├── README.md
├── examples
├── 0_pathwise_conditioning.ipynb
├── 1_model_api.ipynb
├── 2_minimizing_sample_paths.ipynb
├── 3_training_via_sampling.ipynb
└── 4_autoencoding_mnist.ipynb
├── gpflow_sampling
├── __init__.py
├── bases
│ ├── __init__.py
│ ├── core.py
│ ├── dispatch.py
│ ├── fourier_bases.py
│ └── fourier_initializers.py
├── covariances
│ ├── Kfus.py
│ ├── Kufs.py
│ ├── Kuus.py
│ └── __init__.py
├── inducing_variables.py
├── kernels.py
├── models.py
├── sampling
│ ├── __init__.py
│ ├── core.py
│ ├── decoupled_samplers.py
│ ├── priors
│ │ ├── __init__.py
│ │ └── fourier_priors.py
│ └── updates
│ │ ├── __init__.py
│ │ ├── cg_updates.py
│ │ ├── exact_updates.py
│ │ └── linear_updates.py
└── utils
│ ├── __init__.py
│ ├── array_ops.py
│ ├── conv_ops.py
│ ├── gpflow_ops.py
│ └── linalg.py
├── setup.py
└── tests
├── __init__.py
├── kernels
├── __init__.py
└── test_conv2d.py
└── sampling
├── __init__.py
├── priors
├── __init__.py
├── test_fourier_conv2d.py
└── test_fourier_dense.py
└── updates
├── __init__.py
├── common.py
├── test_cg.py
├── test_exact.py
└── test_linear.py
/README.md:
--------------------------------------------------------------------------------
1 | # GPflowSampling
2 | Companion code for [Efficiently Sampling Functions from Gaussian Process Posteriors](https://arxiv.org/abs/2002.09309) and [Pathwise Conditioning of Gaussian processes](https://arxiv.org/abs/2011.04026).
3 | ## Overview
4 | Software provided here revolves around Matheron's update rule
5 |
6 |
7 |
8 | which allows us to represent a GP posterior as the sum of a prior random function and a data-driven update term. Thinking about conditioning at the level of random function (rather than marginal distributions) enables us to accurately sample GP posteriors in linear time.
9 |
10 | Please see `examples` for tutorials and (hopefully) illustrative use cases.
11 |
12 | ## Installation
13 | ```
14 | git clone git@github.com:j-wilson/GPflowSampling.git
15 | cd GPflowSampling
16 | pip install -e .
17 | ```
18 | To install the dependencies needed to run `examples`, use `pip install -e .[examples]`.
19 |
20 |
21 | ## Citing Us
22 | If our work helps you in a way that you feel warrants reference, please cite the following paper:
23 | ```
24 | @inproceedings{wilson2020efficiently,
25 | title={Efficiently sampling functions from Gaussian process posteriors},
26 | author={James T. Wilson
27 | and Viacheslav Borovitskiy
28 | and Alexander Terenin
29 | and Peter Mostowsky
30 | and Marc Peter Deisenroth},
31 | booktitle={International Conference on Machine Learning},
32 | year={2020},
33 | url={https://arxiv.org/abs/2002.09309}
34 | }
35 | ```
36 |
--------------------------------------------------------------------------------
/gpflow_sampling/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from gpflow_sampling import utils, kernels, covariances, bases, sampling, models
5 |
--------------------------------------------------------------------------------
/gpflow_sampling/bases/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from gpflow_sampling.bases.dispatch import (
5 | kernel_basis as kernel,
6 | fourier_basis as fourier
7 | )
8 |
--------------------------------------------------------------------------------
/gpflow_sampling/bases/core.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | from abc import abstractmethod
9 | from typing import Union
10 | from tensorflow import Module
11 | from gpflow.base import TensorType
12 | from gpflow.kernels import Kernel, SharedIndependent, SeparateIndependent
13 | from gpflow.inducing_variables import InducingVariables
14 | from gpflow_sampling.utils import get_inducing_shape
15 | from gpflow_sampling.covariances import Kfu as Kfu_dispatch
16 |
17 | # ---- Exports
18 | __all__ = ('AbstractBasis', 'KernelBasis')
19 |
20 |
21 | # ==============================================
22 | # core
23 | # ==============================================
24 | class AbstractBasis(Module):
25 | def __init__(self, initialized: bool = False, name: str = None):
26 | super().__init__(name=name)
27 | self.initialized = initialized
28 |
29 | @abstractmethod
30 | def __call__(self, *args, **kwargs):
31 | raise NotImplementedError
32 |
33 | @abstractmethod
34 | def initialize(self, *args, **kwargs):
35 | pass
36 |
37 | def _maybe_initialize(self, *args, **kwargs):
38 | if not self.initialized:
39 | self.initialize(*args, **kwargs)
40 | self.initialized = True
41 |
42 | @property
43 | @abstractmethod
44 | def num_bases(self):
45 | raise NotImplementedError
46 |
47 |
48 | class KernelBasis(AbstractBasis):
49 | def __init__(self,
50 | kernel: Kernel,
51 | centers: Union[TensorType, InducingVariables],
52 | name: str = None,
53 | **default_kwargs):
54 |
55 | super().__init__(name=name)
56 | self.kernel = kernel
57 | self.centers = centers
58 | self.default_kwargs = default_kwargs
59 |
60 | def __call__(self, x, **kwargs):
61 | _kwargs = {**self.default_kwargs, **kwargs} # resolve keyword arguments
62 | self._maybe_initialize(x, **_kwargs)
63 | if isinstance(self.centers, InducingVariables):
64 | return Kfu_dispatch(self.centers, self.kernel, x, **_kwargs)
65 |
66 | if isinstance(self.kernel, (SharedIndependent, SeparateIndependent)):
67 | # TODO: Improve handling of "full_output_cov". Here, we're imitating
68 | # the behavior of gpflow.covariances.Kuf.
69 | _kwargs.setdefault('full_output_cov', False)
70 |
71 | return self.kernel.K(x, self.centers, **_kwargs)
72 |
73 | @property
74 | def num_bases(self):
75 | """
76 | TODO: Edge-cases?
77 | """
78 | if isinstance(self.centers, InducingVariables):
79 | return get_inducing_shape(self.centers)[-1]
80 | return self.centers.shape[-1]
81 |
--------------------------------------------------------------------------------
/gpflow_sampling/bases/dispatch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | from typing import Union
9 | from gpflow import kernels as gpflow_kernels
10 | from gpflow.base import TensorType
11 | from gpflow.utilities import Dispatcher
12 | from gpflow.inducing_variables import InducingVariables
13 | from gpflow_sampling import kernels
14 | from gpflow_sampling.bases import fourier_bases
15 | from gpflow_sampling.bases.core import KernelBasis
16 |
17 |
18 | # ---- Exports
19 | __all__ = (
20 | 'kernel_basis',
21 | 'fourier_basis',
22 | )
23 |
24 | kernel_basis = Dispatcher("kernel_basis")
25 | fourier_basis = Dispatcher("fourier_basis")
26 |
27 |
28 | # ==============================================
29 | # dispatch
30 | # ==============================================
31 | @kernel_basis.register(gpflow_kernels.Kernel)
32 | def _kernel_fallback(kern: gpflow_kernels.Kernel,
33 | centers: Union[TensorType, InducingVariables],
34 | **kwargs):
35 | return KernelBasis(kernel=kern, centers=centers, **kwargs)
36 |
37 |
38 | @fourier_basis.register(gpflow_kernels.Stationary)
39 | def _fourier_stationary(kern: gpflow_kernels.Stationary, **kwargs):
40 | return fourier_bases.Dense(kernel=kern, **kwargs)
41 |
42 |
43 | @fourier_basis.register(gpflow_kernels.MultioutputKernel)
44 | def _fourier_multioutput(kern: gpflow_kernels.MultioutputKernel, **kwargs):
45 | return fourier_bases.MultioutputDense(kernel=kern, **kwargs)
46 |
47 |
48 | @fourier_basis.register(kernels.Conv2d)
49 | def _fourier_conv2d(kern: kernels.Conv2d, **kwargs):
50 | return fourier_bases.Conv2d(kernel=kern, **kwargs)
51 |
52 |
53 | @fourier_basis.register(kernels.Conv2dTranspose)
54 | def _fourier_conv2d_transposed(kern: kernels.Conv2dTranspose, **kwargs):
55 | return fourier_bases.Conv2dTranspose(kernel=kern, **kwargs)
56 |
57 |
58 | @fourier_basis.register(kernels.DepthwiseConv2d)
59 | def _fourier_depthwise_conv2d(kern: kernels.DepthwiseConv2d, **kwargs):
60 | return fourier_bases.DepthwiseConv2d(kernel=kern, **kwargs)
61 |
--------------------------------------------------------------------------------
/gpflow_sampling/bases/fourier_bases.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from typing import Any
11 | from gpflow import kernels as gpflow_kernels
12 | from gpflow.base import TensorType
13 | from gpflow.inducing_variables import InducingVariables
14 | from gpflow_sampling import kernels, inducing_variables
15 | from gpflow_sampling.bases.core import AbstractBasis
16 | from gpflow_sampling.utils import (move_axis,
17 | expand_to,
18 | batch_tensordot,
19 | inducing_to_tensor)
20 | from gpflow_sampling.bases.fourier_initializers import (bias_initializer,
21 | weight_initializer)
22 |
23 |
24 | # ---- Exports
25 | __all__ = (
26 | 'Dense',
27 | 'MultioutputDense',
28 | 'Conv2d',
29 | 'Conv2dTranspose',
30 | 'DepthwiseConv2d',
31 | )
32 |
33 |
34 | # ==============================================
35 | # fourier_bases
36 | # ==============================================
37 | class AbstractFourierBasis(AbstractBasis):
38 | def __init__(self,
39 | kernel: gpflow_kernels.Kernel,
40 | num_bases: int,
41 | initialized: bool = False,
42 | name: Any = None):
43 | super().__init__(initialized=initialized, name=name)
44 | self.kernel = kernel
45 | self._num_bases = num_bases
46 |
47 | @property
48 | def num_bases(self):
49 | return self._num_bases
50 |
51 |
52 | class Dense(AbstractFourierBasis):
53 | def __init__(self,
54 | kernel: gpflow_kernels.Stationary,
55 | num_bases: int,
56 | weights: tf.Tensor = None,
57 | biases: tf.Tensor = None,
58 | name: str = None):
59 | super().__init__(name=name,
60 | kernel=kernel,
61 | num_bases=num_bases)
62 | self._weights = weights
63 | self._biases = biases
64 |
65 | def __call__(self, x: TensorType, **kwargs) -> tf.Tensor:
66 | self._maybe_initialize(x, **kwargs)
67 | if isinstance(x, InducingVariables): # TODO: Allow this behavior?
68 | x = inducing_to_tensor(x)
69 |
70 | proj = tf.tensordot(x, self.weights, axes=[-1, -1]) # [..., B]
71 | feat = tf.cos(proj + self.biases)
72 | return self.output_scale * feat
73 |
74 | def initialize(self, x: TensorType, dtype: Any = None):
75 | if isinstance(x, InducingVariables):
76 | x = inducing_to_tensor(x)
77 |
78 | if dtype is None:
79 | dtype = x.dtype
80 |
81 | self._biases = bias_initializer(self.kernel, self.num_bases, dtype=dtype)
82 | self._weights = weight_initializer(self.kernel, x.shape[-1],
83 | batch_shape=[self.num_bases],
84 | dtype=dtype)
85 |
86 | @property
87 | def weights(self):
88 | if self._weights is None:
89 | return None
90 | return tf.math.reciprocal(self.kernel.lengthscales) * self._weights
91 |
92 | @property
93 | def biases(self):
94 | return self._biases
95 |
96 | @property
97 | def output_scale(self):
98 | return tf.sqrt(2 * self.kernel.variance / self.num_bases)
99 |
100 |
101 | class MultioutputDense(Dense):
102 | def __call__(self, x: TensorType, multioutput_axis: int = None, **kwargs):
103 | self._maybe_initialize(x, **kwargs)
104 | if isinstance(x, InducingVariables): # TODO: Allow this behavior?
105 | x = inducing_to_tensor(x)
106 |
107 | # Compute (batch) tensor dot product
108 | batch_axes = None if (multioutput_axis is None) else [0, multioutput_axis]
109 | proj = move_axis(batch_tensordot(self.weights,
110 | x,
111 | axes=[-1, -1],
112 | batch_axes=batch_axes), 1, -1)
113 |
114 | ndims = proj.shape.ndims
115 | feat = tf.cos(proj + expand_to(self.biases, axis=1, ndims=ndims))
116 | return expand_to(self.output_scale, axis=1, ndims=ndims) * feat # [L, N, B]
117 |
118 | def initialize(self, x: TensorType, dtype: Any = None):
119 | if isinstance(x, InducingVariables):
120 | x = inducing_to_tensor(x)
121 |
122 | if dtype is None:
123 | dtype = x.dtype
124 |
125 | biases = []
126 | weights = []
127 | for kernel in self.kernel.latent_kernels:
128 | biases.append(
129 | bias_initializer(kernel, self.num_bases, dtype=dtype))
130 |
131 | weights.append(
132 | weight_initializer(kernel, x.shape[-1],
133 | batch_shape=[self.num_bases],
134 | dtype=dtype))
135 |
136 | self._biases = tf.stack(biases, axis=0) # [L, B]
137 | self._weights = tf.stack(weights, axis=0) # [L, B, D]
138 |
139 | @property
140 | def weights(self):
141 | if self._weights is None:
142 | return None
143 |
144 | num_lengthscales = None
145 | for kernel in self.kernel.latent_kernels:
146 | if kernel.ard:
147 | ls = kernel.lengthscales
148 | assert ls.shape.ndims == 1
149 | if num_lengthscales is None:
150 | num_lengthscales = ls.shape[0]
151 | else:
152 | assert num_lengthscales == ls.shape[0]
153 |
154 | inv_lengthscales = []
155 | for kernel in self.kernel.latent_kernels:
156 | inv_ls = tf.math.reciprocal(kernel.lengthscales)
157 | if not kernel.ard and num_lengthscales is not None:
158 | inv_ls = tf.fill([num_lengthscales], inv_ls)
159 | inv_lengthscales.append(inv_ls)
160 |
161 | # [L, 1, D] or [L, 1, 1]
162 | inv_lengthscales = expand_to(arr=tf.stack(inv_lengthscales),
163 | axis=1,
164 | ndims=self._weights.shape.ndims)
165 |
166 | return inv_lengthscales * self._weights
167 |
168 | @property
169 | def output_scale(self):
170 | variances = tf.stack([k.variance for k in self.kernel.latent_kernels])
171 | return tf.sqrt(2 * variances / self.num_bases) # [L]
172 |
173 |
174 | class Conv2d(AbstractFourierBasis):
175 | def __init__(self,
176 | kernel: kernels.Conv2d,
177 | num_bases: int,
178 | filters: tf.Tensor = None,
179 | biases: tf.Tensor = None,
180 | name: str = None):
181 |
182 | super().__init__(name=name,
183 | kernel=kernel,
184 | num_bases=num_bases)
185 | self._filters = filters
186 | self._biases = biases
187 |
188 | def __call__(self, x: TensorType) -> tf.Tensor:
189 | self._maybe_initialize(x)
190 | if isinstance(x, InducingVariables) or len(x.shape) == 4:
191 | conv = self.convolve(x)
192 | elif len(x.shape) > 4: # allow for higher order batches
193 | x_4d = tf.reshape(x, [-1] + list(x.shape[-3:]))
194 | conv = self.convolve(x_4d)
195 | conv = tf.reshape(conv, list(x.shape[:-3]) + list(conv.shape[1:]))
196 | else:
197 | raise NotImplementedError
198 | return self.output_scale * tf.cos(conv + self.biases)
199 |
200 | def convolve(self, x: TensorType) -> tf.Tensor:
201 | if isinstance(x, inducing_variables.InducingImages):
202 | return tf.nn.conv2d(input=x.as_images,
203 | filters=self.filters,
204 | strides=(1, 1, 1, 1),
205 | padding="VALID")
206 | return self.kernel.convolve(input=x, filters=self.filters)
207 |
208 | def initialize(self, x, dtype: Any = None):
209 | if isinstance(x, inducing_variables.InducingImages):
210 | x = x.as_images
211 |
212 | if dtype is None:
213 | dtype = x.dtype
214 |
215 | self._biases = bias_initializer(self.kernel.kernel,
216 | self.num_bases,
217 | dtype=dtype)
218 |
219 | patch_size = (self.kernel.channels_in
220 | * self.kernel.patch_shape[0]
221 | * self.kernel.patch_shape[1])
222 |
223 | weights = weight_initializer(self.kernel.kernel, patch_size,
224 | batch_shape=[self.num_bases],
225 | dtype=dtype)
226 |
227 | shape = self.kernel.patch_shape + [self.kernel.channels_in, self.num_bases]
228 | self._filters = tf.reshape(move_axis(weights, -1, 0), shape)
229 |
230 | @property
231 | def filters(self):
232 | if self._filters is None:
233 | return None
234 |
235 | shape = list(self.kernel.patch_shape) + [self.kernel.channels_in, 1]
236 | inv_ls = tf.math.reciprocal(self.kernel.kernel.lengthscales)
237 | if self.kernel.kernel.ard:
238 | coeffs = tf.reshape(inv_ls, shape)
239 | else:
240 | coeffs = tf.fill(shape, inv_ls)
241 |
242 | return coeffs * self._filters
243 |
244 | @property
245 | def biases(self):
246 | return self._biases
247 |
248 | @property
249 | def output_scale(self):
250 | return tf.sqrt(2 * self.kernel.kernel.variance / self.num_bases)
251 |
252 |
253 | class Conv2dTranspose(Conv2d):
254 | pass
255 |
256 |
257 | class DepthwiseConv2d(Conv2d):
258 | def convolve(self, x: TensorType) -> tf.Tensor:
259 | if isinstance(x, inducing_variables.DepthwiseInducingImages):
260 | return tf.nn.depthwise_conv2d(input=x.as_images,
261 | filter=self.filters,
262 | strides=(1, 1, 1, 1),
263 | padding="VALID")
264 |
265 | return self.kernel.convolve(input=x, filters=self.filters)
266 |
267 | def initialize(self, x, dtype: Any = None):
268 | if isinstance(x, inducing_variables.InducingImages):
269 | x = x.as_images
270 |
271 | if dtype is None:
272 | dtype = x.dtype
273 |
274 | channels_out = self.kernel.channels_in * self.num_bases
275 | self._biases = bias_initializer(self.kernel.kernel,
276 | channels_out,
277 | dtype=dtype)
278 |
279 | patch_size = self.kernel.patch_shape[0] * self.kernel.patch_shape[1]
280 | batch_shape = [self.kernel.channels_in, self.num_bases]
281 | weights = weight_initializer(self.kernel.kernel, patch_size,
282 | batch_shape=batch_shape,
283 | dtype=dtype)
284 |
285 | self._filters = tf.reshape(move_axis(weights, -1, 0),
286 | self.kernel.patch_shape + batch_shape)
287 |
288 | @property
289 | def filters(self):
290 | if self._filters is None:
291 | return None
292 |
293 | shape = list(self.kernel.patch_shape) + [self.kernel.channels_in, 1]
294 | inv_ls = tf.math.reciprocal(self.kernel.kernel.lengthscales)
295 | if self.kernel.kernel.ard:
296 | coeffs = tf.reshape(tf.transpose(inv_ls), shape)
297 | else:
298 | coeffs = tf.fill(shape, inv_ls)
299 |
300 | return coeffs * self._filters
301 |
302 | @property
303 | def output_scale(self):
304 | num_features_out = self.num_bases * self.kernel.channels_in
305 | return tf.sqrt(2 * self.kernel.kernel.variance / num_features_out)
306 |
--------------------------------------------------------------------------------
/gpflow_sampling/bases/fourier_initializers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | from typing import Any, List
12 | from gpflow import kernels
13 | from gpflow.config import default_float
14 | from gpflow.utilities import Dispatcher
15 |
16 | # ---- Exports
17 | __all__ = ('bias_initializer', 'weight_initializer')
18 |
19 |
20 | # ==============================================
21 | # initializers
22 | # ==============================================
23 | MaternKernel = kernels.Matern12, kernels.Matern32, kernels.Matern52
24 | bias_initializer = Dispatcher("bias_initializer")
25 | weight_initializer = Dispatcher("weight_initializer")
26 |
27 |
28 | @bias_initializer.register(kernels.Stationary, int)
29 | def _bias_initializer_fallback(kern: kernels.Stationary,
30 | ndims: int,
31 | *,
32 | batch_shape: List = None,
33 | dtype: Any = None,
34 | maxval: float = 2 * np.pi) -> tf.Tensor:
35 | if dtype is None:
36 | dtype = default_float()
37 |
38 | shape = [ndims] if batch_shape is None else list(batch_shape) + [ndims]
39 | return tf.random.uniform(shape=shape, maxval=maxval, dtype=dtype)
40 |
41 |
42 | @weight_initializer.register(kernels.SquaredExponential, int)
43 | def _weight_initializer_squaredExp(kern: kernels.SquaredExponential,
44 | ndims: int,
45 | *,
46 | batch_shape: List = None,
47 | dtype: Any = None,
48 | normal_rvs: tf.Tensor = None) -> tf.Tensor:
49 | if dtype is None:
50 | dtype = default_float()
51 |
52 | if batch_shape is None:
53 | batch_shape = []
54 |
55 | shape = [ndims] if batch_shape is None else list(batch_shape) + [ndims]
56 | if normal_rvs is None:
57 | return tf.random.normal(shape, dtype=dtype)
58 |
59 | assert tuple(normal_rvs.shape) == tuple(shape)
60 | return tf.convert_to_tensor(normal_rvs, dtype=dtype)
61 |
62 |
63 | @weight_initializer.register(MaternKernel, int)
64 | def _weight_initializer_matern(kern: MaternKernel,
65 | ndims: int,
66 | *,
67 | batch_shape: List = None,
68 | dtype: Any = None,
69 | normal_rvs: tf.Tensor = None,
70 | gamma_rvs: tf.Tensor = None) -> tf.Tensor:
71 | if dtype is None:
72 | dtype = default_float()
73 |
74 | if isinstance(kern, kernels.Matern12):
75 | smoothness = 1/2
76 | elif isinstance(kern, kernels.Matern32):
77 | smoothness = 3/2
78 | elif isinstance(kern, kernels.Matern52):
79 | smoothness = 5/2
80 | else:
81 | raise NotImplementedError
82 |
83 | batch_shape = [] if batch_shape is None else list(batch_shape)
84 | if normal_rvs is None:
85 | normal_rvs = tf.random.normal(shape=batch_shape + [ndims], dtype=dtype)
86 | else:
87 | assert tuple(normal_rvs.shape) == tuple(batch_shape + [ndims])
88 | normal_rvs = tf.convert_to_tensor(normal_rvs, dtype=dtype)
89 |
90 | if gamma_rvs is None:
91 | gamma_rvs = tf.random.gamma(shape=batch_shape + [1],
92 | alpha=smoothness,
93 | beta=smoothness,
94 | dtype=dtype)
95 | else:
96 | assert tuple(gamma_rvs.shape) == tuple(batch_shape + [1])
97 | gamma_rvs = tf.convert_to_tensor(gamma_rvs, dtype=dtype)
98 |
99 | # Return draws from a multivariate-t distribution
100 | return tf.math.rsqrt(gamma_rvs) * normal_rvs
101 |
--------------------------------------------------------------------------------
/gpflow_sampling/covariances/Kfus.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from gpflow import kernels
11 | from gpflow.base import TensorLike
12 | from gpflow.utilities import Dispatcher
13 | from gpflow.inducing_variables import (InducingVariables,
14 | SharedIndependentInducingVariables)
15 | from gpflow.covariances.dispatch import Kuf as Kuf_dispatch
16 | from gpflow_sampling.kernels import Conv2d, DepthwiseConv2d
17 | from gpflow_sampling.utils import move_axis, get_inducing_shape
18 | from gpflow_sampling.inducing_variables import (InducingImages,
19 | DepthwiseInducingImages)
20 |
21 |
22 | # ==============================================
23 | # Kfus
24 | # ==============================================
25 | Kfu = Dispatcher("Kfu")
26 |
27 |
28 | @Kfu.register(InducingVariables, kernels.Kernel, TensorLike)
29 | def _Kfu_fallback(Z, kern, X, **kwargs):
30 | Kuf = Kuf_dispatch(Z, kern, X, **kwargs)
31 |
32 | # Assume features of x and z are 1-dimensional
33 | ndims_x = X.shape.ndims - 1 # assume x lives in 1d space
34 | ndims_z = len(get_inducing_shape(Z)) - 1
35 | assert ndims_x + ndims_z == Kuf.shape.ndims
36 |
37 | # Swap the batch axes of x and z
38 | axes = list(range(ndims_x + ndims_z))
39 | perm = axes[ndims_z: ndims_z + ndims_x] + axes[:ndims_z]
40 | return tf.transpose(Kuf, perm)
41 |
42 |
43 | @Kfu.register(InducingVariables, kernels.MultioutputKernel, TensorLike)
44 | def _Kfu_fallback_multioutput(Z, kern, X, **kwargs):
45 | Kuf = Kuf_dispatch(Z, kern, X, **kwargs)
46 |
47 | # Assume features of x and z are 1-dimensional
48 | ndims_x = X.shape.ndims - 1 # assume x lives in 1d space
49 | ndims_z = 1 # shared Z live in 1d space, separate Z are 2d but 1-to-1 with L
50 | assert ndims_x + ndims_z == Kuf.shape.ndims - 1
51 |
52 | # Swap the batch axes of x and z
53 | axes = list(range(1, ndims_x + ndims_z + 1)) # keep L output-features first
54 | perm = [0] + axes[ndims_z: ndims_z + ndims_x] + axes[:ndims_z]
55 | return tf.transpose(Kuf, perm)
56 |
57 |
58 | @Kfu.register(SharedIndependentInducingVariables,
59 | kernels.SharedIndependent,
60 | TensorLike)
61 | def _Kfu_fallback_shared(Z, kern, X, **kwargs):
62 | return _Kfu_fallback(Z, kern, X, **kwargs) # Edge-case where L is supressed
63 |
64 |
65 | def _Kfu_conv2d_fallback(feat: InducingImages,
66 | kern: Conv2d,
67 | Xnew: tf.Tensor,
68 | full_spatial: bool = False):
69 |
70 | Zp = feat.as_patches # [M, patch_len]
71 | Xp = kern.get_patches(Xnew, full_spatial=full_spatial)
72 | Kxz = kern.kernel.K(Xp, Zp) # [N, H * W, M] or [N, H, W, M]
73 | if full_spatial: # convert to 4D image format
74 | spatial_out = kern.get_spatial_out(Xnew.shape[-3:-1]) # to [N, H, W, M]
75 | return tf.reshape(Kxz, list(Kxz.shape[:-2]) + spatial_out + [Kxz.shape[-1]])
76 |
77 | if kern.weights is None:
78 | return tf.reduce_mean(Kxz, axis=-2)
79 |
80 | return tf.tensordot(Kxz, kern.weights, [-2, -1])
81 |
82 |
83 | @Kfu.register(InducingImages, Conv2d, object)
84 | def _Kfu_conv2d(feat: InducingImages,
85 | kern: Conv2d,
86 | Xnew: tf.Tensor,
87 | full_spatial: bool = False):
88 |
89 | if not isinstance(kern.kernel, kernels.Stationary):
90 | return _Kfu_conv2d_fallback(feat, kern, Xnew, full_spatial)
91 |
92 | # Compute (squared) Mahalanobis distances between patches
93 | patch_shape = list(kern.patch_shape)
94 | channels_in = Xnew.shape[-3 if kern.data_format == "NCHW" else -1]
95 | precis = tf.square(tf.math.reciprocal(kern.kernel.lengthscales))
96 |
97 | # Construct lengthscale filters [h, w, channels_in, 1]
98 | if kern.kernel.ard:
99 | filters = tf.reshape(precis, patch_shape + [channels_in, 1])
100 | else:
101 | filters = tf.fill(patch_shape + [channels_in, 1], precis)
102 |
103 | r2 = tf.transpose(tf.nn.conv2d(input=tf.square(feat.as_images),
104 | filters=filters,
105 | strides=[1, 1],
106 | padding="VALID"))
107 |
108 | X = tf.reshape(Xnew, [-1] + list(Xnew.shape)[-3:]) # stack as 4d images
109 | r2 += kern.convolve(tf.square(X), filters) # [N, height_out, width_out, M]
110 |
111 | filters *= feat.as_filters # [h, w, channels_in, M]
112 | r2 -= 2 * kern.convolve(X, filters)
113 |
114 | Kxz = kern.kernel.K_r2(r2)
115 | if not full_spatial:
116 | Kxz = tf.reshape(Kxz, list(Kxz.shape[:-3]) + [-1, len(feat)]) # [N, P, M]
117 | if kern.weights is None:
118 | Kxz = tf.reduce_mean(Kxz, axis=-2)
119 | else:
120 | Kxz = tf.tensordot(Kxz, kern.weights, [-2, -1])
121 |
122 | # Undo stacking of Xnew as 4d images X
123 | return tf.reshape(Kxz, list(Xnew.shape[:-3]) + list(Kxz.shape[1:]))
124 |
125 |
126 | def _Kfu_depthwise_conv2d_fallback(feat: DepthwiseInducingImages,
127 | kern: DepthwiseConv2d,
128 | Xnew: tf.Tensor,
129 | full_spatial: bool = False):
130 |
131 | Zp = feat.as_patches # [M, channels_in, patch_len]
132 | Xp = kern.get_patches(Xnew, full_spatial=full_spatial)
133 | r2 = tf.reduce_sum(tf.math.squared_difference( # compute square distances
134 | tf.expand_dims(kern.kernel.scale(Xp), -Zp.shape.ndims),
135 | kern.kernel.scale(Zp)), axis=-1)
136 |
137 | Kxz = kern.kernel.K_r2(r2)
138 | if full_spatial: # convert to 4D image format as [N, H, W, channels_in * M]
139 | return tf.reshape(move_axis(Kxz, -1, -2), list(Kxz.shape[:-2]) + [-1])
140 |
141 | if kern.weights is None: # reduce over channels and patches
142 | return tf.reduce_mean(Kxz, axis=[-3, -1])
143 |
144 | return tf.tensordot(kern.weights, Kxz, axes=[(0, 1), (-3, -1)])
145 |
146 |
147 | @Kfu.register(DepthwiseInducingImages, DepthwiseConv2d, object)
148 | def _Kfu_depthwise_conv2d(feat: DepthwiseInducingImages,
149 | kern: DepthwiseConv2d,
150 | Xnew: tf.Tensor,
151 | full_spatial: bool = False):
152 |
153 | if not isinstance(kern.kernel, kernels.Stationary):
154 | return _Kfu_depthwise_conv2d_fallback(feat, kern, Xnew, full_spatial)
155 |
156 | # Compute (squared) Mahalanobis distances between patches
157 | patch_shape = list(kern.patch_shape)
158 | channels_in = Xnew.shape[-3 if kern.data_format == "NCHW" else -1]
159 | channels_out = len(feat) * channels_in
160 | precis = tf.square(tf.math.reciprocal(kern.kernel.lengthscales))
161 |
162 | # Construct lengthscale filters [h, w, channels_in, 1]
163 | if kern.kernel.ard: # notice the transpose!
164 | assert tuple(precis.shape) == (channels_in, tf.reduce_prod(patch_shape))
165 | filters = tf.reshape(tf.transpose(precis), patch_shape + [channels_in, 1])
166 | else:
167 | filters = tf.fill(patch_shape + [channels_in, 1], precis)
168 |
169 | ZZ = tf.nn.depthwise_conv2d(input=tf.square(feat.as_images),
170 | filter=filters,
171 | strides=[1, 1, 1, 1],
172 | padding="VALID") # [M, 1, 1, channels_in]
173 |
174 | r2 = tf.reshape(move_axis(ZZ, 0, -1), [1, 1, 1, channels_out])
175 |
176 | X = tf.reshape(Xnew, [-1] + list(Xnew.shape)[-3:]) # stack as 4d images
177 | r2 += tf.repeat(kern.convolve(tf.square(X), filters), len(feat), axis=-1)
178 |
179 | filters *= feat.as_filters # [h, w, channels_in, M]
180 | r2 -= 2 * kern.convolve(X, filters) # [N, height_out, width_out, chan_out]
181 |
182 | Kxz = kern.kernel.K_r2(r2)
183 | if full_spatial:
184 | Kxz = tf.reduce_mean(
185 | tf.reshape(Kxz, list(Kxz.shape[:-1]) + [channels_in, -1]),
186 | axis=-2) # average over input channels
187 | else:
188 | Kxz = tf.reshape(Kxz, list(Kxz.shape[:-3]) + [-1, len(feat)]) # [N, P, M]
189 | if kern.weights is None:
190 | Kxz = tf.reduce_mean(Kxz, axis=-2)
191 | else:
192 | div = tf.cast(1/channels_in, Kxz.dtype)
193 | Kxz = div * tf.tensordot(Kxz, tf.reshape(kern.weights, [-1]), [-2, -1])
194 |
195 | # Undo stacking of Xnew as 4d images X
196 | return tf.reshape(Kxz, list(Xnew.shape[:-3]) + list(Kxz.shape[1:]))
197 |
--------------------------------------------------------------------------------
/gpflow_sampling/covariances/Kufs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from gpflow.base import TensorLike
11 | from gpflow.covariances.dispatch import Kuf
12 | from gpflow_sampling.utils import swap_axes
13 | from gpflow_sampling.kernels import Conv2d
14 | from gpflow_sampling.covariances.Kfus import Kfu as Kfu_dispatch
15 | from gpflow_sampling.inducing_variables import InducingImages
16 |
17 |
18 | # ==============================================
19 | # Kufs
20 | # ==============================================
21 | @Kuf.register(InducingImages, Conv2d, TensorLike)
22 | def _Kuf_conv2d_fallback(Z, kernel, X, full_spatial: bool = False, **kwargs):
23 | Kfu = Kfu_dispatch(Z, kernel, X, full_spatial=full_spatial, **kwargs)
24 |
25 | ndims_x = X.shape.ndims - 3 # assume x lives in 3d image space
26 | ndims_z = Z.as_images.shape.ndims - 3
27 |
28 | if full_spatial:
29 | assert Kfu.shape.ndims == ndims_x + ndims_z + 2
30 | return swap_axes(Kfu, -4, -1) # TODO: this is a hack
31 |
32 | # Swap the batch axes of x and z
33 | assert Kfu.shape.ndims == ndims_x + ndims_z
34 | axes = list(range(ndims_x + ndims_z))
35 | perm = axes[ndims_x: ndims_x + ndims_z] + axes[:ndims_x]
36 | return tf.transpose(Kfu, perm)
37 |
--------------------------------------------------------------------------------
/gpflow_sampling/covariances/Kuus.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from gpflow.utilities.ops import square_distance
11 | from gpflow.covariances.dispatch import Kuu
12 | from gpflow_sampling.utils import move_axis
13 | from gpflow_sampling.kernels import Conv2d, DepthwiseConv2d
14 | from gpflow_sampling.inducing_variables import (InducingImages,
15 | DepthwiseInducingImages)
16 |
17 | # ==============================================
18 | # Kuus
19 | # ==============================================
20 | @Kuu.register(InducingImages, Conv2d)
21 | def _Kuu_conv2d(feat: InducingImages,
22 | kern: Conv2d,
23 | jitter: float = 0.0):
24 | _Kuu = kern.kernel.K(feat.as_patches)
25 | return tf.linalg.set_diag(_Kuu, tf.linalg.diag_part(_Kuu) + jitter)
26 |
27 |
28 | @Kuu.register(DepthwiseInducingImages, DepthwiseConv2d)
29 | def _Kuu_depthwise_conv2d(feat: DepthwiseInducingImages,
30 | kern: DepthwiseConv2d,
31 | jitter: float = 0.0):
32 |
33 | # Prepare scaled inducing patches; shape(Zp) = [channels_in, M, patch_len]
34 | Zp = move_axis(kern.kernel.scale(feat.as_patches), -2, 0)
35 | r2 = square_distance(Zp, None)
36 | _Kuu = tf.reduce_mean(kern.kernel.K_r2(r2), axis=0) # [M, M]
37 | return tf.linalg.set_diag(_Kuu, tf.linalg.diag_part(_Kuu) + jitter)
38 |
--------------------------------------------------------------------------------
/gpflow_sampling/covariances/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from gpflow_sampling.covariances.Kuus import Kuu
5 | from gpflow_sampling.covariances.Kufs import Kuf
6 | from gpflow_sampling.covariances.Kfus import Kfu
--------------------------------------------------------------------------------
/gpflow_sampling/inducing_variables.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 | from typing import Optional
10 | from gpflow import inducing_variables
11 | from gpflow.base import TensorData, Parameter
12 | from gpflow.config import default_float
13 | from gpflow_sampling.utils import move_axis
14 |
15 |
16 | # ---- Exports
17 | __all__ = (
18 | 'InducingImages',
19 | 'SharedInducingImages',
20 | 'DepthwiseInducingImages',
21 | 'SharedDepthwiseInducingImages',
22 | )
23 |
24 |
25 | # ==============================================
26 | # inducing_variables
27 | # ==============================================
28 | class InducingImages(inducing_variables.InducingVariables):
29 | def __init__(self, images: TensorData, name: Optional[str] = None):
30 | """
31 | :param images: initial values of inducing locations in image form.
32 |
33 | The shape of the inducing variables varies by representation:
34 | - as Z: [M, height * width * channels_in]
35 | - as images: [M, height, width, channels_in]
36 | - as patches: [M, height * width * channels_in]
37 | - as filters: [height, width, channels_in, M]
38 |
39 | TODO:
40 | - Generalize to allow for inducing image with multiple patches?
41 | - Work on naming convention? The term 'image' is a bit too general.
42 | Patch works, however this term usually refers to a vectorized form
43 | and (for now) overlaps with GPflow's own inducing class. Alternatives
44 | include: filter, window, glimpse
45 | """
46 | super().__init__(name=name)
47 | self._images = Parameter(images, dtype=default_float())
48 |
49 | def __len__(self) -> int:
50 | return self._images.shape[0]
51 |
52 | @property
53 | def Z(self) -> tf.Tensor:
54 | return tf.reshape(self._images, [len(self), -1])
55 |
56 | @property
57 | def as_patches(self) -> tf.Tensor:
58 | return tf.reshape(self.as_images, [len(self), -1])
59 |
60 | @property
61 | def as_filters(self) -> tf.Tensor:
62 | return move_axis(self.as_images, 0, -1)
63 |
64 | @property
65 | def as_images(self) -> tf.Tensor:
66 | return tf.convert_to_tensor(self._images, dtype=self._images.dtype)
67 |
68 |
69 | class SharedInducingImages(InducingImages):
70 | def __init__(self,
71 | images: TensorData,
72 | channels_in: int,
73 | name: Optional[str] = None):
74 | """
75 | :param images: initial values of inducing locations in image form.
76 | :param channels_in: number of input channels to share across
77 |
78 | Same as but with the same single-channel inducing
79 | images shared across all input channels.
80 |
81 | The shape of the inducing variables varies by representation:
82 | - as Z: [M, height * width] (new!)
83 | - as images: [M, height, width, channels_in]
84 | - as patches [M, channels_in, height * width]
85 | - as filters: [height, width, channels_in, M]
86 | """
87 | assert images.shape.ndims == 4 and images.shape[-1] == 1
88 | self.channels_in = channels_in
89 | super().__init__(images, name=name)
90 |
91 | @property
92 | def as_images(self) -> tf.Tensor:
93 | return tf.tile(self._images, [1, 1, 1, self.channels_in])
94 |
95 |
96 | class DepthwiseInducingImages(InducingImages):
97 | """
98 | Same as but for depthwise convolutions.
99 |
100 | The shape of the inducing variables varies by representation:
101 | - as Z: [M, height * width * channels_in]
102 | - as images: [M, height, width, channels_in]
103 | - as patches [M, channels_in, height * width] (new!)
104 | - as filters: [height, width, channels_in, M]
105 | """
106 | @property
107 | def as_patches(self) -> tf.Tensor:
108 | images = self.as_images
109 | patches = tf.reshape(images, [images.shape[0], -1, images.shape[-1]])
110 | return tf.transpose(patches, [0, 2, 1]) # [M, channels_in, patch_len]
111 |
112 |
113 | class SharedDepthwiseInducingImages(SharedInducingImages,
114 | DepthwiseInducingImages):
115 | def __init__(self,
116 | images: TensorData,
117 | channels_in: int,
118 | name: Optional[str] = None):
119 | """
120 | :param images: initial values of inducing locations in image form.
121 | :param channels_in: number of input channels to share across.
122 | """
123 | SharedInducingImages.__init__(self,
124 | name=name,
125 | images=images,
126 | channels_in=channels_in)
127 |
--------------------------------------------------------------------------------
/gpflow_sampling/kernels.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from typing import List
11 | from warnings import warn
12 | from gpflow import kernels
13 | from gpflow.base import TensorType, Parameter
14 | from gpflow.config import default_float
15 | from gpflow_sampling.utils import (conv_ops,
16 | swap_axes,
17 | move_axis,
18 | batch_tensordot)
19 | from tensorflow.python.keras.utils.conv_utils import (conv_output_length,
20 | deconv_output_length)
21 |
22 | # ---- Exports
23 | __all__ = (
24 | 'Conv2d',
25 | 'Conv2dTranspose',
26 | 'DepthwiseConv2d',
27 | )
28 |
29 |
30 | # ==============================================
31 | # kernels
32 | # ==============================================
33 | class Conv2d(kernels.MultioutputKernel):
34 | def __init__(self,
35 | kernel: kernels.Kernel,
36 | image_shape: List,
37 | patch_shape: List,
38 | channels_in: int = 1,
39 | channels_out: int = 1,
40 | weights: TensorType = "default",
41 | strides: List = None,
42 | padding: str = "VALID",
43 | dilations: List = None,
44 | data_format: str = "NHWC"):
45 |
46 | strides = list((1, 1) if strides is None else strides)
47 | dilations = list((1, 1) if dilations is None else dilations)
48 |
49 | # Sanity checks
50 | assert len(strides) == 2
51 | assert len(dilations) == 2
52 | assert padding in ("VALID", "SAME")
53 | assert data_format in ("NHWC", "NCHW")
54 |
55 | if isinstance(weights, str) and weights == "default": # TODO: improve me
56 | spatial_out = self.get_spatial_out(spatial_in=image_shape,
57 | filter_shape=patch_shape,
58 | strides=strides,
59 | padding=padding,
60 | dilations=dilations)
61 |
62 | weights = tf.ones([tf.reduce_prod(spatial_out)], dtype=default_float())
63 |
64 | super().__init__()
65 | self.kernel = kernel
66 | self.image_shape = image_shape
67 | self.patch_shape = patch_shape
68 | self.channels_in = channels_in
69 | self.channels_out = channels_out
70 |
71 | self.strides = strides
72 | self.padding = padding
73 | self.dilations = dilations
74 | self.data_format = data_format
75 | self._weights = None if (weights is None) else Parameter(weights)
76 |
77 | def __call__(self,
78 | X: TensorType,
79 | X2: TensorType=None,
80 | *,
81 | full_cov: bool = False,
82 | full_spatial: bool = False,
83 | presliced: bool = False):
84 |
85 | if not presliced:
86 | X, X2 = self.slice(X, X2)
87 |
88 | if not full_cov and X2 is not None:
89 | raise ValueError(
90 | "Ambiguous inputs: passing in `X2` is not compatible with `full_cov=False`."
91 | )
92 |
93 | if not full_cov:
94 | return self.K_diag(X, full_spatial=full_spatial)
95 | return self.K(X, X2, full_spatial=full_spatial)
96 |
97 | def K(self, X: tf.Tensor, X2: tf.Tensor = None, full_spatial: bool = False):
98 | """
99 | TODO: For stationary kernels, implement this using convolutions?
100 | """
101 | P = self.get_patches(X, full_spatial=full_spatial)
102 | P2 = P if X2 is None else self.get_patches(X2, full_spatial=full_spatial)
103 | K = self.kernel.K(P, P2)
104 | if full_spatial:
105 | return K # [N, H1, W1, N2, H2, W2]
106 |
107 | # At this point, shape(K) = [N, H1 * W1, N2, H2 * W2]
108 | if self.weights is None:
109 | return tf.reduce_mean(K, axis=[-3, -1])
110 | return tf.tensordot(tf.linalg.matvec(K, self.weights),
111 | self.weights,
112 | axes=[-2, 0])
113 |
114 | def K_diag(self, X: tf.Tensor, full_spatial: bool = False):
115 | P = self.get_patches(X, full_spatial=full_spatial)
116 | K = self.kernel.K_diag(P)
117 | if full_spatial:
118 | return K # [N, H1, W1]
119 |
120 | # At this point, shape(K) = [N, H1 * W1]
121 | if self.weights is None:
122 | return tf.reduce_mean(K, axis=[-2, -1])
123 | return tf.linalg.matvec(tf.linalg.matvec(K, self.weights), self.weights)
124 |
125 | def convolve(self,
126 | input,
127 | filters,
128 | strides: List = None,
129 | padding: str = None,
130 | dilations: List = None,
131 | data_format: str = None):
132 |
133 | if strides is None:
134 | strides = self.strides
135 |
136 | if padding is None:
137 | padding = self.padding
138 |
139 | if dilations is None:
140 | dilations = self.dilations
141 |
142 | if data_format is None:
143 | data_format = self.data_format
144 |
145 | return tf.nn.conv2d(input=input,
146 | filters=filters,
147 | strides=strides,
148 | padding=padding,
149 | dilations=dilations,
150 | data_format=data_format)
151 |
152 | def get_patches(self, X: TensorType, full_spatial: bool = True):
153 | # Extract image patches
154 | X_nchw = conv_ops.reformat_data(X, self.data_format, "NHWC")
155 | patches = tf.image.extract_patches(images=X_nchw,
156 | sizes=[1] + self.patch_shape + [1],
157 | strides=[1] + self.strides + [1],
158 | rates=[1] + self.dilations + [1],
159 | padding=self.padding)
160 |
161 | if full_spatial:
162 | output_shape = list(X.shape[:-3]) + list(patches.shape[-3:])
163 | else:
164 | output_shape = list(X.shape[:-3]) + [-1, patches.shape[-1]]
165 |
166 | return tf.reshape(patches, output_shape)
167 |
168 | def get_shape_out(self,
169 | shape_in: List,
170 | filter_shape: List,
171 | strides: List = None,
172 | dilations: List = None,
173 | data_format: str = None) -> List:
174 |
175 | if data_format is None:
176 | data_format = self.data_format
177 |
178 | if data_format == "NHWC":
179 | *batch, height, width, _ = list(shape_in)
180 | else:
181 | *batch, _, height, width = list(shape_in)
182 |
183 | spatial_out = self.get_spatial_out(spatial_in=[height, width],
184 | filter_shape=filter_shape[:2],
185 | strides=strides,
186 | dilations=dilations)
187 |
188 | nhwc_out = batch + spatial_out + [filter_shape[-1]]
189 | return conv_ops.reformat_shape(nhwc_out, "NHWC", data_format)
190 |
191 | def get_spatial_out(self,
192 | spatial_in: List = None,
193 | filter_shape: List = None,
194 | strides: List = None,
195 | padding: str = None,
196 | dilations: List = None) -> List:
197 |
198 | if spatial_in is None:
199 | spatial_in = self.image_shape
200 |
201 | if filter_shape is None:
202 | filter_shape = self.patch_shape
203 | else:
204 | assert len(filter_shape) == 2
205 |
206 | if strides is None:
207 | strides = self.strides
208 |
209 | if padding is None:
210 | padding = self.padding
211 |
212 | if dilations is None:
213 | dilations = self.dilations
214 |
215 | return [conv_output_length(input_length=spatial_in[i],
216 | filter_size=filter_shape[i],
217 | stride=strides[i],
218 | padding=padding.lower(),
219 | dilation=dilations[i]) for i in range(2)]
220 |
221 | @property
222 | def num_patches(self):
223 | return tf.reduce_prod(self.get_spatial_out())
224 |
225 | @property
226 | def weights(self):
227 | if self._weights is None:
228 | return None
229 | return tf.cast(1/self.num_patches, self._weights.dtype) * self._weights
230 |
231 | @property
232 | def num_latent_gps(self):
233 | return self.channels_out
234 |
235 | @property
236 | def latent_kernels(self):
237 | return self.kernel,
238 |
239 |
240 | class Conv2dTranspose(Conv2d):
241 | def __init__(self,
242 | *args,
243 | strides: List = None,
244 | dilations: List = None,
245 | **kwargs):
246 |
247 | strides = list((1, 1) if strides is None else strides)
248 | dilations = list((1, 1) if dilations is None else dilations)
249 | if strides != [1, 1] and dilations != [1, 1]:
250 | raise NotImplementedError('Tensorflow does not implement transposed'
251 | 'convolutions with strides and dilations.')
252 |
253 | super().__init__(*args, strides=strides, dilations=dilations, **kwargs)
254 |
255 | def convolve(self,
256 | input,
257 | filters,
258 | strides: List = None,
259 | padding: str = None,
260 | dilations: List = None,
261 | data_format: str = None):
262 |
263 | if strides is None:
264 | strides = self.strides
265 |
266 | if padding is None:
267 | padding = self.padding
268 |
269 | if dilations is None:
270 | dilations = self.dilations
271 |
272 | if data_format is None:
273 | data_format = self.data_format
274 |
275 | shape_out = self.get_shape_out(shape_in=input.shape,
276 | filter_shape=filters.shape,
277 | strides=strides,
278 | dilations=dilations,
279 | data_format=data_format)
280 |
281 | _filters = swap_axes(filters, -2, -1)
282 | conv_kwargs = dict(filters=_filters,
283 | padding=padding,
284 | output_shape=shape_out)
285 |
286 | if dilations != [1, 1]: # TODO: improve me
287 | assert data_format == 'NHWC'
288 | assert list(strides) == [1, 1]
289 | assert len(dilations) == 2 and dilations[0] == dilations[1]
290 | return tf.nn.atrous_conv2d_transpose(input,
291 | rate=dilations[0],
292 | **conv_kwargs)
293 |
294 | return tf.nn.conv2d_transpose(input,
295 | strides=strides,
296 | dilations=dilations,
297 | data_format=data_format,
298 | **conv_kwargs)
299 |
300 | def get_patches(self, X, full_spatial: bool = False):
301 | """
302 | Returns the patches used by a 2d transposed convolution.
303 | """
304 | spatial_in = X.shape[-3: -1]
305 |
306 | # Pad X with (stride - 1) zeros in between each pixel
307 | if any(stride != 1 for stride in self.strides):
308 | shape = list(X.shape[:-3])
309 | terms = [tf.range(size) for size in shape]
310 | for i, stride in enumerate(self.strides):
311 | size = X.shape[i - 3]
312 | shape.append(stride * (size - 1) + 1)
313 | terms.append(tf.range(stride * size, delta=stride))
314 | shape.append(X.shape[-1])
315 | grid = tf.meshgrid(*terms, indexing='ij')
316 | X = tf.scatter_nd(tf.stack(grid, -1), X, shape)
317 |
318 | # Prepare padding info
319 | if self.padding == "VALID":
320 | h_pad = 2 * [self.dilations[0] * (self.patch_shape[0] - 1)]
321 | w_pad = 2 * [self.dilations[1] * (self.patch_shape[1] - 1)]
322 | elif self.padding == "SAME":
323 | height_out, width_out = self.get_spatial_out(spatial_in)
324 | extra = height_out - X.shape[-3]
325 | h_pad = list(map(lambda x: tf.cast(x, tf.int64),
326 | (tf.math.ceil(0.5 * extra), tf.math.floor(0.5 * extra))))
327 |
328 | extra = width_out - X.shape[-2]
329 | w_pad = list(map(lambda x: tf.cast(x, tf.int64),
330 | (tf.math.ceil(0.5 * extra), tf.math.floor(0.5 * extra))))
331 |
332 | # Extract (flipped) image patches
333 | X_pad = tf.pad(X, [2 * [0], h_pad, w_pad, 2 * [0]])
334 | patches = tf.image.extract_patches(images=X_pad,
335 | sizes=[1] + self.patch_shape + [1],
336 | strides=[1, 1, 1, 1],
337 | rates=[1] + self.dilations + [1],
338 | padding=self.padding)
339 |
340 | if full_spatial:
341 | output_shape = list(X.shape[:-3]) + list(patches.shape[-3:])
342 | else:
343 | output_shape = list(X.shape[:-3]) + [-1, patches.shape[-1]]
344 |
345 | return tf.reshape(tf.reverse( # reverse channel-wise patches and reshape
346 | tf.reshape(patches, list(patches.shape[:-1]) + [-1, X.shape[-1]]),
347 | axis=[-2]), output_shape)
348 |
349 | def get_spatial_out(self,
350 | spatial_in: List = None,
351 | filter_shape: List = None,
352 | strides: List = None,
353 | padding: str = None,
354 | dilations: List = None) -> List:
355 |
356 | if spatial_in is None:
357 | spatial_in = self.image_shape
358 |
359 | if filter_shape is None:
360 | filter_shape = self.patch_shape
361 | else:
362 | assert len(filter_shape) == 2
363 |
364 | if strides is None:
365 | strides = self.strides
366 |
367 | if padding is None:
368 | padding = self.padding
369 |
370 | if dilations is None:
371 | dilations = self.dilations
372 |
373 | return [deconv_output_length(input_length=spatial_in[i],
374 | filter_size=filter_shape[i],
375 | stride=strides[i],
376 | padding=padding.lower(),
377 | dilation=dilations[i]) for i in range(2)]
378 |
379 |
380 | class DepthwiseConv2d(Conv2d):
381 | def __init__(self,
382 | kernel: kernels.Kernel,
383 | image_shape: List,
384 | patch_shape: List,
385 | channels_in: int = 1,
386 | channels_out: int = 1,
387 | weights: TensorType = "default",
388 | strides: List = None,
389 | padding: str = "VALID",
390 | dilations: List = None,
391 | data_format: str = "NHWC",
392 | **kwargs):
393 |
394 | strides = list((1, 1) if strides is None else strides)
395 | dilations = list((1, 1) if dilations is None else dilations)
396 | if strides != [1, 1] and dilations != [1, 1]:
397 | warn(f"{self.__class__} does not pass unit tests when strides != [1, 1]"
398 | f" and dilations != [1, 1] simultaneously.")
399 |
400 | if isinstance(weights, str) and weights == "default": # TODO: improve me
401 | spatial_out = self.get_spatial_out(spatial_in=image_shape,
402 | filter_shape=patch_shape,
403 | strides=strides,
404 | padding=padding,
405 | dilations=dilations)
406 |
407 | weights = tf.ones([tf.reduce_prod(spatial_out), channels_in],
408 | dtype=default_float())
409 |
410 | super().__init__(kernel=kernel,
411 | image_shape=image_shape,
412 | patch_shape=patch_shape,
413 | channels_in=channels_in,
414 | channels_out=channels_out,
415 | weights=weights,
416 | strides=strides,
417 | padding=padding,
418 | dilations=dilations,
419 | data_format=data_format)
420 |
421 | def K(self, X: tf.Tensor, X2: tf.Tensor = None, full_spatial: bool = False):
422 | P = self.get_patches(X, full_spatial=full_spatial)
423 | P2 = P if X2 is None else self.get_patches(X2, full_spatial=full_spatial)
424 |
425 | # TODO: Temporary hack, use of self.kernel should be deprecated
426 | K = move_axis(
427 | tf.linalg.diag_part(
428 | move_axis(self.kernel.K(P, P2), P.shape.ndims - 2, -2)), -1, 0)
429 |
430 | if full_spatial:
431 | return K # [channels_in, N, H1, W1, N2, H2, W2]
432 |
433 | # At this point, shape(K) = [N, num_patches, N2, num_patches]
434 | if self.weights is None:
435 | return tf.reduce_mean(K, axis=[0, -3, -1])
436 |
437 | K = batch_tensordot(K, self.weights, axes=[-1, 0], batch_axes=[0, 1])
438 | K = batch_tensordot(K, self.weights, axes=[-2, 0], batch_axes=[0, 1])
439 | return tf.reduce_mean(K, axis=0)
440 |
441 | def K_diag(self, X: tf.Tensor, full_spatial: bool = False):
442 | raise NotImplementedError
443 |
444 | P = self.get_patches(X, full_spatial=full_spatial)
445 | K = tf.reduce_mean(self.kernel.K(P), axis=-2) # average over channels
446 | if full_spatial:
447 | return K # [num_channels, N, H1, W1, H1, W1]
448 |
449 | # At this point, K has shape # [num_channels, N, num_patches, num_patches]
450 | if self.weights is None:
451 | return tf.reduce_mean(K, axis=[-2, -1])
452 |
453 | K = batch_tensordot(K, self.weights, axes=[-1, 0], batch_axes=[0, 1])
454 | K = batch_tensordot(K, self.weights, axes=[-1, 0], batch_axes=[0, 1])
455 | return tf.reduce_mean(K, axis=0)
456 |
457 | def convolve(self,
458 | input,
459 | filters,
460 | strides: List = None,
461 | padding: str = None,
462 | dilations: List = None,
463 | data_format: str = None):
464 |
465 | if strides is None:
466 | strides = self.strides
467 |
468 | if padding is None:
469 | padding = self.padding
470 |
471 | if dilations is None:
472 | dilations = self.dilations
473 |
474 | if data_format is None:
475 | data_format = self.data_format
476 |
477 | return tf.nn.depthwise_conv2d(input=input,
478 | filter=filters,
479 | strides=[1] + strides + [1],
480 | padding=padding,
481 | dilations=dilations,
482 | data_format=data_format)
483 |
484 | def get_patches(self, X: TensorType, full_spatial: bool = False):
485 | """
486 | Returns the patches used by a 2d depthwise convolution.
487 | """
488 | patches = super().get_patches(X, full_spatial=full_spatial)
489 | channels_in = X.shape[-3 if self.data_format == "NCHW" else -1]
490 | depthwise_patches = tf.reshape(patches,
491 | list(patches.shape[:-1]) + [-1, channels_in])
492 | return move_axis(depthwise_patches, -2, -1)
493 |
494 | def get_shape_out(self,
495 | shape_in: List,
496 | filter_shape: List,
497 | strides: List = None,
498 | dilations: List = None,
499 | data_format: str = None) -> List:
500 |
501 | if data_format is None:
502 | data_format = self.data_format
503 |
504 | if data_format == "NHWC":
505 | *batch, height, width, _ = list(shape_in)
506 | else:
507 | *batch, _, height, width = list(shape_in)
508 |
509 | spatial_out = self.get_spatial_out(spatial_in=[height, width],
510 | filter_shape=filter_shape[:2],
511 | strides=strides,
512 | dilations=dilations)
513 |
514 | nhwc_out = batch + spatial_out + [filter_shape[-2] * filter_shape[-1]]
515 | return conv_ops.reformat_shape(nhwc_out, "NHWC", data_format)
516 |
--------------------------------------------------------------------------------
/gpflow_sampling/models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from abc import abstractmethod
11 | from typing import Optional
12 | from contextlib import contextmanager
13 | from gpflow.base import TensorLike
14 | from gpflow.config import default_float, default_jitter
15 | from gpflow.models import GPModel, SVGP, GPR
16 | from gpflow_sampling import sampling, covariances
17 | from gpflow_sampling.sampling.core import AbstractSampler, CompositeSampler
18 |
19 |
20 | # ---- Exports
21 | __all__ = ('PathwiseGPModel', 'PathwiseGPR', 'PathwiseSVGP')
22 |
23 |
24 | # ==============================================
25 | # models
26 | # ==============================================
27 | class PathwiseGPModel(GPModel):
28 | def __init__(self, *args, paths: AbstractSampler = None, **kwargs):
29 | super().__init__(*args, **kwargs)
30 | self._paths = paths
31 |
32 | @abstractmethod
33 | def generate_paths(self, *args, **kwargs) -> AbstractSampler:
34 | raise NotImplementedError
35 |
36 | def predict_f_samples(self,
37 | Xnew: TensorLike,
38 | num_samples: Optional[int] = None,
39 | full_cov: bool = True,
40 | full_output_cov: bool = True,
41 | **kwargs) -> tf.Tensor:
42 |
43 | assert full_cov and full_output_cov, NotImplementedError
44 | if self.paths is None:
45 | raise RuntimeError("Paths were not initialized.")
46 |
47 | if num_samples is not None:
48 | assert num_samples == self.paths.sample_shape,\
49 | ValueError("Requested number of samples does not match path count.")
50 |
51 | return self.paths(Xnew, **kwargs)
52 |
53 | @contextmanager
54 | def temporary_paths(self, *args, **kwargs):
55 | try:
56 | init_paths = self.paths
57 | temp_paths = self.generate_paths(*args, **kwargs)
58 | self.set_paths(temp_paths)
59 | yield temp_paths
60 | finally:
61 | self.set_paths(init_paths)
62 |
63 | def set_paths(self, paths) -> AbstractSampler:
64 | self._paths = paths
65 | return paths
66 |
67 | @contextmanager
68 | def set_temporary_paths(self, paths):
69 | try:
70 | init_paths = self.paths
71 | self.set_paths(paths)
72 | yield paths
73 | finally:
74 | self.set_paths(init_paths)
75 |
76 | @property
77 | def paths(self) -> AbstractSampler:
78 | return self._paths
79 |
80 |
81 | class PathwiseGPR(GPR, PathwiseGPModel):
82 | def __init__(self, *args, paths: AbstractSampler = None, **kwargs):
83 | GPR.__init__(self, *args, **kwargs)
84 | self._paths = paths
85 |
86 | def generate_paths(self,
87 | num_samples: int,
88 | num_bases: int = None,
89 | prior: AbstractSampler = None,
90 | sample_axis: int = None,
91 | **kwargs) -> CompositeSampler:
92 |
93 | if prior is None:
94 | prior = sampling.priors.random_fourier(self.kernel,
95 | num_bases=num_bases,
96 | sample_shape=[num_samples],
97 | sample_axis=sample_axis)
98 | elif num_bases is not None:
99 | assert prior.sample_shape == [num_samples]
100 |
101 | diag = tf.convert_to_tensor(self.likelihood.variance)
102 | return sampling.decoupled(self.kernel,
103 | prior,
104 | *self.data,
105 | mean_function=self.mean_function,
106 | diag=diag,
107 | sample_axis=sample_axis,
108 | **kwargs)
109 |
110 |
111 | class PathwiseSVGP(SVGP, PathwiseGPModel):
112 | def __init__(self, *args, paths: AbstractSampler = None, **kwargs):
113 | SVGP.__init__(self, *args, **kwargs)
114 | self._paths = paths
115 |
116 | def generate_paths(self,
117 | num_samples: int,
118 | num_bases: int = None,
119 | prior: AbstractSampler = None,
120 | sample_axis: int = None,
121 | **kwargs) -> CompositeSampler:
122 |
123 | if prior is None:
124 | prior = sampling.priors.random_fourier(self.kernel,
125 | num_bases=num_bases,
126 | sample_shape=[num_samples],
127 | sample_axis=sample_axis)
128 | elif num_bases is not None:
129 | assert prior.sample_shape == [num_samples]
130 |
131 | return sampling.decoupled(self.kernel,
132 | prior,
133 | self.inducing_variable,
134 | self._generate_u(num_samples),
135 | mean_function=self.mean_function,
136 | sample_axis=sample_axis,
137 | **kwargs)
138 |
139 | def _generate_u(self, num_samples: int, L: tf.Tensor = None):
140 | """
141 | Returns samples $u ~ q(u)$.
142 | """
143 | q_sqrt = tf.linalg.band_part(self.q_sqrt, -1, 0)
144 | shape = self.num_latent_gps, q_sqrt.shape[-1], num_samples
145 | rvs = tf.random.normal(shape, dtype=default_float()) # [L, M, S]
146 | uT = q_sqrt @ rvs + tf.transpose(self.q_mu)[..., None]
147 | if self.whiten:
148 | if L is None:
149 | Z = self.inducing_variable
150 | K = covariances.Kuu(Z, self.kernel, jitter=default_jitter())
151 | L = tf.linalg.cholesky(K)
152 | uT = L @ uT
153 | return tf.transpose(uT) # [S, M, L]
154 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 |
5 | from gpflow_sampling.sampling import core
6 | from gpflow_sampling.sampling import priors, updates
7 | from gpflow_sampling.sampling.decoupled_samplers import decoupled
8 |
9 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/core.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from abc import abstractmethod
11 | from typing import List, Callable, Union
12 | from gpflow.inducing_variables import InducingVariables
13 | from gpflow_sampling.utils import (move_axis,
14 | normalize_axis,
15 | batch_tensordot,
16 | get_inducing_shape)
17 |
18 | # ---- Exports
19 | __all__ = (
20 | 'AbstractSampler',
21 | 'DenseSampler',
22 | 'MultioutputDenseSampler',
23 | 'CompositeSampler',
24 | )
25 |
26 |
27 | # ==============================================
28 | # core
29 | # ==============================================
30 | class AbstractSampler(tf.Module):
31 | @abstractmethod
32 | def __call__(self, *args, **kwargs):
33 | raise NotImplementedError
34 |
35 | @property
36 | def sample_shape(self):
37 | raise NotImplementedError
38 |
39 |
40 | class CompositeSampler(AbstractSampler):
41 | def __init__(self,
42 | join_rule: Callable,
43 | samplers: List[Callable],
44 | mean_function: Callable = None,
45 | name: str = None):
46 | """
47 | Combine base samples via a specified join rule.
48 | """
49 | super().__init__(name=name)
50 | self._join_rule = join_rule
51 | self._samplers = samplers
52 | self.mean_function = mean_function
53 |
54 | def __call__(self, x: tf.Tensor, **kwargs) -> tf.Tensor:
55 | samples = [sampler(x, **kwargs) for sampler in self.samplers]
56 | vals = self.join_rule(samples)
57 | return vals if self.mean_function is None else vals + self.mean_function(x)
58 |
59 | @property
60 | def join_rule(self) -> Callable:
61 | return self._join_rule
62 |
63 | @property
64 | def samplers(self):
65 | return self._samplers
66 |
67 | @property
68 | def sample_shape(self):
69 | for i, sampler in enumerate(self.samplers):
70 | if i == 0:
71 | sample_shape = sampler.sample_shape
72 | else:
73 | assert sample_shape == sampler.sample_shape
74 | return sample_shape
75 |
76 |
77 | class DenseSampler(AbstractSampler):
78 | def __init__(self,
79 | weights: Union[tf.Tensor, tf.Variable],
80 | basis: Callable = None,
81 | mean_function: Callable = None,
82 | sample_axis: int = None,
83 | name: str = None):
84 | """
85 | Return samples as weighted sums of features.
86 | """
87 | assert weights.shape.ndims > 1
88 | super().__init__(name=name)
89 | self.weights = weights
90 | self.basis = basis
91 | self.mean_function = mean_function
92 | self.sample_axis = sample_axis
93 |
94 | def __call__(self, x: tf.Tensor, sample_axis: int = "default", **kwargs):
95 | """
96 | :param sample_axis: Specify an axis of inputs x as corresponding 1-to-1 with
97 | sample-specific slices of weight tensor w when computing tensor dot
98 | products.
99 | """
100 | if sample_axis == "default":
101 | sample_axis = self.sample_axis
102 |
103 | feat = x if self.basis is None else self.basis(x, **kwargs)
104 | if sample_axis is None:
105 | batch_axes = None
106 | else:
107 | assert len(self.sample_shape), "Received sample_axis but self.weights has" \
108 | " no dedicated axis for samples; this" \
109 | " usually implies that sample_shape=[]."
110 |
111 | ndims_x = len(get_inducing_shape(x) if
112 | isinstance(x, InducingVariables) else x.shape)
113 |
114 | batch_axes = [-3, normalize_axis(sample_axis, ndims_x) - ndims_x]
115 |
116 | # Batch-axes notwithstanding, shape(vals) = [N, S] (after move_axis)
117 | vals = move_axis(batch_tensordot(self.weights,
118 | feat,
119 | axes=[-1, -1],
120 | batch_axes=batch_axes),
121 | len(self.sample_shape), # axis=-2 houses scalar output 1
122 | -1)
123 |
124 | return vals if self.mean_function is None else vals + self.mean_function(x)
125 |
126 | @property
127 | def sample_shape(self):
128 | w_shape = list(self.weights.shape)
129 | if len(w_shape) == 2:
130 | return []
131 | return w_shape[:-2]
132 |
133 |
134 | class MultioutputDenseSampler(DenseSampler):
135 | def __init__(self,
136 | weights: Union[tf.Tensor, tf.Variable],
137 | *args,
138 | multioutput_axis: int = None,
139 | **kwargs):
140 | """
141 | :param multioutput_axis: Specify an axis of inputs x (or features thereof)
142 | as corresponding 1-to-1 with output-feature-specific slices of weight
143 | tensor w when computing tensor dot products.
144 | """
145 | self.multioutput_axis = multioutput_axis
146 | super().__init__(weights, *args, **kwargs)
147 |
148 | def __call__(self, x: tf.Tensor, sample_axis: int = "default", **kwargs):
149 | """
150 | :param sample_axis: Specify an axis of inputs x as corresponding 1-to-1 with
151 | sample-specific slices of weight tensor w when computing tensor dot
152 | products.
153 |
154 | TODO: Improve hard-coding of multioutput-/sample-axis of weights.
155 | """
156 | if sample_axis == "default":
157 | sample_axis = self.sample_axis
158 |
159 | if self.multioutput_axis is None:
160 | batch_w = [] # batch axes for w
161 | batch_x = [] # batch axes for x
162 | else:
163 | batch_w = [-2]
164 | batch_x = [self.multioutput_axis]
165 |
166 | if sample_axis is not None:
167 | assert len(self.sample_shape), "Received sample_axis but self.weights has" \
168 | " no dedicated axis for samples; this" \
169 | " usually implies that sample_shape=[]."
170 | batch_w.append(-3)
171 |
172 | # TODO: If basis(x) grows the rank of x, it should only do so from the
173 | # left, such that the negative i-th axis (i > 1) remains the same.
174 | ndims_x = len(get_inducing_shape(x) if
175 | isinstance(x, InducingVariables) else x.shape)
176 | batch_x.append(normalize_axis(sample_axis, ndims_x) - ndims_x)
177 |
178 | feat = x if self.basis is None else self.basis(x, **kwargs)
179 | vals = move_axis(batch_tensordot(self.weights, # output features go last
180 | feat,
181 | axes=[-1, -1],
182 | batch_axes=[batch_w, batch_x]),
183 | len(self.sample_shape), # axis=-2 houses multioutputs L
184 | -1)
185 |
186 | return vals if self.mean_function is None else vals + self.mean_function(x)
187 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/decoupled_samplers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from typing import List, Callable
11 | from gpflow import inducing_variables
12 | from gpflow.base import TensorLike
13 | from gpflow.utilities import Dispatcher
14 | from gpflow.kernels import Kernel, MultioutputKernel, LinearCoregionalization
15 | from gpflow_sampling.sampling.updates import exact as exact_update
16 | from gpflow_sampling.sampling.core import AbstractSampler, CompositeSampler
17 | from gpflow_sampling.kernels import Conv2d
18 | from gpflow_sampling.inducing_variables import InducingImages
19 |
20 | # ---- Exports
21 | __all__ = ('decoupled',)
22 | decoupled = Dispatcher("decoupled")
23 |
24 |
25 | # ==============================================
26 | # decoupled_samplers
27 | # ==============================================
28 | @decoupled.register(Kernel, AbstractSampler, TensorLike, TensorLike)
29 | def _decoupled_fallback(kern: Kernel,
30 | prior: AbstractSampler,
31 | Z: TensorLike,
32 | u: TensorLike,
33 | *,
34 | mean_function: Callable = None,
35 | update_rule: Callable = exact_update,
36 | join_rule: Callable = sum,
37 | **kwargs):
38 |
39 | f = prior(Z, sample_axis=None) # [S, M, L]
40 | update = update_rule(kern, Z, u, f, **kwargs)
41 | return CompositeSampler(samplers=[prior, update],
42 | join_rule=join_rule,
43 | mean_function=mean_function)
44 |
45 |
46 | @decoupled.register(MultioutputKernel, AbstractSampler, TensorLike, TensorLike)
47 | def _decoupled_multioutput(kern: MultioutputKernel,
48 | prior: AbstractSampler,
49 | Z: TensorLike,
50 | u: TensorLike,
51 | *,
52 | mean_function: Callable = None,
53 | update_rule: Callable = exact_update,
54 | join_rule: Callable = sum,
55 | multioutput_axis_Z: int = "default",
56 | **kwargs):
57 |
58 | # Determine whether or not to evalaute Z pathwise (per output feature)
59 | # TODO: Ugly. This argument is actually being passed to the prior's basis.
60 | # Disallow non-inducing-variable Z for multioutput cases of decoupled?
61 | if multioutput_axis_Z == "default":
62 | if isinstance(Z, inducing_variables.MultioutputInducingVariables) and not\
63 | isinstance(Z, inducing_variables.SharedIndependentInducingVariables):
64 | multioutput_axis_Z = 0
65 | else:
66 | multioutput_axis_Z = None
67 |
68 | f = prior(Z, sample_axis=None, multioutput_axis=multioutput_axis_Z)
69 | update = update_rule(kern, Z, u, f, **kwargs)
70 | return CompositeSampler(samplers=[prior, update],
71 | join_rule=join_rule,
72 | mean_function=mean_function)
73 |
74 |
75 | @decoupled.register(LinearCoregionalization, AbstractSampler, TensorLike, TensorLike)
76 | def _decoupled_lcm(kern: LinearCoregionalization,
77 | prior: AbstractSampler,
78 | Z: TensorLike,
79 | u: TensorLike,
80 | *,
81 | join_rule: Callable = None,
82 | **kwargs):
83 | if join_rule is None:
84 | def join_rule(terms: List[tf.Tensor]) -> tf.Tensor:
85 | return tf.tensordot(kern.W, sum(terms), axes=[-1, 0])
86 | return _decoupled_multioutput(kern, prior, Z, u, join_rule=join_rule, **kwargs)
87 |
88 |
89 | @decoupled.register(Conv2d, AbstractSampler, InducingImages, TensorLike)
90 | def _decoupled_conv(kern: Conv2d,
91 | prior: AbstractSampler,
92 | Z: InducingImages,
93 | u: TensorLike,
94 | *,
95 | mean_function: Callable = None,
96 | update_rule: Callable = exact_update,
97 | join_rule: Callable = sum,
98 | **kwargs):
99 |
100 | f = tf.squeeze(prior(Z, sample_axis=None), axis=[-3, -2]) # [S, M, L]
101 | update = update_rule(kern, Z, u, f, **kwargs)
102 | return CompositeSampler(samplers=[prior, update],
103 | join_rule=join_rule,
104 | mean_function=mean_function)
105 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/priors/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from gpflow_sampling.sampling.priors.fourier_priors import random_fourier
5 |
6 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/priors/fourier_priors.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from typing import Any, List, Callable
11 | from gpflow.config import default_float
12 | from gpflow.kernels import Kernel, MultioutputKernel
13 | from gpflow.utilities import Dispatcher
14 | from gpflow_sampling.bases import fourier as fourier_basis
15 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler
16 | from gpflow_sampling.kernels import Conv2d, DepthwiseConv2d
17 |
18 |
19 | # ---- Exports
20 | __all__ = ('random_fourier',)
21 | random_fourier = Dispatcher("random_fourier")
22 |
23 |
24 | # ==============================================
25 | # fourier_priors
26 | # ==============================================
27 | @random_fourier.register(Kernel)
28 | def _random_fourier(kernel: Kernel,
29 | sample_shape: List,
30 | num_bases: int,
31 | basis: Callable = None,
32 | dtype: Any = None,
33 | name: str = None,
34 | **kwargs):
35 |
36 | if dtype is None:
37 | dtype = default_float()
38 |
39 | if basis is None:
40 | basis = fourier_basis(kernel, num_bases=num_bases)
41 |
42 | weights = tf.random.normal(list(sample_shape) + [1, num_bases], dtype=dtype)
43 | return DenseSampler(weights=weights, basis=basis, name=name, **kwargs)
44 |
45 |
46 | @random_fourier.register(MultioutputKernel)
47 | def _random_fourier_multioutput(kernel: MultioutputKernel,
48 | sample_shape: List,
49 | num_bases: int,
50 | basis: Callable = None,
51 | dtype: Any = None,
52 | name: str = None,
53 | multioutput_axis: int = 0,
54 | **kwargs):
55 | if dtype is None:
56 | dtype = default_float()
57 |
58 | if basis is None:
59 | basis = fourier_basis(kernel, num_bases=num_bases)
60 |
61 | shape = list(sample_shape) + [kernel.num_latent_gps, num_bases]
62 | weights = tf.random.normal(shape, dtype=dtype)
63 | return MultioutputDenseSampler(name=name,
64 | basis=basis,
65 | weights=weights,
66 | multioutput_axis=multioutput_axis,
67 | **kwargs)
68 |
69 |
70 | @random_fourier.register(Conv2d)
71 | def _random_fourier_conv(kernel: Conv2d,
72 | sample_shape: List,
73 | num_bases: int,
74 | basis: Callable = None,
75 | dtype: Any = None,
76 | name: str = None,
77 | **kwargs):
78 |
79 | if dtype is None:
80 | dtype = default_float()
81 |
82 | if basis is None:
83 | basis = fourier_basis(kernel, num_bases=num_bases)
84 |
85 | shape = list(sample_shape) + [kernel.num_latent_gps, num_bases]
86 | weights = tf.random.normal(shape, dtype=dtype)
87 | return MultioutputDenseSampler(weights=weights,
88 | basis=basis,
89 | name=name,
90 | **kwargs)
91 |
92 |
93 | @random_fourier.register(DepthwiseConv2d)
94 | def _random_fourier_depthwise_conv(kernel: DepthwiseConv2d,
95 | sample_shape: List,
96 | num_bases: int,
97 | basis: Callable = None,
98 | dtype: Any = None,
99 | name: str = None,
100 | **kwargs):
101 |
102 | if dtype is None:
103 | dtype = default_float()
104 |
105 | if basis is None:
106 | basis = fourier_basis(kernel, num_bases=num_bases)
107 |
108 | channels_out = num_bases * kernel.channels_in
109 | shape = list(sample_shape) + [kernel.num_latent_gps, channels_out]
110 | weights = tf.random.normal(shape, dtype=dtype)
111 | return MultioutputDenseSampler(weights=weights,
112 | basis=basis,
113 | name=name,
114 | **kwargs)
115 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/updates/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from gpflow_sampling.sampling.updates.linear_updates import linear
5 | from gpflow_sampling.sampling.updates.exact_updates import exact
6 | from gpflow_sampling.sampling.updates.cg_updates import cg
7 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/updates/cg_updates.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from warnings import warn
11 | from gpflow.base import TensorLike
12 | from gpflow.config import default_jitter
13 | from gpflow.utilities import Dispatcher
14 | from gpflow import kernels, inducing_variables
15 | from gpflow_sampling import covariances, kernels as kernels_ext
16 | from gpflow_sampling.bases import kernel as kernel_basis
17 | from gpflow_sampling.bases.core import AbstractBasis
18 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler
19 | from gpflow_sampling.inducing_variables import InducingImages
20 | from gpflow_sampling.utils import get_default_preconditioner
21 |
22 |
23 | # ==============================================
24 | # cg_updates
25 | # ==============================================
26 | cg = Dispatcher("cg_updates")
27 |
28 |
29 | @cg.register(kernels.Kernel, TensorLike, TensorLike, TensorLike)
30 | def _cg_fallback(kern: kernels.Kernel,
31 | Z: TensorLike,
32 | u: TensorLike,
33 | f: TensorLike,
34 | *,
35 | diag: TensorLike = None,
36 | basis: AbstractBasis = None,
37 | preconditioner: tf.linalg.LinearOperator = "default",
38 | tol: float = 1e-3,
39 | max_iter: int = 100,
40 | **kwargs):
41 | """
42 | Return pathwise updates of a prior processes $f$ subject to the
43 | condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$.
44 | """
45 | u_shape = tuple(u.shape)
46 | f_shape = tuple(f.shape)
47 | assert u_shape[-1] == 1, "Recieved multiple output features"
48 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected"
49 | if basis is None: # finite-dimensional basis used to express the update
50 | basis = kernel_basis(kern, centers=Z)
51 |
52 | if diag is None:
53 | diag = default_jitter()
54 |
55 | # Prepare linear system for CG solver
56 | if isinstance(Z, inducing_variables.InducingVariables):
57 | Kff = covariances.Kuu(Z, kern, jitter=0.0)
58 | else:
59 | Kff = kern(Z, full_cov=True)
60 | Kuu = tf.linalg.set_diag(Kff, tf.linalg.diag_part(Kff) + diag)
61 | operator = tf.linalg.LinearOperatorFullMatrix(Kuu,
62 | is_non_singular=True,
63 | is_self_adjoint=True,
64 | is_positive_definite=True,
65 | is_square=True)
66 |
67 | if preconditioner == "default":
68 | preconditioner = get_default_preconditioner(Kff, diag=diag)
69 |
70 | # Compute error term and CG initializer
71 | err = tf.linalg.adjoint(u - f) # [S, 1, M]
72 | err -= (diag ** 0.5) * tf.random.normal(err.shape, dtype=err.dtype)
73 | if preconditioner is None:
74 | initializer = None
75 | else:
76 | initializer = preconditioner.matvec(err)
77 |
78 | # Approximately solve for $Cov(u, u)^{-1}(u - f(Z))$ using linear CG
79 | res = tf.linalg.experimental.conjugate_gradient(operator=operator,
80 | rhs=err,
81 | preconditioner=preconditioner,
82 | x=initializer,
83 | tol=tol,
84 | max_iter=max_iter)
85 |
86 | weights = res.x
87 | if tf.math.count_nonzero(tf.math.is_nan(weights)):
88 | warn("One or more update weights returned by CG are NaN")
89 |
90 | return DenseSampler(basis=basis, weights=weights, **kwargs)
91 |
92 |
93 | @cg.register((kernels.SharedIndependent,
94 | kernels.SeparateIndependent,
95 | kernels.LinearCoregionalization),
96 | TensorLike,
97 | TensorLike,
98 | TensorLike)
99 | def _cg_independent(kern: kernels.MultioutputKernel,
100 | Z: TensorLike,
101 | u: TensorLike,
102 | f: TensorLike,
103 | *,
104 | diag: TensorLike = None,
105 | basis: AbstractBasis = None,
106 | preconditioner: tf.linalg.LinearOperator = "default",
107 | tol: float = 1e-3,
108 | max_iter: int = 100,
109 | multioutput_axis: int = 0,
110 | **kwargs):
111 | """
112 | Return (independent) pathwise updates for each of the latent prior processes
113 | $f$ subject to the condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$.
114 | """
115 | u_shape = tuple(u.shape)
116 | f_shape = tuple(f.shape)
117 | assert u_shape[-1] == kern.num_latent_gps, "Num. outputs != num. latent GPs"
118 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected"
119 | if basis is None: # finite-dimensional basis used to express the update
120 | basis = kernel_basis(kern, centers=Z)
121 |
122 | if diag is None:
123 | diag = default_jitter()
124 |
125 | # Prepare linear system for CG solver
126 | if isinstance(Z, inducing_variables.InducingVariables):
127 | Kff = covariances.Kuu(Z, kern, jitter=0.0)
128 | else:
129 | Kff = kern(Z, full_cov=True, full_output_cov=False)
130 | Kuu = tf.linalg.set_diag(Kff, tf.linalg.diag_part(Kff) + diag)
131 | operator = tf.linalg.LinearOperatorFullMatrix(Kuu,
132 | is_non_singular=True,
133 | is_self_adjoint=True,
134 | is_positive_definite=True,
135 | is_square=True)
136 |
137 | if preconditioner == "default":
138 | preconditioner = get_default_preconditioner(Kff, diag=diag)
139 |
140 | err = tf.linalg.adjoint(u - f) # [S, L, M]
141 | err -= (diag ** 0.5) * tf.random.normal(err.shape, dtype=err.dtype)
142 | if preconditioner is None:
143 | initializer = None
144 | else:
145 | initializer = preconditioner.matvec(err)
146 |
147 | # Approximately solve for $Cov(u, u)^{-1}(u - f(Z))$ using linear CG
148 | res = tf.linalg.experimental.conjugate_gradient(operator=operator,
149 | rhs=err,
150 | preconditioner=preconditioner,
151 | x=initializer,
152 | tol=tol,
153 | max_iter=max_iter)
154 |
155 | weights = res.x
156 | if tf.math.count_nonzero(tf.math.is_nan(weights)):
157 | warn("One or more update weights returned by CG are NaN")
158 |
159 | return MultioutputDenseSampler(basis=basis,
160 | weights=weights,
161 | multioutput_axis=multioutput_axis,
162 | **kwargs)
163 |
164 |
165 | @cg.register(kernels.SharedIndependent,
166 | inducing_variables.SharedIndependentInducingVariables,
167 | TensorLike,
168 | TensorLike)
169 | def _cg_shared(kern, Z, u, f, *, multioutput_axis=None, **kwargs):
170 | """
171 | Edge-case where the multioutput axis gets suppressed.
172 | """
173 | return _cg_independent(kern,
174 | Z,
175 | u,
176 | f,
177 | multioutput_axis=multioutput_axis,
178 | **kwargs)
179 |
180 |
181 | @cg.register(kernels_ext.Conv2d, InducingImages, TensorLike, TensorLike)
182 | def _cg_conv2d(kern, Z, u, f, *, basis=None, multioutput_axis=None, **kwargs):
183 | if basis is None: # finite-dimensional basis used to express the update
184 | basis = kernel_basis(kern, centers=Z, full_spatial=True)
185 | return _cg_independent(kern,
186 | Z,
187 | u,
188 | f,
189 | basis=basis,
190 | multioutput_axis=multioutput_axis,
191 | **kwargs)
192 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/updates/exact_updates.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from gpflow import kernels, inducing_variables
11 | from gpflow.base import TensorLike
12 | from gpflow.config import default_jitter
13 | from gpflow.utilities import Dispatcher
14 | from gpflow_sampling import covariances, kernels as kernels_ext
15 | from gpflow_sampling.utils import swap_axes, move_axis
16 | from gpflow_sampling.bases import kernel as kernel_basis
17 | from gpflow_sampling.bases.core import AbstractBasis
18 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler
19 | from gpflow_sampling.inducing_variables import InducingImages
20 |
21 |
22 | # ==============================================
23 | # exact_updates
24 | # ==============================================
25 | exact = Dispatcher("exact_updates")
26 |
27 |
28 | @exact.register(kernels.Kernel, TensorLike, TensorLike, TensorLike)
29 | def _exact_fallback(kern: kernels.Kernel,
30 | Z: TensorLike,
31 | u: TensorLike,
32 | f: TensorLike,
33 | *,
34 | L : TensorLike = None,
35 | diag: TensorLike = None,
36 | basis: AbstractBasis = None,
37 | **kwargs):
38 | """
39 | Return pathwise updates of a prior processes $f$ subject to the
40 | condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$.
41 | """
42 | u_shape = tuple(u.shape)
43 | f_shape = tuple(f.shape)
44 | assert u_shape[-1] == 1, "Recieved multiple output features"
45 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected"
46 | if basis is None: # finite-dimensional basis used to express the update
47 | basis = kernel_basis(kern, centers=Z)
48 |
49 | # Prepare diagonal term
50 | if diag is None:
51 | diag = default_jitter()
52 | if isinstance(diag, float):
53 | diag = tf.convert_to_tensor(diag, dtype=f.dtype)
54 | diag = tf.expand_dims(diag, axis=-1) # [M, 1] or [1, 1] or [1]
55 |
56 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$
57 | err = u - f # [S, M, 1]
58 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype)
59 | if L is None:
60 | if isinstance(Z, inducing_variables.InducingVariables):
61 | K = covariances.Kuu(Z, kern, jitter=0.0)
62 | else:
63 | K = kern(Z, full_cov=True)
64 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0])
65 | L = tf.linalg.cholesky(K)
66 |
67 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$
68 | weights = tf.linalg.adjoint(tf.linalg.cholesky_solve(L, err))
69 | return DenseSampler(basis=basis, weights=weights, **kwargs)
70 |
71 |
72 | @exact.register((kernels.SharedIndependent,
73 | kernels.SeparateIndependent,
74 | kernels.LinearCoregionalization),
75 | TensorLike,
76 | TensorLike,
77 | TensorLike)
78 | def _exact_independent(kern: kernels.MultioutputKernel,
79 | Z: TensorLike,
80 | u: TensorLike,
81 | f: TensorLike,
82 | *,
83 | L: TensorLike = None,
84 | diag: TensorLike = None,
85 | basis: AbstractBasis = None,
86 | multioutput_axis: int = 0,
87 | **kwargs):
88 | """
89 | Return (independent) pathwise updates for each of the latent prior processes
90 | $f$ subject to the condition $p(f | u) = N(f | u, diag)$ on $f = f(Z)$.
91 | """
92 | u_shape = tuple(u.shape)
93 | f_shape = tuple(f.shape)
94 | assert u_shape[-1] == kern.num_latent_gps, "Num. outputs != num. latent GPs"
95 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected"
96 | if basis is None: # finite-dimensional basis used to express the update
97 | basis = kernel_basis(kern, centers=Z)
98 |
99 | # Prepare diagonal term
100 | if diag is None: # used by
101 | diag = default_jitter()
102 | if isinstance(diag, float):
103 | diag = tf.convert_to_tensor(diag, dtype=f.dtype)
104 | diag = tf.expand_dims(diag, axis=-1) # ([L] or []) + ([M] or []) + [1]
105 |
106 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$
107 | err = swap_axes(u - f, -3, -1) # [L, M, S]
108 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype)
109 | if L is None:
110 | if isinstance(Z, inducing_variables.InducingVariables):
111 | K = covariances.Kuu(Z, kern, jitter=0.0)
112 | else:
113 | K = kern(Z, full_cov=True, full_output_cov=False)
114 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0])
115 | L = tf.linalg.cholesky(K)
116 |
117 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$
118 | weights = move_axis(tf.linalg.cholesky_solve(L, err), -1, -3) # [S, L, M]
119 | return MultioutputDenseSampler(basis=basis,
120 | weights=weights,
121 | multioutput_axis=multioutput_axis,
122 | **kwargs)
123 |
124 |
125 | @exact.register(kernels.SharedIndependent,
126 | inducing_variables.SharedIndependentInducingVariables,
127 | TensorLike,
128 | TensorLike)
129 | def _exact_shared(kern, Z, u, f, *, multioutput_axis=None, **kwargs):
130 | """
131 | Edge-case where the multioutput axis gets suppressed.
132 | """
133 | return _exact_independent(kern,
134 | Z,
135 | u,
136 | f,
137 | multioutput_axis=multioutput_axis,
138 | **kwargs)
139 |
140 |
141 | @exact.register(kernels_ext.Conv2d, InducingImages, TensorLike, TensorLike)
142 | def _exact_conv2d(kern, Z, u, f, *, basis=None, multioutput_axis=None, **kwargs):
143 | if basis is None: # finite-dimensional basis used to express the update
144 | basis = kernel_basis(kern, centers=Z, full_spatial=True)
145 | return _exact_independent(kern,
146 | Z,
147 | u,
148 | f,
149 | basis=basis,
150 | multioutput_axis=multioutput_axis,
151 | **kwargs)
152 |
--------------------------------------------------------------------------------
/gpflow_sampling/sampling/updates/linear_updates.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | """
8 | Pathwise updates for Gaussian processes with linear
9 | kernels in some explicit, finite-dimensional basis.
10 | """
11 | # ---- Imports
12 | import tensorflow as tf
13 |
14 | from gpflow import inducing_variables
15 | from gpflow.base import TensorLike
16 | from gpflow.config import default_jitter
17 | from gpflow.utilities import Dispatcher
18 | from gpflow_sampling.utils import swap_axes, move_axis, inducing_to_tensor
19 | from gpflow_sampling.bases.core import AbstractBasis
20 | from gpflow_sampling.sampling.core import DenseSampler, MultioutputDenseSampler
21 |
22 |
23 | # ==============================================
24 | # linear_update
25 | # ==============================================
26 | linear = Dispatcher("linear_updates")
27 |
28 |
29 | @linear.register(TensorLike, TensorLike, TensorLike)
30 | def _linear_fallback(Z: TensorLike,
31 | u: TensorLike,
32 | f: TensorLike,
33 | *,
34 | L : TensorLike = None,
35 | diag: TensorLike = None,
36 | basis: AbstractBasis = None,
37 | **kwargs):
38 |
39 | u_shape = tuple(u.shape)
40 | f_shape = tuple(f.shape)
41 | assert u_shape[-1] == 1, "Recieved multiple output features"
42 | assert u_shape == f_shape[-len(u_shape):], "Incompatible shapes detected"
43 |
44 | # Prepare diagonal term
45 | if diag is None: # used by
46 | diag = default_jitter()
47 | if isinstance(diag, float):
48 | diag = tf.convert_to_tensor(diag, dtype=f.dtype)
49 | diag = tf.expand_dims(diag, axis=-1) # [M, 1] or [1, 1] or [1]
50 |
51 | # Extract "features" of Z
52 | if basis is None:
53 | if isinstance(Z, inducing_variables.InducingVariables):
54 | feat = inducing_to_tensor(Z) # [M, D]
55 | else:
56 | feat = Z
57 | else:
58 | feat = basis(Z) # [M, D] (maybe a different "D" than above)
59 |
60 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$
61 | err = swap_axes(u - f, -3, -1) # [1, M, S]
62 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype)
63 | M, D = feat.shape[-2:]
64 | if L is None:
65 | if D < M:
66 | feat_iDiag = feat * tf.math.reciprocal(diag)
67 | S = tf.matmul(feat_iDiag, feat, transpose_a=True) # [D, D]
68 | L = tf.linalg.cholesky(S + tf.eye(S.shape[-1], dtype=S.dtype))
69 | else:
70 | K = tf.matmul(feat, feat, transpose_b=True) # [M, M]
71 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0])
72 | L = tf.linalg.cholesky(K)
73 | else:
74 | assert L.shape[-1] == min(M, D) # TODO: improve me
75 |
76 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$
77 | if D < M:
78 | feat_iDiag = feat * tf.math.reciprocal(diag)
79 | weights = tf.linalg.adjoint(tf.linalg.cholesky_solve(L,
80 | tf.matmul(feat_iDiag, err, transpose_a=True)))
81 | else:
82 | iK_err = tf.linalg.cholesky_solve(L, err) # [S, M, 1]
83 | weights = tf.matmul(iK_err, feat, transpose_a=True) # [S, 1, D]
84 |
85 | return DenseSampler(basis=basis, weights=move_axis(weights, -2, -3), **kwargs)
86 |
87 |
88 | @linear.register(inducing_variables.MultioutputInducingVariables,
89 | TensorLike,
90 | TensorLike)
91 | def _linear_multioutput(Z: inducing_variables.MultioutputInducingVariables,
92 | u: TensorLike,
93 | f: TensorLike,
94 | *,
95 | L: TensorLike = None,
96 | diag: TensorLike = None,
97 | basis: AbstractBasis = None,
98 | multioutput_axis: int = "default",
99 | **kwargs):
100 | assert tuple(u.shape) == tuple(f.shape)
101 | if multioutput_axis == "default":
102 | multioutput_axis = None if (basis is None) else 0
103 |
104 | # Prepare diagonal term
105 | if diag is None: # used by
106 | diag = default_jitter()
107 | if isinstance(diag, float):
108 | diag = tf.convert_to_tensor(diag, dtype=f.dtype)
109 | diag = tf.expand_dims(diag, axis=-1) # ([L] or []) + ([M] or []) + [1]
110 |
111 | # Extract "features" of Z
112 | if basis is None:
113 | if isinstance(Z, inducing_variables.InducingVariables):
114 | feat = inducing_to_tensor(Z) # [L, M, D] or [M, D]
115 | else:
116 | feat = Z
117 | elif isinstance(Z, inducing_variables.SharedIndependentInducingVariables):
118 | feat = basis(Z)
119 | else:
120 | feat = basis(Z, multioutput_axis=0) # first axis of Z is output-specific
121 |
122 | # Compute error term and matrix square root $Cov(u, u)^{1/2}$
123 | err = swap_axes(u - f, -3, -1) # [L, M, S]
124 | err -= tf.sqrt(diag) * tf.random.normal(err.shape, dtype=err.dtype)
125 | M, D = feat.shape[-2:]
126 | if L is None:
127 | if D < M:
128 | feat_iDiag = feat * tf.math.reciprocal(diag)
129 | S = tf.matmul(feat_iDiag, feat, transpose_a=True) # [L, D, D] or [D, D]
130 | L = tf.linalg.cholesky(S + tf.eye(S.shape[-1], dtype=S.dtype))
131 | else:
132 | K = tf.matmul(feat, feat, transpose_b=True) # [L, M, M] or [M, M]
133 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag[..., 0])
134 | L = tf.linalg.cholesky(K)
135 | else:
136 | assert L.shape[-1] == min(M, D) # TODO: improve me
137 |
138 | # Solve for $Cov(u, u)^{-1}(u - f(Z))$
139 | if D < M:
140 | feat_iDiag = feat * tf.math.reciprocal(diag)
141 | weights = tf.linalg.adjoint(tf.linalg.cholesky_solve(L,
142 | tf.matmul(feat_iDiag, err, transpose_a=True)))
143 | else:
144 | iK_err = tf.linalg.cholesky_solve(L, err) # [L, S, M]
145 | weights = tf.matmul(iK_err, feat, transpose_a=True) # [L, S, D]
146 |
147 | return MultioutputDenseSampler(basis=basis,
148 | weights=swap_axes(weights, -3, -2), # [S, L, D]
149 | multioutput_axis=multioutput_axis,
150 | **kwargs)
151 |
--------------------------------------------------------------------------------
/gpflow_sampling/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | from gpflow_sampling.utils.array_ops import *
5 | from gpflow_sampling.utils.linalg import *
6 | from gpflow_sampling.utils.gpflow_ops import *
7 |
--------------------------------------------------------------------------------
/gpflow_sampling/utils/array_ops.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import numpy as np
9 | import tensorflow as tf
10 |
11 | # ---- Exports
12 | __all__ = (
13 | 'normalize_axis',
14 | 'swap_axes',
15 | 'move_axis',
16 | 'expand_n',
17 | 'expand_to',
18 | )
19 |
20 |
21 | # ==============================================
22 | # misc
23 | # ==============================================
24 | def normalize_axis(axis, ndims, minval=None, maxval=None):
25 | if minval is None:
26 | minval = -ndims
27 |
28 | if maxval is None:
29 | maxval = ndims - 1
30 |
31 | assert maxval >= axis >= minval
32 | return ndims + axis if (axis < 0) else axis
33 |
34 |
35 | def move_axis(arr: tf.Tensor, src: int, dest: int):
36 | ndims = len(arr.shape)
37 | src = ndims + src if (src < 0) else src
38 | dest = ndims + dest if (dest < 0) else dest
39 |
40 | src = normalize_axis(src, ndims)
41 | dest = normalize_axis(dest, ndims)
42 |
43 | perm = list(range(ndims))
44 | perm.insert(dest, perm.pop(src))
45 | return tf.transpose(arr, perm)
46 |
47 |
48 | def swap_axes(arr: tf.Tensor, a: int, b: int):
49 | ndims = len(arr.shape)
50 | a = normalize_axis(a, ndims)
51 | b = normalize_axis(b, ndims)
52 |
53 | perm = list(range(ndims))
54 | perm[a] = b
55 | perm[b] = a
56 | return tf.transpose(arr, perm)
57 |
58 |
59 | def expand_n(arr: tf.Tensor, axis: int, n: int):
60 | ndims = len(arr.shape)
61 | axis = normalize_axis(axis, ndims, maxval=ndims)
62 | return arr[axis * (slice(None),) + n * (np.newaxis,)]
63 |
64 |
65 | def expand_to(arr: tf.Tensor, axis: int, ndims: int):
66 | _ndims = len(arr.shape)
67 | if _ndims == ndims:
68 | return tf.identity(arr)
69 | assert ndims > _ndims
70 | return expand_n(arr, axis, ndims - _ndims)
71 |
--------------------------------------------------------------------------------
/gpflow_sampling/utils/conv_ops.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 | from typing import List
10 | from gpflow.base import TensorType
11 | from gpflow_sampling.utils.array_ops import move_axis
12 |
13 |
14 | # ---- Exports
15 | __all__ = (
16 | 'reformat_shape',
17 | 'reformat_data',
18 | )
19 |
20 |
21 | # ==============================================
22 | # conv_ops
23 | # ==============================================
24 | def reformat_shape(shape: List,
25 | input_format: str,
26 | output_format: str) -> List:
27 | """
28 | Helper method for shape data between NHWC and NCHW formats.
29 | """
30 | if input_format == output_format:
31 | return shape # noop
32 |
33 | if input_format == 'NHWC':
34 | assert output_format == "NCHW"
35 | return shape[:-3] + shape[-1:] + shape[-3: -1]
36 |
37 | if input_format == 'NCHW':
38 | assert output_format == "NHWC"
39 | return shape[:-3] + shape[-2:] + [shape[-3]]
40 |
41 | raise NotImplementedError
42 |
43 |
44 | def reformat_data(x: TensorType,
45 | format_in: str,
46 | format_out: str):
47 | """
48 | Helper method for converting image data between NHWC and NCHW formats.
49 | """
50 | if format_in == format_out:
51 | return x # noop
52 |
53 | if format_in == "NHWC":
54 | assert format_out == "NCHW"
55 | return move_axis(x, -1, -3)
56 |
57 | if format_in == "NCHW":
58 | assert format_out == "NHWC"
59 | return move_axis(x, -3, -1)
60 |
61 | raise NotImplementedError
62 |
--------------------------------------------------------------------------------
/gpflow_sampling/utils/gpflow_ops.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 | from gpflow.inducing_variables import (InducingVariables,
10 | MultioutputInducingVariables,
11 | SharedIndependentInducingVariables,
12 | SeparateIndependentInducingVariables)
13 |
14 | from gpflow.utilities import Dispatcher
15 |
16 |
17 | # ---- Exports
18 | __all__ = ('get_inducing_shape', 'inducing_to_tensor')
19 |
20 |
21 | # ==============================================
22 | # gpflow_utils
23 | # ==============================================
24 | get_inducing_shape = Dispatcher("get_inducing_shape")
25 | inducing_to_tensor = Dispatcher("inducing_to_tensor")
26 |
27 | @get_inducing_shape.register(InducingVariables)
28 | def _getter(x):
29 | assert not isinstance(InducingVariables, MultioutputInducingVariables)
30 | return list(x.Z.shape)
31 |
32 |
33 | @get_inducing_shape.register(SharedIndependentInducingVariables)
34 | def _getter(x: SharedIndependentInducingVariables):
35 | assert len(x.inducing_variables) == 1
36 | return get_inducing_shape(x.inducing_variables[0])
37 |
38 |
39 | @get_inducing_shape.register(SeparateIndependentInducingVariables)
40 | def _getter(x: SeparateIndependentInducingVariables):
41 | for n, z in enumerate(x.inducing_variables):
42 | if n == 0:
43 | shape = get_inducing_shape(z)
44 | else:
45 | assert shape == get_inducing_shape(z)
46 | return [n + 1] + shape
47 |
48 |
49 | @inducing_to_tensor.register(InducingVariables)
50 | def _convert(x: InducingVariables, **kwargs):
51 | assert not isinstance(InducingVariables, MultioutputInducingVariables)
52 | return tf.convert_to_tensor(x.Z, **kwargs)
53 |
54 |
55 | @inducing_to_tensor.register(SharedIndependentInducingVariables)
56 | def _convert(x: InducingVariables, **kwargs):
57 | assert len(x.inducing_variables) == 1
58 | return inducing_to_tensor(x.inducing_variables[0], **kwargs)
59 |
60 |
61 | @inducing_to_tensor.register(SeparateIndependentInducingVariables)
62 | def _convert(x: InducingVariables, **kwargs):
63 | return tf.stack([inducing_to_tensor(z, **kwargs)
64 | for z in x.inducing_variables], axis=0)
65 |
66 |
--------------------------------------------------------------------------------
/gpflow_sampling/utils/linalg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import numpy as np
9 | import tensorflow as tf
10 | from typing import List, Union
11 | from string import ascii_lowercase
12 | from collections.abc import Iterable
13 | from gpflow.config import default_jitter
14 | from gpflow_sampling.utils.array_ops import normalize_axis
15 | from tensorflow_probability.python.math import pivoted_cholesky
16 |
17 | # ---- Exports
18 | __all__ = (
19 | 'batch_tensordot',
20 | 'get_default_preconditioner',
21 | )
22 |
23 |
24 | # ==============================================
25 | # linalg
26 | # ==============================================
27 | def batch_tensordot(a: tf.Tensor,
28 | b: tf.Tensor,
29 | axes: List,
30 | batch_axes: List = None) -> tf.Tensor:
31 | """
32 | Computes tensor products like with:
33 | - Support for batch dimensions; 1-to-1 rather than Cartesian product
34 | - Broadcasting of batch dimensions
35 |
36 | Example:
37 | a = tf.random.rand([5, 4, 3, 2])
38 | b = tf.random.rand([6, 4, 2, 1])
39 | c = batch_tensordot(a, b, [-1, -2], [[-3, -2], [-3, -1]]) # [5, 6, 4, 3]
40 | """
41 | ndims_a = len(a.shape)
42 | ndims_b = len(b.shape)
43 | assert min(ndims_a, ndims_b) > 0
44 | assert max(ndims_a, ndims_b) < 26
45 |
46 | # Prepare batch and contraction axes
47 | def parse_axes(axes):
48 | if axes is None:
49 | return [], []
50 |
51 | assert len(axes) == 2
52 | axes_a, axes_b = axes
53 |
54 | a_is_int = isinstance(axes_a, int)
55 | b_is_int = isinstance(axes_b, int)
56 | assert a_is_int == b_is_int
57 | if a_is_int:
58 | return [normalize_axis(axes_a, ndims_a)], \
59 | [normalize_axis(axes_b, ndims_b)]
60 |
61 | assert isinstance(axes_a, Iterable)
62 | assert isinstance(axes_b, Iterable)
63 | axes_a = list(axes_a)
64 | axes_b = list(axes_b)
65 | length = len(axes_a)
66 | assert length == len(axes_b)
67 | if length == 0:
68 | return [], []
69 |
70 | axes_a = [normalize_axis(ax, ndims_a) for ax in axes_a]
71 | axes_b = [normalize_axis(ax, ndims_b) for ax in axes_b]
72 | return map(list, zip(*sorted(zip(axes_a, axes_b)))) # sort according to a
73 |
74 | reduce_a, reduce_b = parse_axes(axes) # defines the tensor contraction
75 | batch_a, batch_b = parse_axes(batch_axes) # group these together 1-to-1
76 |
77 | # Construct left-hand side einsum conventions
78 | active_a = reduce_a + batch_a
79 | active_b = reduce_b + batch_b
80 | assert len(active_a) == len(set(active_a)) # check for duplicates
81 | assert len(active_b) == len(set(active_b))
82 |
83 | lhs_a = list(ascii_lowercase[:ndims_a])
84 | lhs_b = list(ascii_lowercase[ndims_a: ndims_a + ndims_b])
85 | for (pos_a, pos_b) in zip(active_a, active_b):
86 | lhs_b[pos_b] = lhs_a[pos_a]
87 |
88 | # Construct right-hand side einsum convention
89 | rhs = []
90 | for i, char_a in enumerate(lhs_a):
91 | if i not in reduce_a:
92 | rhs.append(char_a)
93 |
94 | for i, char_b in enumerate(lhs_b):
95 | if i not in active_b:
96 | rhs.append(char_b)
97 |
98 | # Enable broadcasting by eliminate singleton dimenisions
99 | for (pos_a, pos_b) in zip(batch_a, batch_b):
100 | if a.shape[pos_a] == b.shape[pos_b]:
101 | continue # TODO: test for edge cases
102 |
103 | if a.shape[pos_a] == 1:
104 | a = tf.squeeze(a, axis=pos_a)
105 | del lhs_a[pos_a]
106 |
107 | if b.shape[pos_b] == 1:
108 | b = tf.squeeze(b, axis=pos_b)
109 | del lhs_b[pos_b]
110 |
111 | # Compute einsum
112 | return tf.einsum(f"{''.join(lhs_a)},{''.join(lhs_b)}->{''.join(rhs)}", a, b)
113 |
114 |
115 | def get_default_preconditioner(arr: tf.Tensor,
116 | diag: Union[tf.Tensor, tf.linalg.LinearOperator],
117 | max_rank: int = 100,
118 | diag_rtol: float = None):
119 | """
120 | Returns a preconditioner representing
121 | $(D + LL^{T})^{-1}$ where $D$ is a diagonal matrix and $L$ is the
122 | partial pivoted Cholesky factor of a symmetric PSD matrix $A$.
123 | """
124 | if diag_rtol is None:
125 | diag_rtol = default_jitter()
126 |
127 | N = arr.shape[-1]
128 | if not isinstance(diag, tf.linalg.LinearOperator):
129 | diag = tf.convert_to_tensor(diag, dtype=arr.dtype)
130 | if N == 1 or (diag.shape.ndims and diag.shape[-1] > 1):
131 | diag = tf.linalg.LinearOperatorDiag(diag)
132 | else:
133 | diag = tf.linalg.LinearOperatorScaledIdentity(N, diag)
134 |
135 | piv_chol = pivoted_cholesky(arr, max_rank=max_rank, diag_rtol=diag_rtol)
136 | low_rank = tf.linalg.LinearOperatorLowRankUpdate(diag, piv_chol)
137 | return low_rank.inverse()
138 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | requirements = (
4 | 'numpy>=1.18.0',
5 | 'tensorflow>=2.2.0',
6 | 'tensorflow-probability>=0.9.0',
7 | 'gpflow>=2.0.3',
8 | )
9 |
10 | extra_requirements = {
11 | 'examples': (
12 | 'matplotlib',
13 | 'seaborn',
14 | 'tqdm',
15 | 'tensorflow-datasets',
16 | ),
17 | }
18 |
19 | setup(name='gpflow_sampling',
20 | version='0.2',
21 | license='Creative Commons Attribution-Noncommercial-Share Alike license',
22 | packages=find_packages(exclude=["examples*"]),
23 | python_requires='>=3.6',
24 | install_requires=requirements,
25 | extras_require=extra_requirements)
26 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/__init__.py
--------------------------------------------------------------------------------
/tests/kernels/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/kernels/__init__.py
--------------------------------------------------------------------------------
/tests/kernels/test_conv2d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from numpy import allclose
11 | from typing import Any, List, NamedTuple
12 | from gpflow import config as gpflow_config, kernels as gpflow_kernels
13 | from gpflow.config import default_float as floatx
14 | from gpflow_sampling import kernels, covariances
15 | from gpflow_sampling.inducing_variables import *
16 | from gpflow_sampling.covariances.Kfus import (_Kfu_conv2d_fallback,
17 | _Kfu_depthwise_conv2d_fallback)
18 |
19 | SupportedBaseKernels = (gpflow_kernels.Matern12,
20 | gpflow_kernels.Matern32,
21 | gpflow_kernels.Matern52,
22 | gpflow_kernels.SquaredExponential,)
23 |
24 |
25 | # ==============================================
26 | # test_conv2d
27 | # ==============================================
28 | class ConfigConv2d(NamedTuple):
29 | seed: int = 1
30 | floatx: Any = 'float64'
31 | jitter: float = 1e-16
32 |
33 | num_test: int = 16
34 | num_cond: int = 32
35 | kernel_variance: float = 1.0
36 | rel_lengthscales_min: float = 0.1
37 | rel_lengthscales_max: float = 1.0
38 |
39 |
40 | channels_in: int = 3
41 | channels_out: int = 2
42 | image_shape: List = [28, 28]
43 | patch_shape: List = [5, 5]
44 | strides: List = [2, 2]
45 |
46 | padding: str = "SAME"
47 | dilations: List = [1, 1]
48 |
49 |
50 | def test_conv2d(config: ConfigConv2d = None):
51 | if config is None:
52 | config = ConfigConv2d()
53 |
54 | tf.random.set_seed(config.seed)
55 | gpflow_config.set_default_float(config.floatx)
56 | gpflow_config.set_default_jitter(config.jitter)
57 |
58 | X_shape = [config.num_test] + config.image_shape + [config.channels_in]
59 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape)
60 | X /= tf.reduce_max(X)
61 |
62 | patch_len = config.channels_in * int(tf.reduce_prod(config.patch_shape))
63 | for cls in SupportedBaseKernels:
64 | minval = config.rel_lengthscales_min * (patch_len ** 0.5)
65 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5)
66 | lenscales = tf.random.uniform(shape=[patch_len],
67 | minval=minval,
68 | maxval=maxval,
69 | dtype=floatx())
70 |
71 | base = cls(lengthscales=lenscales, variance=config.kernel_variance)
72 | kern = kernels.Conv2d(kernel=base,
73 | image_shape=config.image_shape,
74 | patch_shape=config.patch_shape,
75 | channels_in=config.channels_in,
76 | channels_out=config.channels_out,
77 | dilations=config.dilations,
78 | padding=config.padding,
79 | strides=config.strides)
80 |
81 | kern._weights = tf.random.normal(kern._weights.shape, dtype=floatx())
82 |
83 | # Test full and shared inducing images
84 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in]
85 | Zsrc = tf.random.normal(Z_shape, dtype=floatx())
86 | for Z in (InducingImages(Zsrc),
87 | SharedInducingImages(Zsrc[..., :1], config.channels_in)):
88 |
89 | test = _Kfu_conv2d_fallback(Z, kern, X)
90 | allclose(covariances.Kfu(Z, kern, X), test)
91 |
92 |
93 | def test_conv2d_transpose(config: ConfigConv2d = None):
94 | if config is None:
95 | config = ConfigConv2d()
96 |
97 | tf.random.set_seed(config.seed)
98 | gpflow_config.set_default_float(config.floatx)
99 | gpflow_config.set_default_jitter(config.jitter)
100 |
101 | X_shape = [config.num_test] + config.image_shape + [config.channels_in]
102 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape)
103 | X /= tf.reduce_max(X)
104 |
105 | patch_len = config.channels_in * int(tf.reduce_prod(config.patch_shape))
106 | for cls in SupportedBaseKernels:
107 | minval = config.rel_lengthscales_min * (patch_len ** 0.5)
108 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5)
109 | lenscales = tf.random.uniform(shape=[patch_len],
110 | minval=minval,
111 | maxval=maxval,
112 | dtype=floatx())
113 |
114 | base = cls(lengthscales=lenscales, variance=config.kernel_variance)
115 | kern = kernels.Conv2dTranspose(kernel=base,
116 | image_shape=config.image_shape,
117 | patch_shape=config.patch_shape,
118 | channels_in=config.channels_in,
119 | channels_out=config.channels_out,
120 | dilations=config.dilations,
121 | padding=config.padding,
122 | strides=config.strides)
123 |
124 | kern._weights = tf.random.normal(kern._weights.shape, dtype=floatx())
125 |
126 | # Test full and shared inducing images
127 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in]
128 | Zsrc = tf.random.normal(Z_shape, dtype=floatx())
129 | for Z in (InducingImages(Zsrc),
130 | SharedInducingImages(Zsrc[..., :1], config.channels_in)):
131 |
132 | test = _Kfu_conv2d_fallback(Z, kern, X)
133 | allclose(covariances.Kfu(Z, kern, X), test)
134 |
135 |
136 | def test_depthwise_conv2d(config: ConfigConv2d = None):
137 | if config is None:
138 | config = ConfigConv2d()
139 |
140 | tf.random.set_seed(config.seed)
141 | gpflow_config.set_default_float(config.floatx)
142 | gpflow_config.set_default_jitter(config.jitter)
143 |
144 | X_shape = [config.num_test] + config.image_shape + [config.channels_in]
145 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape)
146 | X /= tf.reduce_max(X)
147 |
148 | patch_len = int(tf.reduce_prod(config.patch_shape))
149 | for cls in SupportedBaseKernels:
150 | minval = config.rel_lengthscales_min * (patch_len ** 0.5)
151 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5)
152 | lenscales = tf.random.uniform(shape=[config.channels_in, patch_len],
153 | minval=minval,
154 | maxval=maxval,
155 | dtype=floatx())
156 |
157 | base = cls(lengthscales=lenscales, variance=config.kernel_variance)
158 | kern = kernels.DepthwiseConv2d(kernel=base,
159 | image_shape=config.image_shape,
160 | patch_shape=config.patch_shape,
161 | channels_in=config.channels_in,
162 | channels_out=config.channels_out,
163 | dilations=config.dilations,
164 | padding=config.padding,
165 | strides=config.strides)
166 |
167 | kern._weights = tf.random.normal(kern._weights.shape, dtype=floatx())
168 |
169 | # Test full and shared inducing images
170 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in]
171 | Zsrc = tf.random.normal(Z_shape, dtype=floatx())
172 | for Z in (DepthwiseInducingImages(Zsrc),
173 | SharedDepthwiseInducingImages(Zsrc[..., :1], config.channels_in)):
174 |
175 | test = _Kfu_depthwise_conv2d_fallback(Z, kern, X)
176 | allclose(covariances.Kfu(Z, kern, X), test)
177 |
178 |
--------------------------------------------------------------------------------
/tests/sampling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/sampling/__init__.py
--------------------------------------------------------------------------------
/tests/sampling/priors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/sampling/priors/__init__.py
--------------------------------------------------------------------------------
/tests/sampling/priors/test_fourier_conv2d.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from numpy import allclose
11 | from typing import Any, List, NamedTuple
12 | from gpflow import config as gpflow_config, kernels as gpflow_kernels
13 | from gpflow.config import default_float as floatx
14 | from gpflow_sampling import kernels, covariances, inducing_variables
15 | from gpflow_sampling.sampling.priors import random_fourier
16 | from gpflow_sampling.bases import fourier as fourier_basis
17 | from gpflow_sampling.utils import batch_tensordot, swap_axes
18 |
19 | SupportedBaseKernels = (gpflow_kernels.Matern12,
20 | gpflow_kernels.Matern32,
21 | gpflow_kernels.Matern52,
22 | gpflow_kernels.SquaredExponential,)
23 |
24 |
25 | # ==============================================
26 | # test_fourier_conv2d
27 | # ==============================================
28 | class ConfigFourierConv2d(NamedTuple):
29 | seed: int = 1
30 | floatx: Any = 'float64'
31 | jitter: float = 1e-6
32 |
33 | num_test: int = 16
34 | num_cond: int = 5
35 | num_bases: int = 4096
36 | num_samples: int = 16384
37 | shard_size: int = 1024
38 |
39 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error
40 | rel_lengthscales_min: float = 0.5
41 | rel_lengthscales_max: float = 2.0
42 | num_latent_gps: int = 3
43 |
44 | # Convolutional settings
45 | channels_in: int = 2
46 | image_shape: List = [5, 5]
47 | patch_shape: List = [3, 3]
48 | strides: List = [1, 1]
49 | dilations: List = [1, 1]
50 |
51 |
52 | def _avg_spatial_inner_product(a, b=None, batch_dims: int = 0):
53 | _a = tf.reshape(a, list(a.shape[:-3]) + [-1, a.shape[-1]])
54 | if b is None:
55 | _b = _a
56 | else:
57 | _b = tf.reshape(b, list(b.shape[:-3]) + [-1, b.shape[-1]])
58 | batch_axes = 2 * [list(range(batch_dims))]
59 |
60 | prod = batch_tensordot(_a, _b, axes=[-1, -1], batch_axes=batch_axes)
61 | return tf.reduce_mean(prod, [-3, -1])
62 |
63 |
64 | def _test_fourier_conv2d_common(config, kern, X, Z):
65 | # Use closed-form evaluations as ground truth
66 | Kuu = covariances.Kuu(Z, kern)
67 | Kfu = covariances.Kfu(Z, kern, X)
68 | Kff = kern(X, full_cov=True)
69 |
70 | # Test Fourier-feature-based kernel approximator
71 | basis = fourier_basis(kern, num_bases=config.num_bases)
72 | feat_x = basis(X) # [N, B] or [N, L, B]
73 | feat_z = basis(Z)
74 |
75 | tol = 3 * config.num_bases ** -0.5
76 | assert allclose(_avg_spatial_inner_product(feat_x, feat_x), Kff, tol, tol)
77 | assert allclose(_avg_spatial_inner_product(feat_x, feat_z), Kfu, tol, tol)
78 | assert allclose(_avg_spatial_inner_product(feat_z, feat_z), Kuu, tol, tol)
79 | del feat_x, feat_z
80 |
81 | # Test covariance of functions draw from approximate prior
82 | fx = []
83 | fz = []
84 | count = 0
85 | while count < config.num_samples:
86 | size = min(config.shard_size, config.num_samples - count)
87 | funcs = random_fourier(kern,
88 | basis=basis,
89 | num_bases=config.num_bases,
90 | sample_shape=[size])
91 |
92 | fx.append(funcs(X))
93 | fz.append(funcs(Z))
94 | count += size
95 |
96 | fx = swap_axes(tf.concat(fx, axis=0), 0, -1) # [L, N, H, W, S]
97 | fz = swap_axes(tf.concat(fz, axis=0), 0, -1) # [L, M, 1, 1, S]
98 | nb = fx.shape.ndims - 4 # num. of batch dimensions
99 | tol += 3 * config.num_samples ** -0.5
100 | frac = 1 / config.num_samples
101 |
102 | assert allclose(frac * _avg_spatial_inner_product(fx, fx, nb), Kff, tol, tol)
103 | assert allclose(frac * _avg_spatial_inner_product(fx, fz, nb), Kfu, tol, tol)
104 | assert allclose(frac * _avg_spatial_inner_product(fz, fz, nb), Kuu, tol, tol)
105 |
106 |
107 | def test_conv2d(config: ConfigFourierConv2d = None):
108 | """
109 | TODO: Consider separating out the test for Conv2dTranspose since it only
110 | supports a subset of strides/dilatons.
111 | """
112 | if config is None:
113 | config = ConfigFourierConv2d()
114 |
115 | tf.random.set_seed(config.seed)
116 | gpflow_config.set_default_float(config.floatx)
117 | gpflow_config.set_default_jitter(config.jitter)
118 |
119 | X_shape = [config.num_test] + config.image_shape + [config.channels_in]
120 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape)
121 | X /= tf.reduce_max(X)
122 |
123 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in]
124 | Zsrc = tf.random.normal(Z_shape, dtype=floatx())
125 | Z = inducing_variables.InducingImages(Zsrc)
126 |
127 | patch_len = config.channels_in * config.patch_shape[0] * config.patch_shape[1]
128 | for base_cls in SupportedBaseKernels:
129 | minval = config.rel_lengthscales_min * (patch_len ** 0.5)
130 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5)
131 | lenscales = tf.random.uniform(shape=[patch_len],
132 | minval=minval,
133 | maxval=maxval,
134 | dtype=floatx())
135 |
136 | base = base_cls(lengthscales=lenscales, variance=config.kernel_variance)
137 | for cls in (kernels.Conv2d, kernels.Conv2dTranspose):
138 | kern = cls(kernel=base,
139 | image_shape=config.image_shape,
140 | patch_shape=config.patch_shape,
141 | channels_in=config.channels_in,
142 | channels_out=config.num_latent_gps,
143 | dilations=config.dilations,
144 | strides=config.strides)
145 |
146 | _test_fourier_conv2d_common(config, kern, X, Z)
147 |
148 |
149 | def test_depthwise_conv2d(config: ConfigFourierConv2d = None):
150 | if config is None:
151 | config = ConfigFourierConv2d()
152 |
153 | assert config.num_bases % config.channels_in == 0
154 | tf.random.set_seed(config.seed)
155 | gpflow_config.set_default_float(config.floatx)
156 | gpflow_config.set_default_jitter(config.jitter)
157 |
158 | X_shape = [config.num_test] + config.image_shape + [config.channels_in]
159 | X = tf.random.uniform(X_shape, dtype=floatx())
160 |
161 | img_shape = [config.num_cond] + config.patch_shape + [config.channels_in]
162 | Zsrc = tf.random.normal(img_shape, dtype=floatx())
163 | Z = inducing_variables.DepthwiseInducingImages(Zsrc)
164 |
165 | patch_len = config.patch_shape[0] * config.patch_shape[1]
166 | for base_cls in SupportedBaseKernels:
167 | minval = config.rel_lengthscales_min * (patch_len ** 0.5)
168 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5)
169 | lenscales = tf.random.uniform(shape=[config.channels_in, patch_len],
170 | minval=minval,
171 | maxval=maxval,
172 | dtype=floatx())
173 |
174 | base = base_cls(lengthscales=lenscales, variance=config.kernel_variance)
175 | for cls in (kernels.DepthwiseConv2d,):
176 | kern = cls(kernel=base,
177 | image_shape=config.image_shape,
178 | patch_shape=config.patch_shape,
179 | channels_in=config.channels_in,
180 | channels_out=config.num_latent_gps,
181 | dilations=config.dilations,
182 | strides=config.strides)
183 |
184 | _test_fourier_conv2d_common(config, kern, X, Z)
185 |
--------------------------------------------------------------------------------
/tests/sampling/priors/test_fourier_dense.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from numpy import allclose
11 | from typing import Any, NamedTuple
12 | from gpflow import kernels, config as gpflow_config
13 | from gpflow.inducing_variables import (InducingPoints,
14 | MultioutputInducingVariables,
15 | SharedIndependentInducingVariables,
16 | SeparateIndependentInducingVariables)
17 | from gpflow.config import default_float as floatx
18 | from gpflow_sampling import covariances
19 | from gpflow_sampling.bases import fourier as fourier_basis
20 | from gpflow_sampling.sampling.priors import random_fourier
21 |
22 | SupportedBaseKernels = (kernels.Matern12,
23 | kernels.Matern32,
24 | kernels.Matern52,
25 | kernels.SquaredExponential)
26 |
27 |
28 | # ==============================================
29 | # test_fourier_dense
30 | # ==============================================
31 | class ConfigFourierDense(NamedTuple):
32 | seed: int = 1
33 | floatx: Any = 'float64'
34 | jitter: float = 1e-6
35 |
36 | num_test: int = 16
37 | num_cond: int = 8
38 | num_bases: int = 4096
39 | num_samples: int = 16384
40 | shard_size: int = 1024
41 | input_dims: int = 3
42 |
43 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error
44 | lengthscales_min: float = 0.1
45 | lengthscales_max: float = 1.0
46 |
47 |
48 | def _test_fourier_dense_common(config, kern, X, Z):
49 | # Test Fourier-feature-based kernel approximator
50 | Kuu = covariances.Kuu(Z, kern)
51 | Kfu = covariances.Kfu(Z, kern, X)
52 | basis = fourier_basis(kern, num_bases=config.num_bases)
53 | Z_opt = dict() # options used when evaluating basis/prior at Z
54 | if isinstance(Z, MultioutputInducingVariables):
55 | Kff = kern(X, full_cov=True, full_output_cov=False)
56 | if not isinstance(Z, SharedIndependentInducingVariables):
57 | # Handling for non-shared multioutput inducing variables.
58 | # We need to indicate that Z's outermost axis should be
59 | # evaluated 1-to-1 with the L latent GPs
60 | Z_opt.setdefault("multioutput_axis", 0)
61 |
62 | feat_x = basis(X) # [L, N, B]
63 | feat_z = basis(Z, **Z_opt) # [L, M, B]
64 | else:
65 | Kff = kern(X, full_cov=True)
66 | feat_x = basis(X)
67 | feat_z = basis(Z)
68 |
69 | tol = 3 * config.num_bases ** -0.5
70 | assert allclose(tf.matmul(feat_x, feat_x, transpose_b=True), Kff, tol, tol)
71 | assert allclose(tf.matmul(feat_x, feat_z, transpose_b=True), Kfu, tol, tol)
72 | assert allclose(tf.matmul(feat_z, feat_z, transpose_b=True), Kuu, tol, tol)
73 | del feat_x, feat_z
74 |
75 | # Test covariance of functions draw from approximate prior
76 | fx = []
77 | fz = []
78 | count = 0
79 | while count < config.num_samples:
80 | size = min(config.shard_size, config.num_samples - count)
81 | funcs = random_fourier(kern,
82 | basis=basis,
83 | num_bases=config.num_bases,
84 | sample_shape=[size])
85 |
86 | fx.append(funcs(X))
87 | fz.append(funcs(Z, **Z_opt))
88 | count += size
89 |
90 | fx = tf.transpose(tf.concat(fx, axis=0)) # [L, N, S]
91 | fz = tf.transpose(tf.concat(fz, axis=0)) # [L, M, S]
92 | tol += 3 * config.num_samples ** -0.5
93 | frac = 1 / config.num_samples
94 | assert allclose(frac * tf.matmul(fx, fx, transpose_b=True), Kff, tol, tol)
95 | assert allclose(frac * tf.matmul(fx, fz, transpose_b=True), Kfu, tol, tol)
96 | assert allclose(frac * tf.matmul(fz, fz, transpose_b=True), Kuu, tol, tol)
97 |
98 |
99 | def test_dense(config: ConfigFourierDense = None):
100 | if config is None:
101 | config = ConfigFourierDense()
102 |
103 | tf.random.set_seed(config.seed)
104 | gpflow_config.set_default_float(config.floatx)
105 | gpflow_config.set_default_jitter(config.jitter)
106 |
107 | X = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx())
108 | Z = tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx())
109 | Z = InducingPoints(Z)
110 | for cls in SupportedBaseKernels:
111 | lenscales = tf.random.uniform(shape=[config.input_dims],
112 | minval=config.lengthscales_min,
113 | maxval=config.lengthscales_max,
114 | dtype=floatx())
115 |
116 | kern = cls(lengthscales=lenscales, variance=config.kernel_variance)
117 | _test_fourier_dense_common(config, kern, X, Z)
118 |
119 |
120 | def test_dense_shared(config: ConfigFourierDense = None, output_dim: int = 2):
121 | if config is None:
122 | config = ConfigFourierDense()
123 |
124 | tf.random.set_seed(config.seed)
125 | gpflow_config.set_default_float(config.floatx)
126 | gpflow_config.set_default_jitter(config.jitter)
127 |
128 | X = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx())
129 | Z = tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx())
130 | Z = SharedIndependentInducingVariables(InducingPoints(Z))
131 | for cls in SupportedBaseKernels:
132 | lenscales = tf.random.uniform(shape=[config.input_dims],
133 | minval=config.lengthscales_min,
134 | maxval=config.lengthscales_max,
135 | dtype=floatx())
136 |
137 | base = cls(lengthscales=lenscales, variance=config.kernel_variance)
138 | kern = kernels.SharedIndependent(base, output_dim=output_dim)
139 | _test_fourier_dense_common(config, kern, X, Z)
140 |
141 |
142 | def test_dense_separate(config: ConfigFourierDense = None):
143 | if config is None:
144 | config = ConfigFourierDense()
145 |
146 | tf.random.set_seed(config.seed)
147 | gpflow_config.set_default_float(config.floatx)
148 | gpflow_config.set_default_jitter(config.jitter)
149 |
150 | allZ = []
151 | allK = []
152 | for cls in SupportedBaseKernels:
153 | lenscales = tf.random.uniform(shape=[config.input_dims],
154 | minval=config.lengthscales_min,
155 | maxval=config.lengthscales_max,
156 | dtype=floatx())
157 |
158 | rel_variance = tf.random.uniform(shape=[],
159 | minval=0.9,
160 | maxval=1.1,
161 | dtype=floatx())
162 |
163 | allK.append(cls(lengthscales=lenscales,
164 | variance=config.kernel_variance * rel_variance))
165 |
166 | allZ.append(InducingPoints(
167 | tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx())))
168 |
169 | kern = kernels.SeparateIndependent(allK)
170 | Z = SeparateIndependentInducingVariables(allZ)
171 | X = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx())
172 | _test_fourier_dense_common(config, kern, X, Z)
173 |
174 |
175 | if __name__ == "__main__":
176 | test_dense()
177 | test_dense_shared()
178 | test_dense_separate()
179 |
--------------------------------------------------------------------------------
/tests/sampling/updates/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/j-wilson/GPflowSampling/bc9c553fe4d8f522726f002c18df9965246df345/tests/sampling/updates/__init__.py
--------------------------------------------------------------------------------
/tests/sampling/updates/common.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 | from numpy import allclose
10 | from typing import Union, NamedTuple
11 | from functools import partial, update_wrapper
12 |
13 | from gpflow import config as gpflow_config
14 | from gpflow import kernels, mean_functions
15 | from gpflow.base import TensorLike
16 | from gpflow.config import default_jitter, default_float as floatx
17 | from gpflow.models import GPR, SVGP
18 | from gpflow.kernels import MultioutputKernel, SharedIndependent
19 | from gpflow.utilities import Dispatcher
20 | from gpflow.inducing_variables import (InducingPoints,
21 | InducingVariables,
22 | SharedIndependentInducingVariables,
23 | SeparateIndependentInducingVariables)
24 | from gpflow_sampling import covariances, kernels as kernels_ext
25 | from gpflow_sampling.utils import batch_tensordot
26 | from gpflow_sampling.inducing_variables import *
27 |
28 | SupportedBaseKernels = (kernels.Matern12,
29 | kernels.Matern32,
30 | kernels.Matern52,
31 | kernels.SquaredExponential)
32 |
33 | # ---- Export
34 | __all__ = ('sample_joint',
35 | 'avg_spatial_inner_product',
36 | 'test_update_sparse',
37 | 'test_update_sparse_shared',
38 | 'test_update_sparse_separate',
39 | 'test_update_conv2d')
40 |
41 | # ==============================================
42 | # common
43 | # ==============================================
44 | sample_joint = Dispatcher("sample_joint")
45 |
46 |
47 | @sample_joint.register(kernels.Kernel, TensorLike, TensorLike)
48 | def _sample_joint_fallback(kern,
49 | X,
50 | Xnew,
51 | num_samples: int,
52 | L: TensorLike = None,
53 | diag: TensorLike = None):
54 | """
55 | Sample from the joint distribution of $f(X), g(Z)$ via a
56 | location-scale transform.
57 | """
58 | if diag is None:
59 | diag = default_jitter()
60 |
61 | if L is None:
62 | K = kern(tf.concat([X, Xnew], axis=-2), full_cov=True)
63 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag)
64 | L = tf.linalg.cholesky(K)
65 |
66 | # Draw samples using a location-scale transform
67 | rvs = tf.random.normal(list(L.shape[:-1]) + [num_samples], dtype=floatx())
68 | draws = tf.expand_dims(L @ rvs, 0) # [1, N + T, S]
69 | return tf.split(tf.transpose(draws), [-1, Xnew.shape[0]], axis=-2), L
70 |
71 |
72 | @sample_joint.register(kernels.Kernel, InducingVariables, TensorLike)
73 | def _sample_joint_inducing(kern,
74 | Z,
75 | Xnew,
76 | num_samples: int,
77 | L: TensorLike = None,
78 | diag: Union[float, tf.Tensor] = None):
79 | """
80 | Sample from the joint distribution of $f(X), g(Z)$ via a
81 | location-scale transform.
82 | """
83 | if diag is None:
84 | diag = default_jitter()
85 |
86 | # Construct joint covariance and compute matrix square root
87 | has_multiple_outputs = isinstance(kern, MultioutputKernel)
88 | if L is None:
89 | if has_multiple_outputs:
90 | Kff = kern(Xnew, full_cov=True, full_output_cov=False)
91 | else:
92 | Kff = kern(Xnew, full_cov=True)
93 | Kuu = covariances.Kuu(Z, kern, jitter=0.0)
94 | Kuf = covariances.Kuf(Z, kern, Xnew)
95 | if isinstance(kern, SharedIndependent) and \
96 | isinstance(Z, SharedIndependentInducingVariables):
97 | Kuu = tf.tile(Kuu[None], [Kff.shape[0], 1, 1])
98 | Kuf = tf.tile(Kuf[None], [Kff.shape[0], 1, 1])
99 |
100 | K = tf.concat([tf.concat([Kuu, Kuf], axis=-1),
101 | tf.concat([tf.linalg.adjoint(Kuf), Kff], axis=-1)], axis=-2)
102 |
103 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag)
104 | L = tf.linalg.cholesky(K)
105 |
106 | # Draw samples using a location-scale transform
107 | rvs = tf.random.normal(list(L.shape[:-1]) + [num_samples], dtype=floatx())
108 | draws = L @ rvs # [L, M + N, S] or [M + N, S]
109 | if not has_multiple_outputs:
110 | draws = tf.expand_dims(draws, 0)
111 |
112 | return tf.split(tf.transpose(draws), [-1, Xnew.shape[0]], axis=-2), L
113 |
114 |
115 | @sample_joint.register(kernels_ext.Conv2d, InducingImages, TensorLike)
116 | def _sample_joint_conv2d(kern,
117 | Z,
118 | Xnew,
119 | num_samples: int,
120 | L: TensorLike = None,
121 | diag: Union[float, tf.Tensor] = None):
122 | """
123 | Sample from the joint distribution of $f(X), g(Z)$ via a
124 | location-scale transform.
125 | """
126 | if diag is None:
127 | diag = default_jitter()
128 |
129 | # Construct joint covariance and compute matrix square root
130 | if L is None:
131 | Zp = Z.as_patches # [M, patch_len]
132 | Xp = kern.get_patches(Xnew, full_spatial=False)
133 | P = tf.concat([Zp, tf.reshape(Xp, [-1, Xp.shape[-1]])], axis=0)
134 | K = kern.kernel(P, full_cov=True)
135 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + diag)
136 | L = tf.linalg.cholesky(K)
137 | L = tf.tile(L[None], [kern.channels_out, 1, 1]) # TODO: Improve me
138 |
139 | # Draw samples using a location-scale transform
140 | spatial_in = Xnew.shape[-3:-1]
141 | spatial_out = kern.get_spatial_out(spatial_in)
142 | rvs = tf.random.normal(list(L.shape[:-1]) + [num_samples], dtype=floatx())
143 | draws = tf.transpose(L @ rvs) # [S, M + P, L]
144 | fz, fx = tf.split(draws, [len(Z), -1], axis=1)
145 |
146 | # Reorganize $f(X)$ as a 3d feature map
147 | fx_shape = [num_samples, Xnew.shape[0]] + spatial_out + [kern.channels_out]
148 | fx = tf.reshape(fx, fx_shape)
149 | return (fz, fx), L
150 |
151 |
152 | def avg_spatial_inner_product(a, b=None, batch_dims: int = 0):
153 | """
154 | Used to compute covariances of functions defined as sums over
155 | patch response functions in 4D image format [N, H, W, C]
156 | """
157 | _a = tf.reshape(a, list(a.shape[:-3]) + [-1, a.shape[-1]])
158 | if b is None:
159 | _b = _a
160 | else:
161 | _b = tf.reshape(b, list(b.shape[:-3]) + [-1, b.shape[-1]])
162 | batch_axes = 2 * [list(range(batch_dims))]
163 | prod = batch_tensordot(_a, _b, axes=[-1, -1], batch_axes=batch_axes)
164 | return tf.reduce_mean(prod, [-3, -1])
165 |
166 |
167 | def test_update_dense(default_config: NamedTuple = None):
168 | def decorator(subroutine):
169 | def main(config):
170 | assert config is not None, ValueError
171 | tf.random.set_seed(config.seed)
172 | gpflow_config.set_default_float(config.floatx)
173 | gpflow_config.set_default_jitter(config.jitter)
174 |
175 | X = tf.random.uniform([config.num_cond, config.input_dims], dtype=floatx())
176 | Xnew = tf.random.uniform([config.num_test, config.input_dims], dtype=floatx())
177 | for cls in SupportedBaseKernels:
178 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5)
179 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5)
180 | lenscales = tf.random.uniform(shape=[config.input_dims],
181 | minval=minval,
182 | maxval=maxval,
183 | dtype=floatx())
184 |
185 | kern = cls(lengthscales=lenscales, variance=config.kernel_variance)
186 | const = tf.random.normal([1], dtype=floatx())
187 |
188 | K = kern(X, full_cov=True)
189 | K = tf.linalg.set_diag(K, tf.linalg.diag_part(K) + config.noise_variance)
190 | L = tf.linalg.cholesky(K)
191 | y = L @ tf.random.normal([L.shape[-1], 1], dtype=floatx()) + const
192 |
193 | model = GPR(kernel=kern,
194 | noise_variance=config.noise_variance,
195 | data=(X, y),
196 | mean_function=mean_functions.Constant(c=const))
197 |
198 | mf, Sff = subroutine(config, model, Xnew)
199 | mg, Sgg = model.predict_f(Xnew, full_cov=True)
200 |
201 | tol = config.error_tol
202 | assert allclose(mf, mg, tol, tol)
203 | assert allclose(Sff, Sgg, tol, tol)
204 |
205 | return update_wrapper(partial(main, config=default_config), subroutine)
206 | return decorator
207 |
208 |
209 | def test_update_sparse(default_config: NamedTuple = None):
210 | def decorator(subroutine):
211 | def main(config):
212 | assert config is not None, ValueError
213 | tf.random.set_seed(config.seed)
214 | gpflow_config.set_default_float(config.floatx)
215 | gpflow_config.set_default_jitter(config.jitter)
216 |
217 | X = tf.random.uniform([config.num_test,config.input_dims], dtype=floatx())
218 | Z_shape = config.num_cond, config.input_dims
219 | for cls in SupportedBaseKernels:
220 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5)
221 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5)
222 | lenscales = tf.random.uniform(shape=[config.input_dims],
223 | minval=minval,
224 | maxval=maxval,
225 | dtype=floatx())
226 |
227 | q_sqrt = tf.zeros([1] + 2 * [config.num_cond], dtype=floatx())
228 | kern = cls(lengthscales=lenscales, variance=config.kernel_variance)
229 | Z = InducingPoints(tf.random.uniform(Z_shape, dtype=floatx()))
230 |
231 | const = tf.random.normal([1], dtype=floatx())
232 | model = SVGP(kernel=kern,
233 | likelihood=None,
234 | inducing_variable=Z,
235 | mean_function=mean_functions.Constant(c=const),
236 | q_sqrt=q_sqrt)
237 |
238 | mf, Sff = subroutine(config, model, X)
239 | mg, Sgg = model.predict_f(X, full_cov=True)
240 |
241 | tol = config.error_tol
242 | assert allclose(mf, mg, tol, tol)
243 | assert allclose(Sff, Sgg, tol, tol)
244 |
245 | return update_wrapper(partial(main, config=default_config), subroutine)
246 | return decorator
247 |
248 |
249 | def test_update_sparse_shared(default_config: NamedTuple = None):
250 | def decorator(subroutine):
251 | def main(config):
252 | assert config is not None, ValueError
253 | tf.random.set_seed(config.seed)
254 | gpflow_config.set_default_float(config.floatx)
255 | gpflow_config.set_default_jitter(config.jitter)
256 |
257 | X = tf.random.uniform([config.num_test,config.input_dims], dtype=floatx())
258 | Z_shape = config.num_cond, config.input_dims
259 | for cls in SupportedBaseKernels:
260 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5)
261 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5)
262 | lenscales = tf.random.uniform(shape=[config.input_dims],
263 | minval=minval,
264 | maxval=maxval,
265 | dtype=floatx())
266 |
267 | base = cls(lengthscales=lenscales, variance=config.kernel_variance)
268 | kern = kernels.SharedIndependent(base, output_dim=2)
269 |
270 | Z = SharedIndependentInducingVariables(
271 | InducingPoints(tf.random.uniform(Z_shape, dtype=floatx())))
272 | Kuu = covariances.Kuu(Z, kern, jitter=gpflow_config.default_jitter())
273 | q_sqrt = tf.stack([tf.zeros(2 * [config.num_cond], dtype=floatx()),
274 | tf.linalg.cholesky(Kuu)])
275 |
276 | const = tf.random.normal([2], dtype=floatx())
277 | model = SVGP(kernel=kern,
278 | likelihood=None,
279 | inducing_variable=Z,
280 | mean_function=mean_functions.Constant(c=const),
281 | q_sqrt=q_sqrt,
282 | whiten=False,
283 | num_latent_gps=2)
284 |
285 | mf, Sff = subroutine(config, model, X)
286 | mg, Sgg = model.predict_f(X, full_cov=True)
287 | tol = config.error_tol
288 | assert allclose(mf, mg, tol, tol)
289 | assert allclose(Sff, Sgg, tol, tol)
290 | return update_wrapper(partial(main, config=default_config), subroutine)
291 | return decorator
292 |
293 |
294 | def test_update_sparse_separate(default_config: NamedTuple = None):
295 | def decorator(subroutine):
296 | def main(config):
297 | assert config is not None, ValueError
298 | tf.random.set_seed(config.seed)
299 | gpflow_config.set_default_float(config.floatx)
300 | gpflow_config.set_default_jitter(config.jitter)
301 |
302 | X = tf.random.uniform([config.num_test,config.input_dims], dtype=floatx())
303 | allK = []
304 | allZ = []
305 | Z_shape = config.num_cond, config.input_dims
306 | for cls in SupportedBaseKernels:
307 | minval = config.rel_lengthscales_min * (config.input_dims ** 0.5)
308 | maxval = config.rel_lengthscales_max * (config.input_dims ** 0.5)
309 | lenscales = tf.random.uniform(shape=[config.input_dims],
310 | minval=minval,
311 | maxval=maxval,
312 | dtype=floatx())
313 |
314 | rel_variance = tf.random.uniform(shape=[],
315 | minval=0.9,
316 | maxval=1.1,
317 | dtype=floatx())
318 |
319 | allK.append(cls(lengthscales=lenscales,
320 | variance=config.kernel_variance * rel_variance))
321 |
322 | allZ.append(InducingPoints(tf.random.uniform(Z_shape, dtype=floatx())))
323 |
324 | kern = kernels.SeparateIndependent(allK)
325 | Z = SeparateIndependentInducingVariables(allZ)
326 |
327 | Kuu = covariances.Kuu(Z, kern, jitter=gpflow_config.default_jitter())
328 | q_sqrt = tf.linalg.cholesky(Kuu)\
329 | * tf.random.uniform(shape=[kern.num_latent_gps, 1, 1],
330 | minval=0.0,
331 | maxval=0.5,
332 | dtype=floatx())
333 |
334 | const = tf.random.normal([len(kern.kernels)], dtype=floatx())
335 | model = SVGP(kernel=kern,
336 | likelihood=None,
337 | inducing_variable=Z,
338 | mean_function=mean_functions.Constant(c=const),
339 | q_sqrt=q_sqrt,
340 | whiten=False,
341 | num_latent_gps=len(allK))
342 |
343 | mf, Sff = subroutine(config, model, X)
344 | mg, Sgg = model.predict_f(X, full_cov=True)
345 | tol = config.error_tol
346 | assert allclose(mf, mg, tol, tol)
347 | assert allclose(Sff, Sgg, tol, tol)
348 | return update_wrapper(partial(main, config=default_config), subroutine)
349 | return decorator
350 |
351 |
352 | def test_update_conv2d(default_config: NamedTuple = None):
353 | def decorator(subroutine):
354 | def main(config):
355 | assert config is not None, ValueError
356 | tf.random.set_seed(config.seed)
357 | gpflow_config.set_default_float(config.floatx)
358 | gpflow_config.set_default_jitter(config.jitter)
359 |
360 | X_shape = [config.num_test] + config.image_shape + [config.channels_in]
361 | X = tf.reshape(tf.range(tf.reduce_prod(X_shape), dtype=floatx()), X_shape)
362 | X /= tf.reduce_max(X)
363 |
364 | patch_len = config.channels_in * int(tf.reduce_prod(config.patch_shape))
365 | for base_cls in SupportedBaseKernels:
366 | minval = config.rel_lengthscales_min * (patch_len ** 0.5)
367 | maxval = config.rel_lengthscales_max * (patch_len ** 0.5)
368 | lenscales = tf.random.uniform(shape=[patch_len],
369 | minval=minval,
370 | maxval=maxval,
371 | dtype=floatx())
372 |
373 | base = base_cls(lengthscales=lenscales, variance=config.kernel_variance)
374 | Z_shape = [config.num_cond] + config.patch_shape + [config.channels_in]
375 | for cls in (kernels_ext.Conv2d, kernels_ext.Conv2dTranspose):
376 | kern = cls(kernel=base,
377 | image_shape=config.image_shape,
378 | patch_shape=config.patch_shape,
379 | channels_in=config.channels_in,
380 | channels_out=config.num_latent_gps,
381 | strides=config.strides,
382 | padding=config.padding,
383 | dilations=config.dilations)
384 |
385 | Z = InducingImages(tf.random.uniform(Z_shape, dtype=floatx()))
386 | q_sqrt = tf.linalg.cholesky(covariances.Kuu(Z, kern))
387 | q_sqrt *= tf.random.uniform([config.num_latent_gps, 1, 1],
388 | minval=0.0,
389 | maxval=0.5,
390 | dtype=floatx())
391 |
392 | # TODO: GPflow's SVGP class is not setup to support outputs defined
393 | # as spatial feature maps. For now, we content ourselves with
394 | # the following hack...
395 | const = tf.random.normal([config.num_latent_gps], dtype=floatx())
396 | mean_function = lambda x: const
397 |
398 | model = SVGP(kernel=kern,
399 | likelihood=None,
400 | mean_function=mean_function,
401 | inducing_variable=Z,
402 | q_sqrt=q_sqrt,
403 | whiten=False,
404 | num_latent_gps=config.num_latent_gps)
405 |
406 | mf, Sff = subroutine(config, model, X)
407 | mg, Sgg = model.predict_f(X, full_cov=True)
408 |
409 | tol = config.error_tol
410 | assert allclose(mf, mg, tol, tol)
411 | assert allclose(Sff, Sgg, tol, tol)
412 |
413 | return update_wrapper(partial(main, config=default_config), subroutine)
414 | return decorator
415 |
--------------------------------------------------------------------------------
/tests/sampling/updates/test_cg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from . import common
11 | from typing import Any, List, NamedTuple
12 | from gpflow import covariances
13 | from gpflow.models import GPR, SVGP
14 | from gpflow.config import default_jitter, default_float as floatx
15 | from gpflow_sampling.utils import swap_axes
16 | from gpflow_sampling.utils.linalg import get_default_preconditioner
17 | from gpflow_sampling.sampling.updates.cg_updates import cg as cg_update
18 |
19 |
20 | # ==============================================
21 | # test_cg
22 | # ==============================================
23 | class ConfigDense(NamedTuple):
24 | seed: int = 0
25 | floatx: Any = 'float64'
26 | jitter: float = 1e-6
27 |
28 | num_test: int = 128
29 | num_cond: int = 32
30 | num_samples: int = 16384
31 | shard_size: int = 1024
32 | input_dims: int = 3
33 |
34 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error
35 | rel_lengthscales_min: float = 0.05
36 | rel_lengthscales_max: float = 0.5
37 | noise_variance: float = 1e-2 # only used by GPR test
38 |
39 | @property
40 | def error_tol(self):
41 | return 4 * (self.num_samples ** -0.5)
42 |
43 |
44 | class ConfigConv2d(NamedTuple):
45 | seed: int = 1
46 | floatx: Any = 'float64'
47 | jitter: float = 1e-5
48 |
49 | num_test: int = 16
50 | num_cond: int = 32
51 | num_samples: int = 16384
52 | shard_size: int = 1024
53 |
54 | kernel_variance: float = 0.9
55 | rel_lengthscales_min: float = 0.5
56 | rel_lengthscales_max: float = 2.0
57 | num_latent_gps: int = 3
58 |
59 | # Convolutional settings
60 | channels_in: int = 2
61 | image_shape: List = [3, 3] # Keep these small, since common.sample_joint
62 | patch_shape: List = [2, 2] # becomes very expensive for Conv2dTranspose!
63 | strides: List = [1, 1]
64 | padding: str = "VALID"
65 | dilations: List = [1, 1]
66 |
67 | @property
68 | def error_tol(self):
69 | return 4 * (self.num_samples ** -0.5)
70 |
71 |
72 | def _test_cg_gpr(config: ConfigDense,
73 | model: GPR,
74 | Xnew: tf.Tensor) -> tf.Tensor:
75 | """
76 | Sample generation subroutine common to each unit test
77 | """
78 | # Prepare preconditioner for CG
79 | X, y = model.data
80 | Kff = model.kernel(X, full_cov=True)
81 | max_rank = config.num_cond//(2 if config.num_cond > 1 else 1)
82 | preconditioner = get_default_preconditioner(Kff,
83 | diag=model.likelihood.variance,
84 | max_rank=max_rank)
85 |
86 | count = 0
87 | L_joint = None
88 | samples = []
89 | while count < config.num_samples:
90 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$
91 | size = min(config.shard_size, config.num_samples - count)
92 |
93 | # Generate draws from the joint distribution $p(f(X), f(Xnew))$
94 | (f, fnew), L_joint = common.sample_joint(model.kernel,
95 | X,
96 | Xnew,
97 | num_samples=size,
98 | L=L_joint)
99 |
100 | # Solve for update functions
101 | update_fns = cg_update(model.kernel,
102 | X,
103 | y,
104 | f + model.mean_function(X),
105 | tol=1e-6,
106 | diag=model.likelihood.variance,
107 | max_iter=config.num_cond,
108 | preconditioner=preconditioner)
109 |
110 | samples.append(fnew + update_fns(Xnew))
111 | count += size
112 |
113 | samples = tf.concat(samples, axis=0)
114 | if model.mean_function is not None:
115 | samples += model.mean_function(Xnew)
116 | return samples
117 |
118 |
119 | def _test_cg_svgp(config: ConfigDense,
120 | model: SVGP,
121 | Xnew: tf.Tensor) -> tf.Tensor:
122 | """
123 | Sample generation subroutine common to each unit test
124 | """
125 | # Prepare preconditioner for CG
126 | Z = model.inducing_variable
127 | Kff = covariances.Kuu(Z, model.kernel, jitter=0)
128 | max_rank = config.num_cond//(2 if config.num_cond > 1 else 1)
129 | preconditioner = get_default_preconditioner(Kff,
130 | diag=default_jitter(),
131 | max_rank=max_rank)
132 |
133 | count = 0
134 | samples = []
135 | L_joint = None
136 | while count < config.num_samples:
137 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$
138 | size = min(config.shard_size, config.num_samples - count)
139 | shape = model.num_latent_gps, config.num_cond, size
140 | rvs = tf.random.normal(shape=shape, dtype=floatx())
141 | u = tf.transpose(model.q_sqrt @ rvs)
142 |
143 | # Generate draws from the joint distribution $p(f(X), g(Z))$
144 | (f, fnew), L_joint = common.sample_joint(model.kernel,
145 | Z,
146 | Xnew,
147 | num_samples=size,
148 | L=L_joint)
149 |
150 | # Solve for update functions
151 | update_fns = cg_update(model.kernel,
152 | Z,
153 | u,
154 | f,
155 | tol=1e-6,
156 | max_iter=config.num_cond,
157 | preconditioner=preconditioner)
158 |
159 | samples.append(fnew + update_fns(Xnew))
160 | count += size
161 |
162 | samples = tf.concat(samples, axis=0)
163 | if model.mean_function is not None:
164 | samples += model.mean_function(Xnew)
165 | return samples
166 |
167 |
168 | @common.test_update_dense(default_config=ConfigDense())
169 | def test_cg_dense(*args, **kwargs):
170 | f = _test_cg_gpr(*args, **kwargs) # [S, N, 1]
171 | mf = tf.reduce_mean(f, axis=0)
172 | res = tf.squeeze(f - mf, axis=-1)
173 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0]
174 | return mf, Sff
175 |
176 |
177 | @common.test_update_sparse(default_config=ConfigDense())
178 | def test_cg_sparse(*args, **kwargs):
179 | f = _test_cg_svgp(*args, **kwargs) # [S, N, 1]
180 | mf = tf.reduce_mean(f, axis=0)
181 | res = tf.squeeze(f - mf, axis=-1)
182 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0]
183 | return mf, Sff
184 |
185 |
186 | @common.test_update_sparse_shared(default_config=ConfigDense())
187 | def test_cg_sparse_shared(*args, **kwargs):
188 | f = _test_cg_svgp(*args, **kwargs) # [S, N, L]
189 | mf = tf.reduce_mean(f, axis=0)
190 | res = tf.transpose(f - mf)
191 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0]
192 | return mf, Sff
193 |
194 |
195 | @common.test_update_sparse_separate(default_config=ConfigDense())
196 | def test_cg_sparse_separate(*args, **kwargs):
197 | f = _test_cg_svgp(*args, **kwargs) # [S, N, L]
198 | mf = tf.reduce_mean(f, axis=0)
199 | res = tf.transpose(f - mf)
200 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0]
201 | return mf, Sff
202 |
203 |
204 | @common.test_update_conv2d(default_config=ConfigConv2d())
205 | def test_cg_conv2d(*args, **kwargs):
206 | f = swap_axes(_test_cg_svgp(*args, **kwargs), 0, -1) # [L, N, H, W, S]
207 | mf = tf.transpose(tf.reduce_mean(f, [-3, -2, -1])) # [N, L]
208 | res = f - tf.reduce_mean(f, axis=-1, keepdims=True)
209 | Sff = common.avg_spatial_inner_product(res, batch_dims=f.shape.ndims - 4)
210 | return mf, Sff/f.shape[-1] # [L, N, N]
211 |
--------------------------------------------------------------------------------
/tests/sampling/updates/test_exact.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from . import common
11 | from typing import Any, List, Union, NamedTuple
12 | from gpflow.config import default_jitter, default_float as floatx
13 | from gpflow.models import GPR, SVGP
14 | from gpflow_sampling import covariances
15 | from gpflow_sampling.utils import swap_axes
16 | from gpflow_sampling.sampling.updates import exact as exact_update
17 |
18 |
19 | # ==============================================
20 | # test_exact
21 | # ==============================================
22 | class ConfigDense(NamedTuple):
23 | seed: int = 0
24 | floatx: Any = 'float64'
25 | jitter: float = 1e-6
26 |
27 | num_test: int = 128
28 | num_cond: int = 32
29 | num_samples: int = 16384
30 | shard_size: int = 1024
31 | input_dims: int = 3
32 |
33 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error
34 | rel_lengthscales_min: float = 0.05
35 | rel_lengthscales_max: float = 0.5
36 | noise_variance: float = 1e-2 # only used by GPR test
37 |
38 | @property
39 | def error_tol(self):
40 | return 4 * (self.num_samples ** -0.5)
41 |
42 |
43 | class ConfigConv2d(NamedTuple):
44 | seed: int = 1
45 | floatx: Any = 'float64'
46 | jitter: float = 1e-5
47 |
48 | num_test: int = 128
49 | num_cond: int = 32
50 | num_samples: int = 16384
51 | shard_size: int = 1024
52 | kernel_variance: float = 0.9
53 | rel_lengthscales_min: float = 0.5
54 | rel_lengthscales_max: float = 2.0
55 | num_latent_gps: int = 3
56 |
57 | # Convolutional settings
58 | channels_in: int = 2
59 | image_shape: List = [3, 3] # Keep these small, since common.sample_joint
60 | patch_shape: List = [2, 2] # becomes very expensive for Conv2dTranspose!
61 | strides: List = [1, 1]
62 | padding: str = "VALID"
63 | dilations: List = [1, 1]
64 |
65 | @property
66 | def error_tol(self):
67 | return 4 * (self.num_samples ** -0.5)
68 |
69 |
70 | def _test_exact_gpr(config: ConfigDense,
71 | model: GPR,
72 | Xnew: tf.Tensor) -> tf.Tensor:
73 | """
74 | Sample generation subroutine common to each unit test
75 | """
76 | # Precompute Cholesky factor (optional)
77 | X, y = model.data
78 | Kyy = model.kernel(X, full_cov=True)
79 | Kyy = tf.linalg.set_diag(Kyy,
80 | tf.linalg.diag_part(Kyy) + model.likelihood.variance)
81 | Lyy = tf.linalg.cholesky(Kyy)
82 |
83 | count = 0
84 | L_joint = None
85 | samples = []
86 | while count < config.num_samples:
87 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$
88 | size = min(config.shard_size, config.num_samples - count)
89 |
90 | # Generate draws from the joint distribution $p(f(X), f(Xnew))$
91 | (f, fnew), L_joint = common.sample_joint(model.kernel,
92 | X,
93 | Xnew,
94 | num_samples=size,
95 | L=L_joint)
96 |
97 | # Solve for update functions
98 | update_fns = exact_update(model.kernel,
99 | X,
100 | y,
101 | f + model.mean_function(X),
102 | L=Lyy,
103 | diag=model.likelihood.variance)
104 |
105 | samples.append(fnew + update_fns(Xnew))
106 | count += size
107 |
108 | samples = tf.concat(samples, axis=0)
109 | if model.mean_function is not None:
110 | samples += model.mean_function(Xnew)
111 | return samples
112 |
113 |
114 | def _test_exact_svgp(config: Union[ConfigDense, ConfigConv2d],
115 | model: SVGP,
116 | Xnew: tf.Tensor) -> tf.Tensor:
117 | """
118 | Sample generation subroutine common to each unit test
119 | """
120 | # Precompute Cholesky factor (optional)
121 | Z = model.inducing_variable
122 | Kuu = covariances.Kuu(Z, model.kernel, jitter=default_jitter())
123 | Luu = tf.linalg.cholesky(Kuu)
124 |
125 | count = 0
126 | L_joint = None
127 | samples = []
128 | while count < config.num_samples:
129 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$
130 | size = min(config.shard_size, config.num_samples - count)
131 | shape = model.num_latent_gps, config.num_cond, size
132 | rvs = tf.random.normal(shape=shape, dtype=floatx())
133 | u = tf.transpose(model.q_sqrt @ rvs)
134 |
135 | # Generate draws from the joint distribution $p(f(X), g(Z))$
136 | (f, fnew), L_joint = common.sample_joint(model.kernel,
137 | Z,
138 | Xnew,
139 | num_samples=size,
140 | L=L_joint)
141 |
142 | # Solve for update functions
143 | update_fns = exact_update(model.kernel, Z, u, f, L=Luu)
144 | samples.append(fnew + update_fns(Xnew))
145 | count += size
146 |
147 | samples = tf.concat(samples, axis=0)
148 | if model.mean_function is not None:
149 | samples += model.mean_function(Xnew)
150 | return samples
151 |
152 |
153 | @common.test_update_dense(default_config=ConfigDense())
154 | def test_exact_dense(*args, **kwargs):
155 | f = _test_exact_gpr(*args, **kwargs) # [S, N, 1]
156 | mf = tf.reduce_mean(f, axis=0)
157 | res = tf.squeeze(f - mf, axis=-1)
158 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0]
159 | return mf, Sff
160 |
161 |
162 | @common.test_update_sparse(default_config=ConfigDense())
163 | def test_exact_sparse(*args, **kwargs):
164 | f = _test_exact_svgp(*args, **kwargs) # [S, N, 1]
165 | mf = tf.reduce_mean(f, axis=0)
166 | res = tf.squeeze(f - mf, axis=-1)
167 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0]
168 | return mf, Sff
169 |
170 |
171 | @common.test_update_sparse_shared(default_config=ConfigDense())
172 | def test_exact_sparse_shared(*args, **kwargs):
173 | f = _test_exact_svgp(*args, **kwargs) # [S, N, L]
174 | mf = tf.reduce_mean(f, axis=0)
175 | res = tf.transpose(f - mf)
176 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0]
177 | return mf, Sff
178 |
179 |
180 | @common.test_update_sparse_separate(default_config=ConfigDense())
181 | def test_exact_sparse_separate(*args, **kwargs):
182 | f = _test_exact_svgp(*args, **kwargs) # [S, N, L]
183 | mf = tf.reduce_mean(f, axis=0)
184 | res = tf.transpose(f - mf)
185 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0]
186 | return mf, Sff
187 |
188 |
189 | @common.test_update_conv2d(default_config=ConfigConv2d())
190 | def test_exact_conv2d(*args, **kwargs):
191 | f = swap_axes(_test_exact_svgp(*args, **kwargs), 0, -1) # [L, N, H, W, S]
192 | mf = tf.transpose(tf.reduce_mean(f, [-3, -2, -1])) # [N, L]
193 | res = f - tf.reduce_mean(f, axis=-1, keepdims=True)
194 | Sff = common.avg_spatial_inner_product(res, batch_dims=f.shape.ndims - 4)
195 | return mf, Sff/f.shape[-1] # [L, N, N]
196 |
--------------------------------------------------------------------------------
/tests/sampling/updates/test_linear.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | # ==============================================
5 | # Preamble
6 | # ==============================================
7 | # ---- Imports
8 | import tensorflow as tf
9 |
10 | from . import common
11 | from typing import Any, NamedTuple
12 | from gpflow.config import default_float as floatx
13 | from gpflow.models import SVGP
14 | from gpflow_sampling.bases import fourier as fourier_basis
15 | from gpflow_sampling.sampling.updates import linear as linear_update
16 |
17 |
18 | # ==============================================
19 | # test_linear
20 | # ==============================================
21 | class ConfigDense(NamedTuple):
22 | seed: int = 0
23 | floatx: Any = 'float64'
24 | jitter: float = 1e-6
25 |
26 | num_test: int = 128
27 | num_cond: int = 32
28 | num_bases: int = 4096
29 | num_samples: int = 16384
30 | shard_size: int = 1024
31 | input_dims: int = 3
32 |
33 | kernel_variance: float = 0.9 # keep this near 1 since it impacts MC error
34 | rel_lengthscales_min: float = 0.05
35 | rel_lengthscales_max: float = 0.5
36 |
37 | @property
38 | def error_tol(self):
39 | return 4 * (self.num_samples ** -0.5 + self.num_bases ** -0.5)
40 |
41 |
42 | def _test_linear_svgp(config: ConfigDense,
43 | model: SVGP,
44 | Xnew: tf.Tensor) -> tf.Tensor:
45 | """
46 | Sample generation subroutine common to each unit test
47 | """
48 | Z = model.inducing_variable
49 | count = 0
50 | basis = fourier_basis(model.kernel, num_bases=config.num_bases)
51 | L_joint = None
52 | samples = []
53 | while count < config.num_samples:
54 | # Sample $u ~ N(q_mu, q_sqrt q_sqrt^{T})$
55 | size = min(config.shard_size, config.num_samples - count)
56 | shape = model.num_latent_gps, config.num_cond, size
57 | rvs = tf.random.normal(shape=shape, dtype=floatx())
58 | u = tf.transpose(model.q_sqrt @ rvs)
59 |
60 | # Generate draws from the joint distribution $p(f(X), g(Z))$
61 | (f, fnew), L_joint = common.sample_joint(model.kernel,
62 | Z,
63 | Xnew,
64 | num_samples=size,
65 | L=L_joint)
66 |
67 | # Solve for update functions
68 | update_fns = linear_update(Z, u, f, basis=basis)
69 | samples.append(fnew + update_fns(Xnew))
70 | count += size
71 |
72 | samples = tf.concat(samples, axis=0)
73 | if model.mean_function is not None:
74 | samples += model.mean_function(Xnew)
75 | return samples
76 |
77 |
78 | @common.test_update_sparse(default_config=ConfigDense())
79 | def test_linear_sparse(*args, **kwargs):
80 | f = _test_linear_svgp(*args, **kwargs) # [S, N, 1]
81 | mf = tf.reduce_mean(f, axis=0)
82 | res = tf.squeeze(f - mf, axis=-1)
83 | Sff = tf.matmul(res, res, transpose_a=True)/f.shape[0]
84 | return mf, Sff
85 |
86 |
87 | @common.test_update_sparse_shared(default_config=ConfigDense())
88 | def test_linear_sparse_shared(*args, **kwargs):
89 | f = _test_linear_svgp(*args, **kwargs) # [S, N, L]
90 | mf = tf.reduce_mean(f, axis=0)
91 | res = tf.transpose(f - mf)
92 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0]
93 | return mf, Sff
94 |
95 |
96 | @common.test_update_sparse_separate(default_config=ConfigDense())
97 | def test_linear_sparse_separate(*args, **kwargs):
98 | f = _test_linear_svgp(*args, **kwargs) # [S, N, L]
99 | mf = tf.reduce_mean(f, axis=0)
100 | res = tf.transpose(f - mf)
101 | Sff = tf.matmul(res, res, transpose_b=True)/f.shape[0]
102 | return mf, Sff
103 |
--------------------------------------------------------------------------------