├── tests ├── __init__.py ├── test_mvn_logp.py ├── test_transition_KLs.py ├── test_KL.py └── test_sampling_schemes.py ├── .gitignore ├── kink_function_triptych.png ├── GPt ├── __init__.py ├── utils.py ├── KL.py ├── gpssm_multiseq.py ├── transitions.py ├── emissions.py ├── gpssm_models.py ├── ssm.py └── gpssm.py ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb_checkpoints 3 | .DS_STORE 4 | *.pyc 5 | -------------------------------------------------------------------------------- /kink_function_triptych.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ialong/GPt/HEAD/kink_function_triptych.png -------------------------------------------------------------------------------- /GPt/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .ssm import * 16 | from .gpssm_models import * 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPt 2 | A library for Recurrent Gaussian Process Models based on GPflow and TensorFlow. 3 | 4 | It implements all the inference methods contained in this paper: 5 | 6 | A.D. Ialongo, M. van der Wilk, J. Hensman, C.E. Rasmussen. [Overcoming Mean-Field Approximations in Recurrent Gaussian Process Models](https://arxiv.org/pdf/1906.05828.pdf). In *ICML*, 2019. 7 | 8 | ![kink_function_triptych](kink_function_triptych.png) 9 | 10 | ### Setup 11 | Install TensorFlow. 12 | 13 | Clone GPflow from https://github.com/ialong/GPflow. Select the `custom_multioutput` branch. 14 | 15 | Follow the instructions to install GPflow. 16 | 17 | Clone GPt. 18 | 19 | Example code and models are in the `examples/` directory. 20 | 21 | ### Running tests 22 | `python -m unittest discover` 23 | 24 | ## Applications and Support 25 | We encourage the use of this code for applications (both in the private and public sectors). 26 | 27 | Please tell us about your project by sending an email to the address below. Generally, for support feel free to email or open an issue on GitHub. 28 | 29 | `alex` `.` `ialongo` `at` `gmail` 30 | -------------------------------------------------------------------------------- /GPt/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | 19 | 20 | def block_indices(T, D): 21 | """ 22 | Returns the indices for diagonal and (upper) off-diagonal 23 | blocks in an unravelled matrix. Use as: 24 | diag_blocks = matrix.reshape(-1)[diag_inds].reshape(T,D,D) 25 | offdiag_blocks = matrix.reshape(-1)[offdiag_inds].reshape(T-1,D,D) 26 | """ 27 | TD = T * D 28 | inds = np.tile(np.arange(D), [TD, 1]) 29 | inds += np.arange(TD ** 2, step=TD)[:, None] 30 | diag_inds = inds + np.repeat(np.arange(TD, step=D), D)[:, None] 31 | offdiag_inds = diag_inds[:-D] + D 32 | return diag_inds.ravel(), offdiag_inds.ravel() 33 | 34 | 35 | def extract_cov_blocks(Xchol, T, D, return_off_diag_blocks=False): 36 | Xcov = tf.reshape(tf.matmul(Xchol, Xchol, transpose_b=True), [-1]) 37 | diag_inds, offdiag_inds = block_indices(T, D) 38 | diag_blocks = tf.reshape(tf.gather(Xcov, diag_inds), [T, D, D]) 39 | if return_off_diag_blocks: 40 | offdiag_blocks = tf.reshape(tf.gather(Xcov, offdiag_inds), [T - 1, D, D]) 41 | return diag_blocks, offdiag_blocks 42 | return diag_blocks 43 | -------------------------------------------------------------------------------- /GPt/KL.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import tensorflow as tf 17 | 18 | 19 | def KL(mu_diff, Q_chol, *, P_chol=None, P=None): 20 | """ 21 | :param mu_diff: (DxM) or [D] 22 | :param Q_chol: (DxMxM) or [DxD] 23 | :param P_chol: (None or M or MxM or DxMxM) or [None or DxD or D] 24 | :param P: (None or M or MxM or DxMxM) or [None or DxD or D] 25 | :return: scalar KL 26 | """ 27 | mu_ndims = mu_diff.shape.ndims 28 | assert mu_ndims is not None 29 | white = P_chol is None and P is None 30 | if not white: 31 | P_ndims = (P_chol if P is None else P).shape.ndims 32 | assert P_ndims is not None 33 | 34 | if white: 35 | trace = Q_chol 36 | mahalanobis = mu_diff 37 | elif P_ndims == 1: 38 | P_sqrt = tf.sqrt(tf.abs(P)) if P_chol is None else P_chol 39 | trace = Q_chol / (P_sqrt[:, None] if mu_ndims == 1 else P_sqrt[None, :, None]) 40 | mahalanobis = mu_diff / (P_sqrt if mu_ndims == 1 else P_sqrt[None, :]) 41 | log_det_P = 2. * tf.reduce_sum(tf.log(tf.abs(P_chol))) if P is None else tf.reduce_sum(tf.log(P)) 42 | if mu_ndims == 2: 43 | log_det_P *= tf.cast(tf.shape(mu_diff)[0], P_sqrt.dtype) 44 | else: 45 | D = tf.shape(mu_diff)[0] 46 | P_chol = tf.cholesky(P) if P_chol is None else P_chol # DxMxM or MxM or DxD 47 | tile_P = (P_ndims == 2) and (mu_ndims == 2) 48 | 49 | P_chol_full = tf.tile(P_chol[None, ...], [D, 1, 1]) if tile_P else P_chol # DxMxM or DxD 50 | trace = tf.matrix_triangular_solve(P_chol_full, Q_chol, lower=True) # DxMxM or DxD 51 | 52 | _mu_diff = mu_diff[:, :, None] if P_ndims == 3 else \ 53 | (tf.transpose(mu_diff) if mu_ndims == 2 else mu_diff[:, None]) # DxMx1 or MxD or Dx1 54 | mahalanobis = tf.matrix_triangular_solve(P_chol, _mu_diff, lower=True) # DxMx1 or MxD or Dx1 55 | 56 | log_det_P = 2. * tf.reduce_sum(tf.log(tf.abs(tf.matrix_diag_part(P_chol)))) 57 | if tile_P: 58 | log_det_P *= tf.cast(D, P_chol.dtype) 59 | 60 | trace = tf.reduce_sum(tf.square(trace)) 61 | mahalanobis = tf.reduce_sum(tf.square(mahalanobis)) 62 | constant = tf.cast(tf.size(mu_diff, out_type=tf.int64), dtype=mu_diff.dtype) 63 | log_det_Q = 2. * tf.reduce_sum(tf.log(tf.abs(tf.matrix_diag_part(Q_chol)))) 64 | double_KL = trace + mahalanobis - constant - log_det_Q 65 | 66 | if not white: 67 | double_KL += log_det_P 68 | 69 | return 0.5 * double_KL 70 | 71 | 72 | def KL_samples(mu_diff, Q_chol, P_chol=None): 73 | """ 74 | :param mu_diff: NxSxD or NxD 75 | :param Q_chol: NxDxD or NxD 76 | :param P_chol: None or DxD or D 77 | :return: N 78 | """ 79 | D = tf.shape(mu_diff)[-1] 80 | assert mu_diff.shape.ndims is not None 81 | assert Q_chol.shape.ndims is not None 82 | diag_Q = Q_chol.shape.ndims == 2 83 | 84 | white = P_chol is None 85 | if not white: 86 | assert P_chol.shape.ndims is not None 87 | diag_P = P_chol.shape.ndims == 1 88 | 89 | if white: 90 | trace = Q_chol 91 | mahalanobis = mu_diff 92 | elif diag_P: 93 | trace = Q_chol / (P_chol if diag_Q else P_chol[:, None]) 94 | mahalanobis = mu_diff / P_chol 95 | log_det_P = 2. * tf.reduce_sum(tf.log(tf.abs(P_chol))) 96 | else: 97 | N = tf.shape(mu_diff)[0] 98 | trace = tf.matrix_triangular_solve(tf.tile(P_chol[None, ...], [N, 1, 1]), 99 | tf.matrix_diag(Q_chol) if diag_Q else Q_chol, lower=True) # NxDxD 100 | 101 | if mu_diff.shape.ndims == 2: 102 | mahalanobis = tf.matrix_triangular_solve(P_chol, tf.transpose(mu_diff), lower=True) # DxN 103 | else: 104 | mahalanobis = tf.transpose(tf.reshape(mu_diff, [-1, D])) 105 | mahalanobis = tf.matrix_triangular_solve(P_chol, mahalanobis, lower=True) # Dx(N*S) 106 | mahalanobis = tf.reshape(mahalanobis, [D, N, -1]) # DxNxS 107 | 108 | log_det_P = 2. * tf.reduce_sum(tf.log(tf.abs(tf.matrix_diag_part(P_chol)))) 109 | 110 | if white or diag_P: 111 | mahalanobis = tf.reduce_sum(tf.square(mahalanobis), -1) 112 | else: 113 | mahalanobis = tf.reduce_sum(tf.square(mahalanobis), 0) 114 | if mu_diff.shape.ndims == 3: 115 | mahalanobis = tf.reduce_mean(mahalanobis, -1) 116 | 117 | trace = tf.square(trace) 118 | if (not diag_Q) or (diag_Q and not white and not diag_P): 119 | trace = tf.reduce_sum(trace, -1) 120 | trace = tf.reduce_sum(trace, -1) 121 | 122 | constant = tf.cast(D, dtype=mu_diff.dtype) 123 | log_det_Q = 2. * tf.reduce_sum(tf.log(tf.abs( 124 | Q_chol if diag_Q else tf.matrix_diag_part(Q_chol))), -1) 125 | double_KL = trace + mahalanobis - constant - log_det_Q 126 | 127 | if not white: 128 | double_KL += log_det_P 129 | 130 | return 0.5 * double_KL 131 | -------------------------------------------------------------------------------- /tests/test_mvn_logp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | from numpy.testing import assert_allclose 18 | from numpy.random import randn as mvn 19 | from scipy.stats import multivariate_normal 20 | from numpy.linalg import cholesky 21 | import tensorflow as tf 22 | import gpflow as gp 23 | from gpflow.test_util import GPflowTestCase 24 | from gpflow.logdensities import mvn_logp, diag_mvn_logp 25 | 26 | 27 | def compare_logps(sess, mvn_fn, x, mu, L): 28 | cov_sp = np.eye(x.shape[-1]) if L is None else (L @ L.T if L.ndim == 2 else np.diag(L ** 2.)) 29 | if mu.ndim == 1: 30 | sp_logp = multivariate_normal.logpdf(x=x, mean=mu, cov=cov_sp) 31 | elif x.ndim == 2: # x is TxD and mu is TxD 32 | sp_logp = np.zeros(x.shape[0]) 33 | for t in range(x.shape[0]): 34 | sp_logp[t] = multivariate_normal.logpdf(x=x[t], mean=mu[t], cov=cov_sp) 35 | elif mu.ndim == 2: # x is NxTxD and mu is TxD 36 | sp_logp = np.zeros(x.shape[:-1]) 37 | for n in range(x.shape[0]): 38 | for t in range(x.shape[1]): 39 | sp_logp[n, t] = multivariate_normal.logpdf(x=x[n, t], mean=mu[t], cov=cov_sp) 40 | elif mu.ndim == 3: # x is NxTxD and mu is NxTxD 41 | sp_logp = np.zeros(x.shape[:-1]) 42 | for n in range(x.shape[0]): 43 | for t in range(x.shape[1]): 44 | sp_logp[n, t] = multivariate_normal.logpdf(x=x[n, t], mean=mu[n, t], cov=cov_sp) 45 | 46 | d = x - mu 47 | d = d if mvn_fn is diag_mvn_logp \ 48 | else d[:, None] if d.ndim == 1 \ 49 | else d.T if d.ndim == 2 \ 50 | else np.transpose(d, [2, 0, 1]) 51 | 52 | d_tf = tf.placeholder(gp.settings.float_type, shape=d.shape if d.ndim == 2 else None) 53 | L_tf = None if L is None else tf.placeholder(gp.settings.float_type) 54 | feed_dict = {d_tf: d} 55 | if L is not None: feed_dict[L_tf] = L 56 | gp_logp = sess.run(mvn_fn(d_tf, L_tf), feed_dict) 57 | 58 | assert_allclose(sp_logp, gp_logp) 59 | 60 | 61 | class MvnLogPTest(GPflowTestCase): 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | np.random.seed(0) 65 | tf.set_random_seed(0) 66 | self.S, self.T, self.D = 13, 11, 4 67 | 68 | def prepare_L(self): 69 | L = mvn(self.D, self.D) 70 | L = cholesky(L @ L.T) 71 | return L 72 | 73 | def test_mvn_logp(self): 74 | L = self.prepare_L() 75 | with self.test_context() as sess: 76 | compare_logps(sess, mvn_logp, mvn(self.D), mvn(self.D), L) 77 | compare_logps(sess, mvn_logp, mvn(self.D), mvn(self.D), None) 78 | compare_logps(sess, mvn_logp, mvn(self.T, self.D), mvn(self.D), L) 79 | compare_logps(sess, mvn_logp, mvn(self.T, self.D), mvn(self.D), None) 80 | compare_logps(sess, mvn_logp, mvn(self.T, self.D), mvn(self.T, self.D), L) 81 | compare_logps(sess, mvn_logp, mvn(self.T, self.D), mvn(self.T, self.D), None) 82 | compare_logps(sess, mvn_logp, mvn(self.S, self.T, self.D), mvn(self.D), L) 83 | compare_logps(sess, mvn_logp, mvn(self.S, self.T, self.D), mvn(self.D), None) 84 | compare_logps(sess, mvn_logp, mvn(self.S, self.T, self.D), mvn(self.T, self.D), L) 85 | compare_logps(sess, mvn_logp, mvn(self.S, self.T, self.D), mvn(self.T, self.D), None) 86 | compare_logps(sess, mvn_logp, mvn(self.S, self.T, self.D), mvn(self.S, self.T, self.D), L) 87 | compare_logps(sess, mvn_logp, mvn(self.S, self.T, self.D), mvn(self.S, self.T, self.D), None) 88 | 89 | def test_diag_mvn_logp(self): 90 | L_diag = np.diag(self.prepare_L()) 91 | with self.test_context() as sess: 92 | 93 | compare_logps(sess, diag_mvn_logp, mvn(self.D), mvn(self.D), L_diag) 94 | compare_logps(sess, diag_mvn_logp, mvn(self.D), mvn(self.D), None) 95 | compare_logps(sess, diag_mvn_logp, mvn(self.T, self.D), mvn(self.D), L_diag) 96 | compare_logps(sess, diag_mvn_logp, mvn(self.T, self.D), mvn(self.D), None) 97 | compare_logps(sess, diag_mvn_logp, mvn(self.T, self.D), mvn(self.T, self.D), L_diag) 98 | compare_logps(sess, diag_mvn_logp, mvn(self.T, self.D), mvn(self.T, self.D), None) 99 | compare_logps(sess, diag_mvn_logp, mvn(self.S, self.T, self.D), mvn(self.D), L_diag) 100 | compare_logps(sess, diag_mvn_logp, mvn(self.S, self.T, self.D), mvn(self.D), None) 101 | compare_logps(sess, diag_mvn_logp, mvn(self.S, self.T, self.D), mvn(self.T, self.D), L_diag) 102 | compare_logps(sess, diag_mvn_logp, mvn(self.S, self.T, self.D), mvn(self.T, self.D), None) 103 | compare_logps(sess, diag_mvn_logp, mvn(self.S, self.T, self.D), mvn(self.S, self.T, self.D), L_diag) 104 | compare_logps(sess, diag_mvn_logp, mvn(self.S, self.T, self.D), mvn(self.S, self.T, self.D), None) 105 | 106 | 107 | if __name__ == '__main__': 108 | tf.test.main() -------------------------------------------------------------------------------- /tests/test_transition_KLs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | from numpy.testing import assert_allclose 18 | from numpy.random import randn as mvn 19 | from numpy.linalg import cholesky 20 | import tensorflow as tf 21 | import gpflow as gp 22 | from gpflow.test_util import GPflowTestCase 23 | from gpflow import mean_functions as mean_fns 24 | from gpflow.conditionals import conditional 25 | from GPt.gpssm import GPSSM 26 | from GPt.KL import KL_samples 27 | 28 | 29 | class TransitionKLsTest(GPflowTestCase): 30 | def __init__(self, *args, **kwargs): 31 | super().__init__(*args, **kwargs) 32 | self.seed = 0 33 | np.random.seed(self.seed) 34 | tf.set_random_seed(self.seed) 35 | self.T, self.D, self.E = 11, 3, 2 36 | self.n_samples, self.n_ind_pts = int(1e5), 4 37 | self.white = True 38 | 39 | def prepare(self): 40 | Y = np.random.randn(self.T, self.D) 41 | Q_diag = np.random.randn(self.E) ** 2. 42 | kern = [gp.kernels.RBF(self.E, ARD=True) for _ in range(self.E)] 43 | for k in kern: k.lengthscales = np.random.rand(self.E) 44 | for k in kern: k.variance = np.random.rand() 45 | Z = np.random.randn(self.E, self.n_ind_pts, self.E) 46 | mean_fn = mean_fns.Linear(np.random.randn(self.E, self.E), np.random.randn(self.E)) 47 | Umu = np.random.randn(self.E, self.n_ind_pts) 48 | Ucov_chol = np.random.randn(self.E, self.n_ind_pts, self.n_ind_pts) 49 | Ucov_chol = np.linalg.cholesky(np.matmul(Ucov_chol, np.transpose(Ucov_chol, [0, 2, 1]))) 50 | qx1_mu = np.random.randn(self.E) 51 | qx1_cov = np.random.randn(self.E, self.E) 52 | qx1_cov = qx1_cov @ qx1_cov.T 53 | As = np.random.randn(self.T-1, self.E) 54 | bs = np.random.randn(self.T-1, self.E) 55 | Ss = np.random.randn(self.T-1, self.E) ** 2. 56 | m = GPSSM(self.E, Y, inputs=None, emissions=None, px1_mu=None, px1_cov=None, 57 | kern=kern, Z=Z, n_ind_pts=None, mean_fn=mean_fn, 58 | Q_diag=Q_diag, Umu=Umu, Ucov_chol=Ucov_chol, 59 | qx1_mu=qx1_mu, qx1_cov=qx1_cov, As=As, bs=bs, Ss=Ss, n_samples=self.n_samples, seed=self.seed) 60 | _ = m.compute_log_likelihood() 61 | return m 62 | 63 | def test_transition_KLs_MC(self): 64 | with self.test_context() as sess: 65 | shape = [self.T - 1, self.n_samples, self.E] 66 | X_samples = tf.placeholder(gp.settings.float_type, shape=shape) 67 | feed_dict = {X_samples: np.random.randn(*shape)} 68 | 69 | m = self.prepare() 70 | f_mus, f_vars = conditional(tf.reshape(X_samples, [-1, self.E]), 71 | m.Z, m.kern, m.Umu.constrained_tensor, white=self.white, 72 | q_sqrt=m.Ucov_chol.constrained_tensor) 73 | f_mus += m.mean_fn(tf.reshape(X_samples, [-1, self.E])) 74 | 75 | gpssm_KLs = m._build_transition_KLs(tf.reshape(f_mus, [m.T - 1, m.n_samples, m.latent_dim]), 76 | tf.reshape(f_vars, [m.T - 1, m.n_samples, m.latent_dim])) 77 | 78 | f_samples = f_mus + tf.sqrt(f_vars) * tf.random_normal( 79 | [(self.T - 1) * self.n_samples, self.E], dtype=gp.settings.float_type, seed=self.seed) 80 | 81 | q_mus = m.As.constrained_tensor[:, None, :] * tf.reshape(f_samples, shape) \ 82 | + m.bs.constrained_tensor[:, None, :] 83 | q_mus = tf.reshape(q_mus, [-1, self.E]) 84 | q_covs = tf.reshape(tf.tile( 85 | m.S_chols.constrained_tensor[:, None, :], [1, self.n_samples, 1]), [-1, self.E]) 86 | mc_KLs = KL_samples(q_mus - f_samples, Q_chol=q_covs, P_chol=m.Q_sqrt.constrained_tensor) 87 | mc_KLs = tf.reduce_mean(tf.reshape(mc_KLs, shape[:-1]), -1) 88 | 89 | assert_allclose(*sess.run([gpssm_KLs, mc_KLs], feed_dict=feed_dict), rtol=0.5*1e-2) 90 | 91 | def test_transition_KLs_extra_trace(self): 92 | with self.test_context() as sess: 93 | shape = [self.T - 1, self.n_samples, self.E] 94 | X_samples = tf.placeholder(gp.settings.float_type, shape=shape) 95 | feed_dict = {X_samples: np.random.randn(*shape)} 96 | 97 | m = self.prepare() 98 | f_mus, f_vars = conditional(tf.reshape(X_samples, [-1, self.E]), 99 | m.Z, m.kern, m.Umu.constrained_tensor, white=self.white, 100 | q_sqrt=m.Ucov_chol.constrained_tensor) 101 | f_mus += m.mean_fn(tf.reshape(X_samples, [-1, self.E])) 102 | 103 | gpssm_KLs = m._build_transition_KLs(tf.reshape(f_mus, [m.T - 1, m.n_samples, m.latent_dim]), 104 | tf.reshape(f_vars, [m.T - 1, m.n_samples, m.latent_dim])) 105 | 106 | q_mus = m.As.constrained_tensor[:, None, :] * tf.reshape(f_mus, shape) \ 107 | + m.bs.constrained_tensor[:, None, :] 108 | q_mus = tf.reshape(q_mus, [-1, self.E]) 109 | q_covs = tf.reshape(tf.tile( 110 | m.S_chols.constrained_tensor[:, None, :], [1, self.n_samples, 1]), [-1, self.E]) 111 | trace_KLs = KL_samples(q_mus - f_mus, Q_chol=q_covs, P_chol=m.Q_sqrt.constrained_tensor) 112 | trace_KLs = tf.reduce_mean(tf.reshape(trace_KLs, shape[:-1]), -1) 113 | 114 | trace_KLs += 0.5 * tf.reduce_mean(tf.reduce_sum( 115 | (tf.square(m.As.constrained_tensor - 1.)[:, None, :] * tf.reshape(f_vars, shape)) 116 | / tf.square(m.Q_sqrt.constrained_tensor), -1), -1) 117 | 118 | assert_allclose(*sess.run([gpssm_KLs, trace_KLs], feed_dict=feed_dict)) 119 | 120 | def test_factorized_transition_KLs(self): 121 | def KL_sampled_mu_and_Q_diag_P(mu_diff, Q_chol, P_chol): 122 | """ 123 | :param mu_diff: NxSxD 124 | :param Q_chol: NxSxD 125 | :param P_chol: D 126 | :return: N 127 | """ 128 | D = tf.shape(mu_diff)[-1] 129 | assert mu_diff.shape.ndims is not None 130 | assert Q_chol.shape.ndims is not None 131 | assert P_chol.shape.ndims is not None 132 | 133 | mahalanobis = mu_diff / P_chol 134 | mahalanobis = tf.reduce_sum(tf.square(mahalanobis), -1) 135 | mahalanobis = tf.reduce_mean(mahalanobis, -1) 136 | 137 | trace = Q_chol / P_chol 138 | trace = tf.reduce_sum(tf.square(trace), -1) 139 | trace = tf.reduce_mean(trace, -1) 140 | 141 | constant = tf.cast(D, dtype=mu_diff.dtype) 142 | log_det_P = 2. * tf.reduce_sum(tf.log(tf.abs(P_chol))) 143 | log_det_Q = 2. * tf.reduce_mean(tf.reduce_sum(tf.log(tf.abs(Q_chol)), -1), -1) 144 | double_KL = trace + mahalanobis - constant + log_det_P - log_det_Q 145 | return 0.5 * double_KL 146 | 147 | with self.test_context() as sess: 148 | m = self.prepare() 149 | with gp.params_as_tensors_for(m): 150 | _, f_mus, f_vars, xcov_chols = sess.run(m._build_linear_time_q_sample( 151 | return_f_moments=True, return_x_cov_chols=True, sample_f=False, sample_u=False)) 152 | 153 | gpssm_KLs = sess.run(m._build_transition_KLs(tf.constant(f_mus), tf.constant(f_vars))) 154 | 155 | diff_term = tf.reduce_sum(tf.reduce_mean(tf.constant(f_vars), -2) * m.As / tf.square(m.Q_sqrt), -1) 156 | diff_term += tf.reduce_sum(tf.log(tf.abs(m.S_chols)), 1) 157 | diff_term -= tf.reduce_sum(tf.reduce_mean(tf.log(tf.abs(tf.constant(xcov_chols))), -2), -1) 158 | 159 | gpssm_KLs += sess.run(diff_term) 160 | 161 | gpssm_factorized_KLs = sess.run(m._build_factorized_transition_KLs( 162 | tf.constant(f_mus), tf.constant(f_vars), tf.constant(xcov_chols))) 163 | 164 | assert_allclose(gpssm_KLs, gpssm_factorized_KLs) 165 | 166 | gpssm_factorized_KLs_2 = sess.run(KL_sampled_mu_and_Q_diag_P( 167 | m.As[:, None, :] * f_mus + m.bs[:, None, :] - f_mus, 168 | tf.constant(xcov_chols), 169 | m.Q_sqrt)) 170 | gpssm_factorized_KLs_2 += 0.5 * np.mean(np.sum(f_vars / np.square(sess.run(m.Q_sqrt)), -1), -1) 171 | 172 | assert_allclose(gpssm_factorized_KLs, gpssm_factorized_KLs_2) 173 | 174 | 175 | if __name__ == '__main__': 176 | tf.test.main() 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tests/test_KL.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | from numpy.testing import assert_allclose 18 | from numpy.random import randn as mvn 19 | from numpy.linalg import cholesky 20 | import tensorflow as tf 21 | import gpflow as gp 22 | from gpflow.test_util import GPflowTestCase 23 | from gpflow.kullback_leiblers import gauss_kl 24 | from GPt.KL import KL, KL_samples 25 | 26 | 27 | FLOAT_TYPE = gp.settings.float_type 28 | 29 | 30 | def choleskify(mats): 31 | ret = [] 32 | for m in mats: 33 | if m.ndim == 1: 34 | m = np.abs(m) 35 | else: 36 | m = cholesky(m @ (m.T if m.ndim == 2 else np.transpose(m, [0, 2, 1]))) 37 | ret.append(m) 38 | return ret 39 | 40 | 41 | def compare_KLs(sess, feed_dict, mu, Q_chol, P_chols): 42 | mu_gpflow = tf.transpose(mu) if mu.shape.ndims == 2 else mu[:, None] 43 | Q_chol_gpflow = Q_chol if Q_chol.shape.ndims == 3 else Q_chol[None, ...] 44 | 45 | KL_gpflow = sess.run(gauss_kl(q_mu=mu_gpflow, q_sqrt=Q_chol_gpflow, K=None), feed_dict=feed_dict) 46 | KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=None, P=None), feed_dict=feed_dict) 47 | assert_allclose(KL_gpflow, KL_gpt) 48 | 49 | for P_chol in P_chols: 50 | P_ndims = P_chol.shape.ndims 51 | P = tf.square(P_chol) if P_ndims == 1 else tf.matmul(P_chol, P_chol, transpose_b=True) 52 | 53 | KL_gpflow = sess.run(gauss_kl(q_mu=mu_gpflow, q_sqrt=Q_chol_gpflow, 54 | K=tf.diag(P) if P_ndims == 1 else P), feed_dict=feed_dict) 55 | 56 | KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=P_chol, P=None), feed_dict=feed_dict) 57 | assert_allclose(KL_gpflow, KL_gpt) 58 | KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=None, P=P), feed_dict=feed_dict) 59 | assert_allclose(KL_gpflow, KL_gpt) 60 | KL_gpt = sess.run(KL(mu_diff=mu, Q_chol=Q_chol, P_chol=P_chol, P=P), feed_dict=feed_dict) 61 | assert_allclose(KL_gpflow, KL_gpt) 62 | 63 | 64 | class KLTest(GPflowTestCase): 65 | def __init__(self, *args, **kwargs): 66 | super().__init__(*args, **kwargs) 67 | np.random.seed(0) 68 | tf.set_random_seed(0) 69 | self.D, self.M, self.n_samples = 11, 53, 7 70 | 71 | def prepare(self): 72 | mus = [mvn(self.D), mvn(self.M, self.D), mvn(self.D, self.M), mvn(self.M, self.n_samples, self.D)] 73 | Q_chols = [mvn(self.D, self.D), mvn(self.D, self.M, self.M), mvn(self.M, self.D, self.D)] 74 | Q_chols = choleskify(Q_chols) 75 | Q_chols.append(np.abs(mvn(self.M, self.D))) 76 | P_chols = [mvn(self.D), mvn(self.D, self.D), mvn(self.M), mvn(self.M, self.M), mvn(self.D, self.M, self.M)] 77 | P_chols = choleskify(P_chols) 78 | P_chols.append(np.abs(mvn(self.D, self.M))) 79 | mus = {a.shape: a for a in mus} 80 | Q_chols = {a.shape: a for a in Q_chols} 81 | P_chols = {a.shape: a for a in P_chols} 82 | return mus, Q_chols, P_chols 83 | 84 | def get_feed_dict(self, mus_tf, Qs_tf, Ps_tf): 85 | shape = lambda ph: tuple(np.array(ph.shape, dtype=int)) 86 | mus, Qs, Ps = self.prepare() 87 | feed_dict = dict() 88 | for mu_tf in mus_tf: 89 | feed_dict[mu_tf] = mus[shape(mu_tf)] 90 | for Q_tf in Qs_tf: 91 | feed_dict[Q_tf] = Qs[shape(Q_tf)] 92 | for P_tf in Ps_tf: 93 | feed_dict[P_tf] = Ps[shape(P_tf)] 94 | return feed_dict 95 | 96 | def test_KL_mu_D_Q_DxD(self): 97 | with self.test_context() as sess: 98 | mu = tf.placeholder(FLOAT_TYPE, shape=(self.D,)) 99 | Q_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.D)) 100 | P_chol_diag = tf.placeholder(FLOAT_TYPE, shape=(self.D)) 101 | P_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.D)) 102 | 103 | feed_dict = self.get_feed_dict([mu], [Q_chol], [P_chol_diag, P_chol]) 104 | 105 | compare_KLs(sess, feed_dict, mu, Q_chol, [P_chol_diag, P_chol]) 106 | 107 | def test_KL_mu_MxD_Q_DxMxM(self): 108 | with self.test_context() as sess: 109 | mu = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M)) 110 | Q_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) 111 | P_chol_1D = tf.placeholder(FLOAT_TYPE, shape=(self.M)) 112 | P_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.M)) 113 | P_chol_3D = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) 114 | 115 | feed_dict = self.get_feed_dict([mu], [Q_chol], [P_chol_1D, P_chol_2D, P_chol_3D]) 116 | 117 | compare_KLs(sess, feed_dict, mu, Q_chol, [P_chol_1D, P_chol_2D, P_chol_3D]) 118 | 119 | def test_KL_samples_mu_2D(self): 120 | with self.test_context() as sess: 121 | mu = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D)) 122 | Q_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D)) 123 | Q_chol_3D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D, self.D)) 124 | P_chol_1D = tf.placeholder(FLOAT_TYPE, shape=(self.D)) 125 | P_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.D)) 126 | 127 | feed_dict = self.get_feed_dict([mu], [Q_chol_2D, Q_chol_3D], [P_chol_1D, P_chol_2D]) 128 | 129 | KL_s_1 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_2D, None)), feed_dict) 130 | KL_s_2 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_2D, P_chol_1D)), feed_dict) 131 | KL_s_3 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_2D, P_chol_2D)), feed_dict) 132 | KL_s_4 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_3D, None)), feed_dict) 133 | KL_s_5 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_3D, P_chol_1D)), feed_dict) 134 | KL_s_6 = sess.run(tf.reduce_sum(KL_samples(mu, Q_chol_3D, P_chol_2D)), feed_dict) 135 | 136 | KL_1 = sess.run(KL(mu, tf.matrix_diag(Q_chol_2D), P_chol=None), feed_dict) 137 | KL_2 = sess.run(KL(mu, tf.matrix_diag(Q_chol_2D), P_chol=tf.diag(P_chol_1D)), feed_dict) 138 | KL_3 = sess.run(KL(mu, tf.matrix_diag(Q_chol_2D), P_chol=P_chol_2D), feed_dict) 139 | KL_4 = sess.run(KL(mu, Q_chol_3D, P_chol=None), feed_dict) 140 | KL_5 = sess.run(KL(mu, Q_chol_3D, P_chol=tf.diag(P_chol_1D)), feed_dict) 141 | KL_6 = sess.run(KL(mu, Q_chol_3D, P_chol=P_chol_2D), feed_dict) 142 | 143 | assert_allclose(KL_s_1, KL_1) 144 | assert_allclose(KL_s_2, KL_2) 145 | assert_allclose(KL_s_3, KL_3) 146 | assert_allclose(KL_s_4, KL_4) 147 | assert_allclose(KL_s_5, KL_5) 148 | assert_allclose(KL_s_6, KL_6) 149 | 150 | def test_KL_samples_mu_3D(self): 151 | with self.test_context() as sess: 152 | mu_3D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.n_samples, self.D)) 153 | Q_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D)) 154 | Q_chol_3D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.D, self.D)) 155 | P_chol_1D = tf.placeholder(FLOAT_TYPE, shape=(self.D)) 156 | P_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.D)) 157 | 158 | feed_dict = self.get_feed_dict([mu_3D], [Q_chol_2D, Q_chol_3D], [P_chol_1D, P_chol_2D]) 159 | 160 | KL_s_1 = sess.run(KL_samples(mu_3D, Q_chol_2D, None), feed_dict) 161 | KL_s_2 = sess.run(KL_samples(mu_3D, Q_chol_2D, P_chol_1D), feed_dict) 162 | KL_s_3 = sess.run(KL_samples(mu_3D, Q_chol_2D, P_chol_2D), feed_dict) 163 | KL_s_4 = sess.run(KL_samples(mu_3D, Q_chol_3D, None), feed_dict) 164 | KL_s_5 = sess.run(KL_samples(mu_3D, Q_chol_3D, P_chol_1D), feed_dict) 165 | KL_s_6 = sess.run(KL_samples(mu_3D, Q_chol_3D, P_chol_2D), feed_dict) 166 | 167 | KL_mu_only_arg = lambda Q_chol, P_chol: lambda mu: KL_samples(mu, Q_chol, P_chol=P_chol) 168 | map_schema = lambda Q_chol, P_chol: \ 169 | tf.reduce_mean(tf.map_fn(KL_mu_only_arg(Q_chol, P_chol), tf.transpose(mu_3D, [1,0,2])), 0) 170 | 171 | KL_map_1 = sess.run(map_schema(Q_chol_2D, None), feed_dict) 172 | KL_map_2 = sess.run(map_schema(Q_chol_2D, P_chol_1D), feed_dict) 173 | KL_map_3 = sess.run(map_schema(Q_chol_2D, P_chol_2D), feed_dict) 174 | KL_map_4 = sess.run(map_schema(Q_chol_3D, None), feed_dict) 175 | KL_map_5 = sess.run(map_schema(Q_chol_3D, P_chol_1D), feed_dict) 176 | KL_map_6 = sess.run(map_schema(Q_chol_3D, P_chol_2D), feed_dict) 177 | 178 | assert_allclose(KL_s_1, KL_map_1) 179 | assert_allclose(KL_s_2, KL_map_2) 180 | assert_allclose(KL_s_3, KL_map_3) 181 | assert_allclose(KL_s_4, KL_map_4) 182 | assert_allclose(KL_s_5, KL_map_5) 183 | assert_allclose(KL_s_6, KL_map_6) 184 | 185 | def test_whitening(self): 186 | with self.test_context() as sess: 187 | mu = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M)) 188 | Q_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) 189 | P_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) 190 | 191 | feed_dict = self.get_feed_dict([mu], [Q_chol], [P_chol]) 192 | 193 | KL_black = sess.run(KL(mu, Q_chol, P_chol=P_chol), feed_dict) 194 | KL_white = sess.run(KL(tf.matrix_triangular_solve(P_chol, mu[:, :, None], lower=True)[..., 0], 195 | tf.matrix_triangular_solve(P_chol, Q_chol, lower=True)), feed_dict) 196 | 197 | assert_allclose(KL_black, KL_white) 198 | 199 | def test_KL_x1_multiseq(self): 200 | with self.test_context() as sess: 201 | mu = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M)) 202 | Q_chol = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) 203 | P_chol_1D = tf.placeholder(FLOAT_TYPE, shape=(self.M)) 204 | P_chol_2D = tf.placeholder(FLOAT_TYPE, shape=(self.M, self.M)) 205 | P_chol_3D_diag = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M)) 206 | P_chol_3D = tf.placeholder(FLOAT_TYPE, shape=(self.D, self.M, self.M)) 207 | 208 | feed_dict = self.get_feed_dict([mu], [Q_chol], [P_chol_1D, P_chol_2D, P_chol_3D_diag, P_chol_3D]) 209 | 210 | KL_1 = sess.run(KL(mu, Q_chol, P_chol=None), feed_dict) 211 | KL_2 = sess.run(KL(mu, Q_chol, P_chol=P_chol_1D), feed_dict) 212 | KL_3 = sess.run(KL(mu, Q_chol, P_chol=P_chol_2D), feed_dict) 213 | KL_4 = sess.run(KL(mu, Q_chol, P_chol=tf.matrix_diag(P_chol_3D_diag)), feed_dict) 214 | KL_5 = sess.run(KL(mu, Q_chol, P_chol=P_chol_3D), feed_dict) 215 | 216 | KL_map_1 = sess.run(tf.map_fn(lambda a: KL(a[0], a[1], P_chol=None), 217 | (mu, Q_chol), (FLOAT_TYPE)), feed_dict) 218 | KL_map_2 = sess.run(tf.map_fn(lambda a: KL(a[0], a[1], P_chol=P_chol_1D), 219 | (mu, Q_chol), (FLOAT_TYPE)), feed_dict) 220 | KL_map_3 = sess.run(tf.map_fn(lambda a: KL(a[0], a[1], P_chol=P_chol_2D), 221 | (mu, Q_chol), (FLOAT_TYPE)), feed_dict) 222 | KL_map_4 = sess.run(tf.map_fn(lambda a: KL(a[0], a[1], P_chol=a[2]), 223 | (mu, Q_chol, P_chol_3D_diag), (FLOAT_TYPE)), feed_dict) 224 | KL_map_5 = sess.run(tf.map_fn(lambda a: KL(a[0], a[1], P_chol=a[2]), 225 | (mu, Q_chol, P_chol_3D), (FLOAT_TYPE)), feed_dict) 226 | 227 | assert_allclose(KL_1, KL_map_1.sum()) 228 | assert_allclose(KL_2, KL_map_2.sum()) 229 | assert_allclose(KL_3, KL_map_3.sum()) 230 | assert_allclose(KL_4, KL_map_4.sum()) 231 | assert_allclose(KL_5, KL_map_5.sum()) 232 | 233 | 234 | if __name__ == '__main__': 235 | tf.test.main() -------------------------------------------------------------------------------- /GPt/gpssm_multiseq.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import gpflow as gp 19 | from gpflow import Param, ParamList, params_as_tensors 20 | from gpflow import transforms as gtf 21 | from .KL import KL, KL_samples 22 | from .gpssm import GPSSM 23 | 24 | 25 | class GPSSM_MultipleSequences(GPSSM): 26 | """Equivalent to GPSSM but for data which comes as many (potentially variable-length) independent sequences.""" 27 | def __init__(self, 28 | latent_dim, 29 | Y, 30 | inputs=None, 31 | emissions=None, 32 | px1_mu=None, px1_cov=None, 33 | kern=None, 34 | Z=None, n_ind_pts=100, 35 | mean_fn=None, 36 | Q_diag=None, 37 | Umu=None, Ucov_chol=None, 38 | qx1_mu=None, qx1_cov=None, 39 | As=None, bs=None, Ss=None, 40 | n_samples=100, 41 | batch_size=None, 42 | chunking=False, 43 | seed=None, 44 | parallel_iterations=10, 45 | jitter=gp.settings.numerics.jitter_level, 46 | name=None): 47 | 48 | super().__init__(latent_dim, Y[0], inputs=None if inputs is None else inputs[0], emissions=emissions, 49 | px1_mu=px1_mu, px1_cov=None, kern=kern, Z=Z, n_ind_pts=n_ind_pts, 50 | mean_fn=mean_fn, Q_diag=Q_diag, Umu=Umu, Ucov_chol=Ucov_chol, 51 | qx1_mu=qx1_mu, qx1_cov=None, As=None, bs=None, Ss=False if Ss is False else None, 52 | n_samples=n_samples, seed=seed, parallel_iterations=parallel_iterations, 53 | jitter=jitter, name=name) 54 | 55 | self.T = [Y_s.shape[0] for Y_s in Y] 56 | self.T_tf = tf.constant(self.T, dtype=gp.settings.int_type) 57 | self.max_T = max(self.T) 58 | self.sum_T = float(sum(self.T)) 59 | self.n_seq = len(self.T) 60 | self.batch_size = batch_size 61 | self.chunking = chunking 62 | 63 | if self.batch_size is None: 64 | self.Y = ParamList(Y, trainable=False) 65 | else: 66 | _Y = np.stack([np.concatenate([Ys, np.zeros((self.max_T - len(Ys), self.obs_dim))]) for Ys in Y]) 67 | self.Y = Param(_Y, trainable=False) 68 | 69 | if inputs is not None: 70 | if self.batch_size is None: 71 | self.inputs = ParamList(inputs, trainable=False) 72 | else: 73 | desired_length = self.max_T if self.chunking else self.max_T - 1 74 | _inputs = [np.concatenate([inputs[s], np.zeros((desired_length - len(inputs[s]), self.input_dim))]) 75 | for s in range(self.n_seq)] # pad the inputs 76 | self.inputs = Param(_inputs, trainable=False) 77 | 78 | if qx1_mu is None: 79 | self.qx1_mu = Param(np.zeros((self.n_seq, self.latent_dim))) 80 | 81 | self.qx1_cov_chol = Param(np.tile(np.eye(self.latent_dim)[None, ...], [self.n_seq, 1, 1]) if qx1_cov is None 82 | else np.linalg.cholesky(qx1_cov), 83 | transform=gtf.LowerTriangular(self.latent_dim, num_matrices=self.n_seq)) 84 | 85 | 86 | _As = [np.ones((T_s - 1, self.latent_dim)) for T_s in self.T] if As is None else As 87 | _bs = [np.zeros((T_s - 1, self.latent_dim)) for T_s in self.T] if bs is None else bs 88 | if Ss is not False: 89 | _S_chols = [np.tile(self.Q_sqrt.value.copy()[None, ...], [T_s - 1, 1]) for T_s in self.T] if Ss is None \ 90 | else [np.sqrt(S) if S.ndim == 2 else np.linalg.cholesky(S) for S in Ss] 91 | 92 | if self.batch_size is None: 93 | self.As = ParamList(_As) 94 | self.bs = ParamList(_bs) 95 | if Ss is not False: 96 | self.S_chols = ParamList([Param(Sc, transform=gtf.positive if Sc.ndim == 2 else 97 | gtf.LowerTriangular(self.latent_dim, num_matrices=Sc.shape[0])) for Sc in _S_chols]) 98 | else: 99 | _As = np.stack([np.concatenate([_A, np.zeros((self.max_T - len(_A) - 1, *_A.shape[1:]))]) for _A in _As]) 100 | _bs = np.stack([np.concatenate([_b, np.zeros((self.max_T - len(_b) - 1, self.latent_dim))]) for _b in _bs]) 101 | self.As = Param(_As) 102 | self.bs = Param(_bs) 103 | if Ss is not False: 104 | _S_chols = [np.concatenate([_S, np.zeros((self.max_T - len(_S) - 1, *_S.shape[1:]))]) 105 | for _S in _S_chols] 106 | _S_chols = np.stack(_S_chols) 107 | self.S_chols = Param(_S_chols, transform=gtf.positive if _S_chols.ndim == 3 else \ 108 | gtf.LowerTriangular(self.latent_dim, num_matrices=(self.n_seq, self.max_T - 1))) 109 | 110 | self.multi_diag_px1_cov = False 111 | if isinstance(px1_cov, list): # different prior for each sequence 112 | _x1_cov = np.stack(px1_cov) 113 | _x1_cov = np.sqrt(_x1_cov) if _x1_cov.ndim == 2 else np.linalg.cholesky(_x1_cov) 114 | _transform = None if _x1_cov.ndim == 2 else gtf.LowerTriangular(self.latent_dim, num_matrices=self.n_seq) 115 | self.multi_diag_px1_cov = _x1_cov.ndim == 2 116 | elif isinstance(px1_cov, np.ndarray): # same prior for each sequence 117 | assert px1_cov.ndim < 3 118 | _x1_cov = np.sqrt(px1_cov) if px1_cov.ndim == 1 else np.linalg.cholesky(px1_cov) 119 | _transform = None if px1_cov.ndim == 1 else gtf.LowerTriangular(self.latent_dim, squeeze=True) 120 | 121 | self.px1_cov_chol = None if px1_cov is None else Param(_x1_cov, trainable=False, transform=_transform) 122 | 123 | if self.chunking: 124 | px1_mu_check = len(self.px1_mu.shape) == 1 125 | px1_cov_check_1 = not self.multi_diag_px1_cov 126 | px1_cov_check_2 = self.px1_cov_chol is None or len(self.px1_cov_chol.shape) < 3 127 | assert px1_mu_check and px1_cov_check_1 and px1_cov_check_2, \ 128 | 'Only one prior over x1 allowed for chunking' 129 | 130 | @params_as_tensors 131 | def _build_likelihood(self): 132 | batch_indices = None if self.batch_size is None else \ 133 | tf.random_shuffle(tf.range(self.n_seq), seed=self.seed)[:self.batch_size] 134 | 135 | X_samples, fs = self._build_sample(batch_indices=batch_indices) 136 | emissions = self._build_emissions(X_samples, batch_indices=batch_indices) 137 | KL_X = self._build_KL_X(fs, batch_indices=batch_indices) 138 | KL_U = self._build_KL_U() 139 | KL_x1 = self._build_KL_x1(batch_indices=batch_indices) 140 | return emissions - KL_X - KL_U - KL_x1 141 | 142 | @params_as_tensors 143 | def _build_sample(self, batch_indices=None): 144 | Lm = tf.cholesky(self.Kzz) 145 | 146 | X_samples, fs = [], [] 147 | if self.chunking: f_stitch = [] 148 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 149 | b_s = s if batch_indices is None else batch_indices[s] 150 | T_s = self.T[s] if batch_indices is None else self.T_tf[b_s] 151 | _A, _b, _S_chol = self.As[b_s], self.bs[b_s], self.S_chols[b_s] 152 | 153 | if self.chunking: 154 | T_s, _A, _b, _S_chol = tf.cond( 155 | tf.equal(b_s, self.n_seq - 1), 156 | lambda: (T_s, _A, _b, _S_chol), 157 | lambda: 158 | (T_s + 1, 159 | tf.concat([_A, tf.zeros((1, *_A.shape[1:]), dtype=gp.settings.float_type)], 0), 160 | tf.concat([_b, tf.zeros((1, self.latent_dim), dtype=gp.settings.float_type)], 0), 161 | tf.concat([_S_chol, tf.ones((1, *_S_chol.shape[1:]), dtype=gp.settings.float_type)], 0)) 162 | ) 163 | 164 | X_sample, *f = self.sample_fn(T=T_s, inputs=None if self.inputs is None else self.inputs[b_s], 165 | qx1_mu=self.qx1_mu[b_s], qx1_cov_chol=self.qx1_cov_chol[b_s], 166 | As=_A, bs=_b, S_chols=_S_chol, Lm=Lm, 167 | **self.sample_kwargs) 168 | if self.chunking: 169 | X_sample = tf.cond(tf.equal(b_s, self.n_seq - 1), lambda: X_sample, lambda: X_sample[:-1]) 170 | f_stitch.append([_f[-1] for _f in f]) 171 | f = [tf.cond(tf.equal(b_s, self.n_seq - 1), lambda: _f, lambda: _f[:-1]) for _f in f] 172 | 173 | X_samples.append(X_sample) 174 | fs.append(f) 175 | 176 | if self.chunking: fs = [fs, f_stitch] 177 | return X_samples, fs 178 | 179 | @params_as_tensors 180 | def _build_emissions(self, X_samples, batch_indices=None): 181 | emissions = 0. 182 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 183 | b_s = s if batch_indices is None else batch_indices[s] 184 | _Y = self.Y[s] if batch_indices is None else self.Y[b_s, :self.T_tf[b_s]] 185 | 186 | emissions += tf.reduce_sum(tf.reduce_mean( 187 | self.emissions.logp(X_samples[s], _Y[:, None, :]), -1)) 188 | 189 | if batch_indices is not None: 190 | sum_T_minibatch = tf.cast(tf.reduce_sum(tf.gather(self.T_tf, batch_indices)), gp.settings.float_type) 191 | emissions *= self.sum_T / sum_T_minibatch 192 | return emissions 193 | 194 | @params_as_tensors 195 | def _build_KL_X(self, fs, batch_indices=None): 196 | if self.chunking: fs, f_stitch = fs 197 | 198 | KL_X = 0. 199 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 200 | b_s = s if batch_indices is None else batch_indices[s] 201 | T_s = self.T_tf[b_s] 202 | _A = self.As[s] if batch_indices is None else self.As[b_s, :T_s - 1] 203 | _b = self.bs[s] if batch_indices is None else self.bs[b_s, :T_s - 1] 204 | _S_chol = self.S_chols[s] if batch_indices is None else self.S_chols[b_s, :T_s - 1] 205 | 206 | KL_X += tf.reduce_sum(self.KL_fn(*fs[s], As=_A, bs=_b, S_chols=_S_chol)) 207 | 208 | if self.chunking: 209 | def KL_stitch(f_stitch_s, b_s): 210 | kl = KL_samples((f_stitch_s[0] - self.qx1_mu[b_s + 1])[None, ...], 211 | self.qx1_cov_chol[b_s + 1][None, ...], self.Q_sqrt)[0] 212 | if len(f_stitch_s) > 1: 213 | kl += 0.5 * tf.reduce_mean(tf.reduce_sum(f_stitch_s[1] / tf.square(self.Q_sqrt), -1)) 214 | return kl 215 | 216 | if isinstance(b_s, int): 217 | if s < self.n_seq - 1: 218 | KL_X += KL_stitch(f_stitch[s], b_s) 219 | else: 220 | KL_X += tf.cond(tf.equal(b_s, self.n_seq - 1), 221 | lambda: tf.constant(0., dtype=gp.settings.float_type), 222 | lambda: KL_stitch(f_stitch[s], b_s)) 223 | 224 | if batch_indices is not None: 225 | sum_T_minibatch = tf.cast(tf.reduce_sum(tf.gather(self.T_tf, batch_indices)), gp.settings.float_type) 226 | if self.chunking: 227 | KL_X *= (self.sum_T - 1.) / tf.cond(tf.reduce_any(tf.equal(batch_indices, self.n_seq - 1)), 228 | lambda: sum_T_minibatch - 1., lambda: sum_T_minibatch) 229 | else: 230 | KL_X *= (self.sum_T - self.n_seq) / (sum_T_minibatch - self.batch_size) 231 | return KL_X 232 | 233 | @params_as_tensors 234 | def _build_KL_x1(self, batch_indices=None): 235 | """ 236 | qx1_mu: SxE 237 | qx1_cov_chol: SxExE 238 | px1_mu: E or SxE 239 | px1_cov_chol: None or E or ExE or SxE or SxExE 240 | """ 241 | _P_chol = self.px1_cov_chol if not self.multi_diag_px1_cov else tf.matrix_diag(self.px1_cov_chol) 242 | if self.chunking: 243 | _px1_mu = self.px1_mu 244 | _qx1_mu = self.qx1_mu[0] 245 | _qx1_cov_chol = self.qx1_cov_chol[0] 246 | elif batch_indices is None: 247 | _px1_mu = self.px1_mu 248 | _qx1_mu = self.qx1_mu 249 | _qx1_cov_chol = self.qx1_cov_chol 250 | else: 251 | _px1_mu = tf.gather(self.px1_mu, batch_indices) if self.px1_mu.shape.ndims == 2 else self.px1_mu 252 | _qx1_mu = tf.gather(self.qx1_mu, batch_indices) 253 | _qx1_cov_chol = tf.gather(self.qx1_cov_chol, batch_indices) 254 | _P_chol = None if self.px1_cov_chol is None else \ 255 | (_P_chol if _P_chol.shape.ndims < 3 else tf.gather(_P_chol, batch_indices)) 256 | 257 | KL_x1 = KL(_qx1_mu - _px1_mu, _qx1_cov_chol, P_chol=_P_chol) 258 | 259 | if batch_indices is not None and not self.chunking: 260 | KL_x1 *= float(self.n_seq) / float(self.batch_size) 261 | 262 | return KL_x1 263 | -------------------------------------------------------------------------------- /GPt/transitions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import gpflow as gp 19 | from gpflow import Param, params_as_tensors 20 | from gpflow import settings as gps 21 | from gpflow import mean_functions as mean_fns 22 | from gpflow.conditionals import conditional, Kuu 23 | from gpflow import transforms as gtf 24 | import gpflow.multioutput.kernels as mk 25 | import gpflow.multioutput.features as mf 26 | from gpflow.logdensities import mvn_logp, diag_mvn_logp 27 | from .KL import KL 28 | 29 | 30 | class BaseGaussianTransitions(gp.Parameterized): 31 | def __init__(self, dim, input_dim=0, Q=None, name=None): 32 | super().__init__(name=name) 33 | self.OBSERVATIONS_AS_INPUT = False 34 | self.dim = dim 35 | self.input_dim = input_dim 36 | if Q is None or Q.ndim == 2: 37 | self.Qchol = Param(np.eye(self.dim) if Q is None else np.linalg.cholesky(Q), 38 | gtf.LowerTriangular(self.dim, squeeze=True)) 39 | elif Q.ndim == 1: 40 | self.Qchol = Param(Q ** 0.5) 41 | 42 | @params_as_tensors 43 | def conditional_mean(self, X, inputs=None): 44 | raise NotImplementedError 45 | 46 | @params_as_tensors 47 | def conditional_variance(self, X, inputs=None): 48 | if self.Qchol.shape.ndims == 2: 49 | Q = tf.matmul(self.Qchol, self.Qchol, transpose_b=True) 50 | tile_Q = [1, 1] 51 | else: 52 | Q = tf.square(self.Qchol) 53 | tile_Q = [1] 54 | if X.shape.ndims == 3: 55 | return tf.tile(Q[None, None, ...], [tf.shape(X)[0], tf.shape(X)[1], *tile_Q]) 56 | else: 57 | return tf.tile(Q[None, ...], [tf.shape(X)[0], *tile_Q]) 58 | 59 | @params_as_tensors 60 | def conditional(self, X, inputs=None): 61 | return self.conditional_mean(X, inputs=inputs), self.conditional_variance(X, inputs=inputs) 62 | 63 | @params_as_tensors 64 | def logp(self, X, inputs=None): 65 | d = X[..., 1:, :] - self.conditional_mean(X[..., :-1, :], inputs=inputs) 66 | if self.Qchol.shape.ndims == 2: 67 | dim_perm = [2, 0, 1] if X.shape.ndims == 3 else [1, 0] 68 | return mvn_logp(tf.transpose(d, dim_perm), self.Qchol) 69 | elif self.Qchol.shape.ndims == 1: 70 | return diag_mvn_logp(d, self.Qchol) 71 | 72 | def sample_conditional(self, N): 73 | session = self.enquire_session() 74 | x_tf = tf.placeholder(gp.settings.float_type, shape=[N, self.dim]) 75 | input_tf = None if self.input_dim == 0 else tf.placeholder(gp.settings.float_type, 76 | shape=[1, self.input_dim]) 77 | mu_op = self.conditional_mean(x_tf, inputs=input_tf) 78 | Qchol = self.Qchol.value.copy() 79 | 80 | def sample_conditional_fn(x, input=None): 81 | feed_dict = {x_tf: x} 82 | if input is not None: feed_dict[input_tf] = input[None, :] 83 | mu = session.run(mu_op, feed_dict=feed_dict) 84 | if Qchol.ndim == 1: 85 | noise_samples = np.random.randn(*x.shape) * Qchol 86 | else: 87 | noise_samples = np.random.randn(*x.shape) @ Qchol.T 88 | return mu + noise_samples 89 | 90 | return sample_conditional_fn 91 | 92 | @params_as_tensors 93 | def variational_expectations(self, Xmu, Xcov, inputs=None): 94 | raise NotImplementedError 95 | 96 | 97 | class GPTransitions(gp.Parameterized): 98 | def __init__(self, dim, input_dim=0, kern=None, Z=None, n_ind_pts=100, 99 | mean_fn=None, Q_diag=None, Umu=None, Ucov_chol=None, 100 | jitter=gps.numerics.jitter_level, name=None): 101 | super().__init__(name=name) 102 | self.OBSERVATIONS_AS_INPUT = False 103 | self.dim = dim 104 | self.input_dim = input_dim 105 | self.jitter = jitter 106 | 107 | self.Q_sqrt = Param(np.ones(self.dim) if Q_diag is None else Q_diag ** 0.5, transform=gtf.positive) 108 | 109 | self.n_ind_pts = n_ind_pts if Z is None else (Z[0].shape[-2] if isinstance(Z, list) else Z.shape[-2]) 110 | 111 | if isinstance(Z, np.ndarray) and Z.ndim == 2: 112 | self.Z = mf.SharedIndependentMof(gp.features.InducingPoints(Z)) 113 | else: 114 | Z_list = [np.random.randn(self.n_ind_pts, self.dim + self.input_dim) 115 | for _ in range(self.dim)] if Z is None else [z for z in Z] 116 | self.Z = mf.SeparateIndependentMof([gp.features.InducingPoints(z) for z in Z_list]) 117 | 118 | if isinstance(kern, gp.kernels.Kernel): 119 | self.kern = mk.SharedIndependentMok(kern, self.dim) 120 | else: 121 | kern_list = kern or [gp.kernels.Matern32(self.dim + self.input_dim, ARD=True) for _ in range(self.dim)] 122 | self.kern = mk.SeparateIndependentMok(kern_list) 123 | 124 | self.mean_fn = mean_fn or mean_fns.Identity(self.dim) 125 | self.Umu = Param(np.zeros((self.dim, self.n_ind_pts)) if Umu is None else Umu) # Lm^-1(Umu - m(Z)) 126 | transform = gtf.LowerTriangular(self.n_ind_pts, num_matrices=self.dim, squeeze=False) 127 | self.Ucov_chol = Param(np.tile(np.eye(self.n_ind_pts)[None, ...], [self.dim, 1, 1]) 128 | if Ucov_chol is None else Ucov_chol, transform=transform) # Lm^-1(Ucov_chol) 129 | self._Kzz = None 130 | 131 | @property 132 | def Kzz(self): 133 | if self._Kzz is None: 134 | self._Kzz = Kuu(self.Z, self.kern, jitter=self.jitter) # (E x) x M x M 135 | return self._Kzz 136 | 137 | @params_as_tensors 138 | def conditional_mean(self, X, inputs=None, Lm=None): 139 | return self.conditional(X, inputs=inputs, add_noise=False, Lm=Lm)[0] 140 | 141 | @params_as_tensors 142 | def conditional_variance(self, X, inputs=None, add_noise=True, Lm=None): 143 | return self.conditional(X, inputs=inputs, add_noise=add_noise, Lm=Lm)[1] 144 | 145 | @params_as_tensors 146 | def conditional(self, X, inputs=None, add_noise=True, Lm=None): 147 | N = tf.shape(X)[0] 148 | if X.shape.ndims == 3: 149 | X_in = X if inputs is None else tf.concat([X, tf.tile(inputs[None, :, :], [N, 1, 1])], -1) 150 | X_in = tf.reshape(X_in, [-1, self.dim + self.input_dim]) 151 | else: 152 | X_in = X if inputs is None else tf.concat([X, tf.tile(inputs[None, :], [N, 1])], -1) 153 | mu, var = conditional(X_in, self.Z, self.kern, self.Umu, q_sqrt=self.Ucov_chol, white=True, Lm=Lm) 154 | n_mean_inputs = self.mean_fn.input_dim if hasattr(self.mean_fn, "input_dim") else self.dim 155 | mu += self.mean_fn(X_in[:, :n_mean_inputs]) 156 | 157 | if X.shape.ndims == 3: 158 | T = tf.shape(X)[1] 159 | mu = tf.reshape(mu, [N, T, self.dim]) 160 | var = tf.reshape(var, [N, T, self.dim]) 161 | 162 | if add_noise: 163 | var += self.Q_sqrt ** 2. 164 | return mu, var 165 | 166 | @params_as_tensors 167 | def logp(self, X, inputs=None, subtract_KL_U=True): 168 | T = tf.shape(X)[-2] 169 | mu, var = self.conditional(X[..., :-1, :], inputs=inputs, add_noise=False) # N x (T-1) x E or (T-1) x E 170 | logp = diag_mvn_logp(X[..., 1:, :] - mu, self.Q_sqrt) 171 | trace = tf.reduce_sum(var / tf.square(self.Q_sqrt), -1) 172 | ret_value = logp - 0.5 * trace 173 | if subtract_KL_U: 174 | KL_U = KL(self.Umu, self.Ucov_chol) / tf.cast(T - 1, X.dtype) 175 | ret_value -= KL_U 176 | return ret_value 177 | 178 | def sample_conditional(self, N): 179 | session = self.enquire_session() 180 | Lm = tf.constant(session.run(tf.cholesky(self.Kzz))) 181 | x_tf = tf.placeholder(gp.settings.float_type, shape=[N, self.dim]) 182 | input_tf = None if self.input_dim == 0 else tf.placeholder(gp.settings.float_type, 183 | shape=[self.input_dim]) 184 | mu_op, var_op = self.conditional(x_tf, inputs=input_tf, add_noise=True, Lm=Lm) 185 | 186 | def sample_conditional_fn(x, input=None): 187 | feed_dict = {x_tf: x} 188 | if input is not None: feed_dict[input_tf] = input 189 | mu, var = session.run([mu_op, var_op], feed_dict=feed_dict) 190 | return mu + np.sqrt(var) * np.random.randn(*x.shape) 191 | return sample_conditional_fn 192 | 193 | @params_as_tensors 194 | def variational_expectations(self, Xmu, Xcov, inputs=None): 195 | raise NotImplementedError 196 | 197 | 198 | class QuadraticPeriodicTransitions(BaseGaussianTransitions): 199 | def __init__(self, dim, input_dim=None, A=None, B=None, C=None, D=None, Q=None, name=None): 200 | _input_dim = input_dim or dim 201 | _Q = np.eye(dim) * np.sqrt(10.) if Q is None else Q 202 | super().__init__(dim=dim, input_dim=_input_dim, Q=_Q, name=name) 203 | self.A = Param(np.eye(self.dim) * 0.5 if A is None else A) 204 | self.B = Param(np.eye(self.dim) * 25. if B is None else B) 205 | self.C = Param(np.eye(self.dim) * 8.0 if C is None else C) 206 | self.D = Param(np.eye(self.dim, self.input_dim) * 1.2 if D is None else D) 207 | 208 | @params_as_tensors 209 | def conditional_mean(self, X, inputs): 210 | if X.shape.ndims == 3: 211 | _X = tf.reshape(X, [-1, tf.shape(X)[-1]]) # (n_samples*(T-1))xD 212 | Xmu = tf.matmul(_X, self.A, transpose_b=True) + \ 213 | tf.matmul(_X, self.B, transpose_b=True) / (1. + tf.square(_X)) 214 | Xmu = tf.reshape(Xmu, tf.shape(X)) 215 | else: 216 | Xmu = tf.matmul(X, self.A, transpose_b=True) + \ 217 | tf.matmul(X, self.B, transpose_b=True) / (1. + tf.square(X)) 218 | Xmu += tf.matmul(tf.cos(tf.matmul(inputs, self.D, transpose_b=True)), self.C, transpose_b=True) 219 | return Xmu # (T-1)xD or n_samplesx(T-1)xD 220 | 221 | 222 | class GARCHParametricTransitions(BaseGaussianTransitions): 223 | def __init__(self, latent_dim, input_dim, A=None, B=None, C=None, d=None, Q=None, name=None): 224 | _Q = np.eye(latent_dim) * 0.2 if Q is None else Q 225 | super().__init__(dim=latent_dim, input_dim=input_dim, Q=_Q, name=name) 226 | self.OBSERVATIONS_AS_INPUT = True 227 | self.A = Param(np.eye(self.dim) * 0.2 if A is None else A) 228 | self.B = Param(np.eye(self.dim, self.input_dim) * (-0.2) if B is None else B) 229 | self.C = Param(np.eye(self.dim, self.input_dim) * 0.1 if C is None else C) 230 | self.d = Param(np.zeros(self.dim) + 0 if d is None else d) 231 | 232 | @params_as_tensors 233 | def conditional_mean(self, X, inputs): 234 | if X.shape.ndims == 3: 235 | Xmu = tf.matmul(tf.reshape(X, [-1, tf.shape(X)[-1]]), self.A, transpose_b=True) # (n_samples*(T-1))xD 236 | Xmu = tf.reshape(Xmu, tf.shape(X)) 237 | else: 238 | Xmu = tf.matmul(X, self.A, transpose_b=True) 239 | Xmu += tf.matmul(inputs, self.B, transpose_b=True) \ 240 | + tf.matmul(tf.square(inputs), self.C, transpose_b=True) + self.d 241 | return Xmu # (T-1)xD or n_samplesx(T-1)xD 242 | 243 | @params_as_tensors 244 | def variational_expectations(self, Xmu, Xcov, inputs): 245 | T = Xmu._shape_as_list()[0] 246 | logp = self.logp(Xmu, inputs) 247 | tiled_A = tf.tile(self.A[None, :, :], [T-1, 1, 1]) 248 | 249 | if isinstance(Xcov, tuple): 250 | trace_factor = - 2 * tf.matmul(tiled_A, Xcov[1]) 251 | Xcov_diag = Xcov[0] 252 | else: 253 | trace_factor = 0. 254 | Xcov_diag = Xcov 255 | if Xcov_diag.shape.ndims == 2: 256 | trace_factor = tf.matrix_diag(Xcov_diag[1:]) \ 257 | + tf.matmul(self.A * Xcov_diag[:-1][:, None, :], tiled_A, transpose_b=True) 258 | else: 259 | trace_factor += Xcov_diag[1:] + tf.matmul(tf.matmul(tiled_A, Xcov_diag[:-1]), tiled_A, transpose_b=True) 260 | trace = tf.trace(tf.cholesky_solve(tf.tile(self.Qchol[None, :, :], [T-1, 1, 1]), trace_factor)) 261 | return logp - 0.5 * trace 262 | 263 | 264 | class KinkTransitions(BaseGaussianTransitions): 265 | def __init__(self, dim, a=None, b=None, c=None, D=None, Q=None, name=None): 266 | _Q = np.eye(dim) * 0.5 if Q is None else Q 267 | super().__init__(dim=dim, Q=_Q, name=name) 268 | self.a = Param(np.ones(self.dim) * 0.8 if a is None else a) 269 | self.b = Param(np.ones(self.dim) * 0.2 if b is None else b) 270 | self.c = Param(np.ones(self.dim) * 5.0 if c is None else c) 271 | self.D = Param(np.eye(self.dim) * 2.0 if D is None else D) 272 | 273 | @params_as_tensors 274 | def conditional_mean(self, X, inputs=None): 275 | if X.shape.ndims == 3: 276 | Xmu = tf.matmul(tf.reshape(X, [-1, tf.shape(X)[-1]]), self.D, transpose_b=True) # (n_samples*(T-1))xD 277 | Xmu = tf.reshape(Xmu, tf.shape(X)) 278 | else: 279 | Xmu = tf.matmul(X, self.D, transpose_b=True) # (T-1)xD 280 | Xmu = self.a + (self.b + X) * (1. - self.c / (1. + tf.exp(-Xmu))) 281 | return Xmu # (T-1)xD or n_samplesx(T-1)xD 282 | -------------------------------------------------------------------------------- /GPt/emissions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import gpflow as gp 19 | from gpflow import settings, params_as_tensors, autoflow, kullback_leiblers 20 | from gpflow.logdensities import mvn_logp 21 | from gpflow import mean_functions as mean_fns 22 | 23 | 24 | class GaussianEmissions(gp.likelihoods.Likelihood): 25 | def __init__(self, latent_dim=None, obs_dim=None, C=None, R=None, bias=None, name=None): 26 | super().__init__(name=name) 27 | self.REQUIRE_FULL_COV = True 28 | self.latent_dim = C.shape[1] if C is not None else latent_dim 29 | 30 | if (C is None) and (R is None): 31 | self.obs_dim = obs_dim 32 | else: 33 | self.obs_dim = R.shape[0] if R is not None else C.shape[0] 34 | 35 | self.C = gp.Param(np.eye(self.obs_dim, self.latent_dim) if C is None else C) 36 | self.Rchol = gp.Param(np.eye(self.obs_dim) if R is None else np.linalg.cholesky(R), 37 | gp.transforms.LowerTriangular(self.obs_dim, squeeze=True)) 38 | self.bias = gp.Param(np.zeros(self.obs_dim) if bias is None else bias) 39 | 40 | @params_as_tensors 41 | def conditional_mean(self, X): 42 | """ 43 | :param X: latent state (T x E) or (n_samples x T x E) 44 | :return: mu(Y)|X (T x D) or (n_samples x T x D) 45 | """ 46 | if X.shape.ndims == 3: 47 | Ymu = tf.matmul(tf.reshape(X, [-1, tf.shape(X)[-1]]), self.C, 48 | transpose_b=True) + self.bias 49 | return tf.reshape(Ymu, [tf.shape(X)[0], tf.shape(X)[1], self.obs_dim]) 50 | else: 51 | return tf.matmul(X, self.C, transpose_b=True) + self.bias 52 | 53 | @params_as_tensors 54 | def conditional_variance(self, X): 55 | """ 56 | :param X: latent state (T x E) or (n_samples x T x E) 57 | :return: cov(Y)|X (T x D x D) or (n_samples x T x D x D) 58 | """ 59 | R = tf.matmul(self.Rchol, self.Rchol, transpose_b=True) 60 | if X.shape.ndims == 3: 61 | return tf.tile(R[None, None, :, :], [tf.shape(X)[0], tf.shape(X)[1], 1, 1]) 62 | else: 63 | return tf.tile(R[None, :, :], [tf.shape(X)[0], 1, 1]) 64 | 65 | @params_as_tensors 66 | def logp(self, X, Y): 67 | """ 68 | :param X: latent state (T x E) or (n_samples x T x E) 69 | :param Y: observations (T x D) 70 | :return: \log P(Y|X(n)) (T) or (n_samples x T) 71 | """ 72 | d = Y - self.conditional_mean(X) 73 | dim_perm = [2, 0, 1] if X.shape.ndims == 3 else [1, 0] 74 | return mvn_logp(tf.transpose(d, dim_perm), self.Rchol) 75 | 76 | def sample_conditional(self, X): 77 | X_in = X if X.ndim == 2 else X.reshape(-1, X.shape[-1]) 78 | noise_samples = np.random.randn(X_in.shape[0], self.obs_dim) @ self.Rchol.value.T 79 | Y = X_in @ self.C.value.T + self.bias.value + noise_samples 80 | if X.ndim != 2: 81 | Y = Y.reshape(*X.shape[:-1], self.obs_dim) 82 | return Y 83 | 84 | @params_as_tensors 85 | def predict_mean_and_var(self, Xmu, Xcov): 86 | assert Xcov.shape.ndims >= 2 87 | _Xcov = Xcov if Xcov.shape.ndims == 3 else tf.matrix_diag(Xcov) 88 | Ymu_pred = self.conditional_mean(Xmu) 89 | C_batch = tf.tile(tf.expand_dims(self.C, 0), [tf.shape(_Xcov)[0], 1, 1]) 90 | Ycov_pred = tf.matmul(self.Rchol, self.Rchol, transpose_b=True) \ 91 | + tf.matmul(C_batch, tf.matmul(_Xcov, C_batch, transpose_b=True)) 92 | return Ymu_pred, Ycov_pred 93 | 94 | @params_as_tensors 95 | def variational_expectations(self, Xmu, Xcov, Y): 96 | assert Xcov.shape.ndims >= 2 97 | _Xcov = Xcov if Xcov.shape.ndims == 3 else tf.matrix_diag(Xcov) 98 | logdet = 2. * tf.reduce_sum(tf.log(tf.abs(tf.diag_part(self.Rchol)))) 99 | d = Y - self.conditional_mean(Xmu) 100 | quad = tf.reduce_sum(tf.square(tf.matrix_triangular_solve(self.Rchol, tf.transpose(d), lower=True)), 0) # T 101 | Ctr_Rinv_C = tf.matmul(self.C, tf.cholesky_solve(self.Rchol, self.C), transpose_a=True) 102 | tr = tf.reduce_sum(Ctr_Rinv_C * _Xcov, [1, 2]) # T 103 | return -0.5 * (self.obs_dim * np.log(2. * np.pi) + logdet + quad + tr) 104 | 105 | @autoflow((settings.float_type,), (settings.float_type,)) 106 | def compute_predictive_mean_and_var(self, Xmu, Xcov): 107 | return self.predict_mean_and_var(Xmu, Xcov) 108 | 109 | @autoflow((settings.float_type,), (settings.float_type,), (settings.float_type,)) 110 | def compute_variational_expectations(self, Xmu, Xcov, Y): 111 | return self.variational_expectations(Xmu, Xcov, Y) 112 | 113 | 114 | class SingleGPEmissions(gp.likelihoods.Likelihood): 115 | def __init__(self, latent_dim, Z, mean_function=None, kern=None, likelihood=None, name=None): 116 | super().__init__(name=name) 117 | self.latent_dim = latent_dim 118 | self.obs_dim = 1 119 | self.n_ind_pts = Z.shape[0] 120 | 121 | self.mean_function = mean_function or mean_fns.Zero(output_dim=self.obs_dim) 122 | self.kern = kern or gp.kernels.RBF(self.latent_dim, ARD=True) 123 | self.likelihood = likelihood or gp.likelihoods.Gaussian() 124 | self.Z = gp.features.InducingPoints(Z) 125 | self.Umu = gp.Param(np.zeros((self.n_ind_pts, self.latent_dim))) # (Lm^-1)(Umu - m(Z)) 126 | self.Ucov_chol = gp.Param(np.tile(np.eye(self.n_ind_pts)[None, ...], [self.obs_dim, 1, 1]), 127 | transform=gp.transforms.LowerTriangular( 128 | self.n_ind_pts, num_matrices=self.obs_dim, squeeze=False)) # (Lm^-1)Lu 129 | 130 | @params_as_tensors 131 | def conditional(self, X, add_observation_noise=True): 132 | """ 133 | :param X: latent state (... x E) 134 | :return: mu(Y)|X (... x D) and var(Y)|X (... x D) 135 | """ 136 | in_shape = tf.shape(X) 137 | out_shape = tf.concat([in_shape[:-1], [self.obs_dim]]) 138 | _X = tf.reshape(X, [-1, self.latent_dim]) 139 | mu, var = gp.conditionals.conditional(_X, self.Z, self.kern, self.Umu, q_sqrt=self.Ucov_chol, 140 | full_cov=False, white=True, full_output_cov=False) 141 | mu += self.mean_function(_X) 142 | if add_observation_noise: 143 | var += self.likelihood.variance 144 | return tf.reshape(mu, out_shape), tf.reshape(var, out_shape) 145 | 146 | 147 | @params_as_tensors 148 | def conditional_mean(self, X): 149 | """ 150 | :param X: latent state (... x E) 151 | :return: mu(Y)|X (... x D) 152 | """ 153 | return self.conditional(X)[0] 154 | 155 | @params_as_tensors 156 | def conditional_variance(self, X): 157 | """ 158 | :param X: latent state (... x E) 159 | :return: var(Y)|X (... x D) 160 | """ 161 | return self.conditional(X)[1] 162 | 163 | @params_as_tensors 164 | def logp(self, X, Y): 165 | """ 166 | :param X: latent state (n_samples x T x E) 167 | :param Y: observations (n_samples x T x D) 168 | :return: variational lower bound on \log P(Y|X) (n_samples x T) 169 | """ 170 | KL = kullback_leiblers.gauss_kl(self.Umu, self.Ucov_chol, None) # () 171 | fmean, fvar = self.conditional(X, add_observation_noise=False) # (n_samples x T x D) and (n_samples x T x D) 172 | var_exp = tf.reduce_sum(self.likelihood.variational_expectations(fmean, fvar, Y), -1) # (n_samples x T) 173 | return var_exp - KL / tf.cast(tf.shape(X)[1], gp.settings.float_type) 174 | 175 | 176 | class PolarToCartesianEmissions(GaussianEmissions): 177 | def __init__(self, R=None, name=None): 178 | obs_dim = 2 179 | R_init = np.eye(obs_dim) * 0.1 ** 2 if R is None else R 180 | super().__init__(latent_dim=obs_dim, obs_dim=obs_dim, 181 | C=np.eye(obs_dim), R=R_init, bias=np.zeros(obs_dim), name=name) 182 | self.C.trainable = False 183 | self.bias.trainable = False 184 | 185 | @params_as_tensors 186 | def conditional_mean(self, X): 187 | return tf.stack([tf.cos(X[..., 0] + 3 / 2 * np.pi), 188 | tf.sin(X[..., 0] + 3 / 2 * np.pi)], -1) 189 | 190 | def sample_conditional(self, X): 191 | conditional_mean = np.stack([ 192 | np.cos(X[..., 0] + 3 / 2 * np.pi), 193 | np.sin(X[..., 0] + 3 / 2 * np.pi)], -1) 194 | flat_noise = np.random.randn(np.prod(X.shape[:-1]), self.obs_dim) 195 | noise_samples = (flat_noise @ self.Rchol.value.T).reshape(*X.shape) 196 | return conditional_mean + noise_samples 197 | 198 | @params_as_tensors 199 | def predict_mean_and_var(self, Xmu, Xcov): 200 | raise NotImplementedError 201 | 202 | @params_as_tensors 203 | def variational_expectations(self, Xmu, Xcov, Y): 204 | raise NotImplementedError 205 | 206 | 207 | class SquaringEmissions(GaussianEmissions): 208 | def __init__(self, obs_dim, latent_dim=None, C=None, R=None, name=None): 209 | super().__init__(latent_dim=latent_dim or obs_dim, 210 | obs_dim=obs_dim, 211 | C=C, R=R, bias=np.zeros(obs_dim), name=name) 212 | self.bias.trainable = False 213 | 214 | @params_as_tensors 215 | def conditional_mean(self, X): 216 | return super().conditional_mean(tf.square(X)) 217 | 218 | def sample_conditional(self, X): 219 | super().sample_conditional(np.square(X)) 220 | 221 | @params_as_tensors 222 | def predict_mean_and_var(self, Xmu, Xcov): 223 | raise NotImplementedError 224 | 225 | @params_as_tensors 226 | def variational_expectations(self, Xmu, Xcov, Y): 227 | raise NotImplementedError 228 | 229 | 230 | class VolatilityEmissions(gp.likelihoods.Likelihood): 231 | """ 232 | The volatility likelihood is a zero mean Gaussian likelihood with varying noise: 233 | p(y|x) = N(y| 0, \exp(x)) 234 | 235 | :param inv_link: the link function that is applied to the inputs, it defaults to `tf.exp` 236 | :type inv_link: a basic TensorFlow function 237 | """ 238 | def __init__(self, inv_link=tf.exp, name=None): 239 | super().__init__(name=name) 240 | self.REQUIRE_FULL_COV = False 241 | self.inv_link = inv_link 242 | 243 | def conditional_mean(self, X): 244 | return tf.zeros_like(X) 245 | 246 | def conditional_variance(self, X): 247 | return self.inv_link(X) 248 | 249 | def logp(self, X, Y): 250 | return gp.logdensities.gaussian(Y, self.conditional_mean(X), self.conditional_variance(X)) 251 | 252 | def sample_conditional(self, X): 253 | if self.inv_link is tf.exp: 254 | return np.exp(0.5 * X) * np.random.randn(*X.shape) 255 | else: 256 | raise NotImplementedError('Currently only the exponential link function is supported') 257 | 258 | def predict_mean_and_var(self, Xmu, Xcov): 259 | assert Xcov.shape.ndims >= 2 260 | Xvar = Xcov if Xcov.shape.ndims == 2 else tf.matrix_diag_part(Xcov) 261 | mu = self.conditional_mean(Xmu) 262 | if self.inv_link is tf.exp: 263 | var = tf.exp(Xmu + Xvar / 2.0) 264 | return mu, var 265 | else: 266 | raise NotImplementedError('Currently only the exponential link function is supported') 267 | 268 | def variational_expectations(self, Xmu, Xcov, Y): 269 | """ 270 | _NormDist(Xmu, Xcov) 271 | :param Xmu: Latent function means (TxD) 272 | :param Xcov: Latent function variances (TxDxD) or (TxD) 273 | :param Y: Observations (TxD) 274 | :return: expectations (T) 275 | """ 276 | assert Xcov.shape.ndims >= 2 277 | Xvar = Xcov if Xcov.shape.ndims == 2 else tf.matrix_diag_part(Xcov) 278 | if self.inv_link is tf.exp: 279 | return -0.5 * tf.reduce_sum( 280 | np.log(2 * np.pi) + Xmu + tf.square(Y) * tf.exp(-Xmu + Xvar / 2.) 281 | , 1) 282 | else: 283 | raise NotImplementedError('Currently only the exponential link function is supported') 284 | 285 | 286 | class PriceAndVolatilityEmissions(VolatilityEmissions): 287 | """ 288 | This is a Volatility likelihood with a non-zero mean and varying noise: 289 | p(y|x_1, x_2) = N(y| w * x_1 + b, \exp(x_2)) 290 | 291 | :param inv_link: the link function that is applied to the inputs, it defaults to `tf.exp` 292 | :type inv_link: a basic TensorFlow function 293 | """ 294 | def __init__(self, inv_link=tf.exp, w=1., b=0., name=None): 295 | super().__init__(inv_link, name=name) 296 | self.w = gp.Param(w) 297 | self.b = gp.Param(b) 298 | 299 | @params_as_tensors 300 | def conditional_mean(self, X): 301 | return self.w * X[..., 0:1] + self.b 302 | 303 | def conditional_variance(self, X): 304 | return self.inv_link(X[..., 1:2]) 305 | 306 | def sample_conditional(self, X): 307 | if self.inv_link is tf.exp: 308 | return np.exp(0.5 * X[..., 1:2]) * np.random.randn(*X.shape[:-1], 1) \ 309 | + self.w.value * X[..., 0:1] + self.b.value 310 | else: 311 | raise NotImplementedError('Currently only the exponential link function is supported') 312 | 313 | @params_as_tensors 314 | def predict_mean_and_var(self, Xmu, Xcov): 315 | if Xcov.shape.ndims == 3: 316 | _Xcov = tf.identity(Xcov) 317 | else: 318 | _Xcov = tf.matrix_diag(Xcov) 319 | 320 | Ymu_pred = self.conditional_mean(Xmu) 321 | if self.inv_link is tf.exp: 322 | sigma_1 = _Xcov[:, 0, 0][:, None] 323 | sigma_2 = _Xcov[:, 1, 1][:, None] 324 | Yvar_pred = sigma_1 * self.w ** 2. + tf.exp(Xmu[:, 1:2] + sigma_2 / 2.) 325 | return Ymu_pred, Yvar_pred 326 | else: 327 | raise NotImplementedError('Currently only the exponential link function is supported') 328 | 329 | @params_as_tensors 330 | def variational_expectations(self, Xmu, Xcov, Y): 331 | """ 332 | _NormDist(Xmu, Xcov) 333 | :param Xmu: Latent function means (Tx2) 334 | :param Xcov: Latent function variances (Tx2x2) or (Tx2) 335 | :param Y: Observations (Tx1) 336 | :return: expectations (T) 337 | """ 338 | if Xcov.shape.ndims == 3: 339 | _Xcov = tf.identity(Xcov) 340 | else: 341 | _Xcov = tf.matrix_diag(Xcov) 342 | 343 | if self.inv_link is tf.exp: 344 | sigma_1 = _Xcov[:, 0, 0][:, None] 345 | sigma_2 = _Xcov[:, 1, 1][:, None] 346 | cross = _Xcov[:, 0, 1][:, None] 347 | return -0.5 * ( 348 | np.log(2 * np.pi) + Xmu[:, 1:2] 349 | + (tf.square(Y - self.b - self.w * (Xmu[:, 0:1] - cross)) + sigma_1 * self.w ** 2.) 350 | * tf.exp(-Xmu[:, 1:2] + sigma_2 / 2.)) 351 | else: 352 | raise NotImplementedError('Currently only the exponential link function is supported') 353 | -------------------------------------------------------------------------------- /GPt/gpssm_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import tensorflow as tf 17 | from gpflow import params_as_tensors 18 | from gpflow import settings as gps 19 | from .gpssm import GPSSM 20 | from .gpssm_multiseq import GPSSM_MultipleSequences 21 | from .ssm import SSM_SG, SSM_SG_MultipleSequences 22 | from .transitions import GPTransitions 23 | 24 | 25 | class GPSSM_VCDT(GPSSM): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self.sample_fn = self._build_linear_time_q_sample 29 | self.sample_kwargs = {'return_f_moments': True, 30 | 'return_x_cov_chols': True, 31 | 'sample_u': True} 32 | self.KL_fn = self._build_factorized_transition_KLs 33 | 34 | 35 | class GPSSM_FactorizedLinear(SSM_SG): 36 | def __init__(self, latent_dim, Y, inputs=None, emissions=None, 37 | px1_mu=None, px1_cov=None, 38 | kern=None, Z=None, n_ind_pts=100, 39 | mean_fn=None, 40 | Q_diag=None, 41 | Umu=None, Ucov_chol=None, 42 | Xmu=None, Xchol=None, 43 | n_samples=100, seed=None, 44 | jitter=gps.numerics.jitter_level, name=None): 45 | 46 | transitions = GPTransitions(latent_dim, 47 | input_dim=0 if inputs is None else inputs.shape[1], 48 | kern=kern, Z=Z, n_ind_pts=n_ind_pts, 49 | mean_fn=mean_fn, Q_diag=Q_diag, 50 | Umu=Umu, Ucov_chol=Ucov_chol, 51 | jitter=jitter, 52 | name=None if name is None else name + '/transitions') 53 | 54 | super().__init__(latent_dim, Y, transitions, 55 | T_latent=None, inputs=inputs, emissions=emissions, 56 | px1_mu=px1_mu, px1_cov=px1_cov, Xmu=Xmu, Xchol=Xchol, 57 | n_samples=n_samples, 58 | seed=seed, name=name) 59 | 60 | 61 | class GPSSM_FactorizedNonLinear(GPSSM): 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | self.sample_fn = self._build_linear_time_q_sample 65 | self.sample_kwargs = {'return_f_moments': True, 66 | 'return_x_cov_chols': True, 67 | 'sample_u': False} 68 | self.KL_fn = self._build_factorized_transition_KLs 69 | 70 | 71 | class GPSSM_Parametric(GPSSM): 72 | """ 73 | Corresponds to doing inference in a parametric model with prior: p(f(X)|u)p(u). 74 | This method can often outperform VCDT as it pays no price for being unable to 75 | approximate the non-parametric posterior and, for sufficiently large numbers of inducing 76 | points, it can provide a fit which is closer to that of the non-parametric, cubic time method. 77 | It can thus be useful to check if the ELBO value achieved by this method is similar to the one 78 | the cubic time method would give, for the same model and variational parameters. 79 | """ 80 | def __init__(self, *args, **kwargs): 81 | super().__init__(*args, **kwargs) 82 | self.sample_fn = self._build_linear_time_q_sample 83 | self.sample_kwargs = {'return_f_moments': True, 84 | 'sample_u': True} 85 | self.KL_fn = self._build_transition_KLs 86 | 87 | 88 | class GPSSM_Cubic(GPSSM): 89 | """Full non-parametric prior and posterior (with cubic time sampling).""" 90 | def __init__(self, *args, **kwargs): 91 | super().__init__(*args, **kwargs) 92 | self.sample_fn = self._build_cubic_time_q_sample 93 | self.sample_kwargs = {'return_f_moments': True, 94 | 'return_f': False, 95 | 'sample_u': False} 96 | self.KL_fn = self._build_transition_KLs 97 | 98 | 99 | # ===== Methods where the posterior transitions are fixed to the prior (A=I, b=0, S=Q) ===== # 100 | 101 | 102 | class PRSSM(GPSSM): 103 | def __init__(self, *args, **kwargs): 104 | if 'As' in kwargs.keys(): kwargs.pop('As') 105 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 106 | kwargs['Ss'] = False 107 | super().__init__(*args, **kwargs) 108 | self.As.trainable = False 109 | self.bs.trainable = False 110 | self.sample_fn = self._build_linear_time_q_sample 111 | self.sample_kwargs = {'return_f_moments': False, 112 | 'sample_u': False} 113 | self.KL_fn = lambda *fs: tf.constant(0., dtype=gps.float_type) 114 | 115 | @property 116 | def S_chols(self): 117 | if self._S_chols is None: 118 | self._S_chols = tf.ones((self.T - 1, self.latent_dim), dtype=gps.float_type) * self.Q_sqrt 119 | return self._S_chols 120 | 121 | 122 | class GPSSM_PPT(GPSSM): 123 | """ 124 | PPT = Parametric, Prior Transitions 125 | Effectively the same as PRSSM but with the correct sampling scheme: 126 | explicit sampling and conditioning of the inducing outputs u. 127 | Also A=I, b=0, S=Q, i.e. the posterior transitions are fixed to the prior. 128 | Beware that this still corresponds to doing inference w.r.t. a parametric prior p(f(X)|u)p(u). 129 | """ 130 | def __init__(self, *args, **kwargs): 131 | if 'As' in kwargs.keys(): kwargs.pop('As') 132 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 133 | kwargs['Ss'] = False 134 | super().__init__(*args, **kwargs) 135 | self.As.trainable = False 136 | self.bs.trainable = False 137 | self.sample_fn = self._build_linear_time_q_sample 138 | self.sample_kwargs = {'return_f_moments': False, 139 | 'sample_u': True} 140 | self.KL_fn = lambda *fs: tf.constant(0., dtype=gps.float_type) 141 | 142 | @property 143 | def S_chols(self): 144 | if self._S_chols is None: 145 | self._S_chols = tf.ones((self.T - 1, self.latent_dim), dtype=gps.float_type) * self.Q_sqrt 146 | return self._S_chols 147 | 148 | 149 | class GPSSM_VPT(GPSSM): 150 | """ 151 | VPT = VCDT, Prior Transitions. 152 | VCDT inference method but with posterior transitions fixed to the prior (A=I, b=0, S=Q) as in PRSSM. 153 | """ 154 | def __init__(self, *args, **kwargs): 155 | if 'As' in kwargs.keys(): kwargs.pop('As') 156 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 157 | kwargs['Ss'] = False 158 | super().__init__(*args, **kwargs) 159 | self.As.trainable = False 160 | self.bs.trainable = False 161 | self.sample_fn = self._build_linear_time_q_sample 162 | self.sample_kwargs = {'return_f_moments': True, 163 | 'return_x_cov_chols': True, 164 | 'sample_u': True} 165 | self.KL_fn = self._build_factorized_transition_KLs 166 | 167 | @property 168 | def S_chols(self): 169 | if self._S_chols is None: 170 | self._S_chols = tf.ones((self.T - 1, self.latent_dim), dtype=gps.float_type) * self.Q_sqrt 171 | return self._S_chols 172 | 173 | 174 | class GPSSM_CPT(GPSSM): 175 | """ 176 | CPT = Cubic sampling, Prior Transitions. 177 | Full non-parametric prior and posterior (with cubic time sampling), 178 | but with posterior transitions fixed to the prior (A=I, b=0, S=Q) as in PRSSM. 179 | """ 180 | def __init__(self, *args, **kwargs): 181 | if 'As' in kwargs.keys(): kwargs.pop('As') 182 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 183 | kwargs['Ss'] = False 184 | super().__init__(*args, **kwargs) 185 | self.As.trainable = False 186 | self.bs.trainable = False 187 | self.sample_fn = self._build_cubic_time_q_sample 188 | self.sample_kwargs = {'return_f_moments': False, 189 | 'return_f': False, 190 | 'sample_u': False} 191 | self.KL_fn = lambda *fs: tf.constant(0., dtype=gps.float_type) 192 | 193 | @property 194 | def S_chols(self): 195 | if self._S_chols is None: 196 | self._S_chols = tf.ones((self.T - 1, self.latent_dim), dtype=gps.float_type) * self.Q_sqrt 197 | return self._S_chols 198 | 199 | 200 | # ========================= Multiple Sequences (MS) data - same models as above ========================= # 201 | 202 | 203 | class GPSSM_MS_VCDT(GPSSM_MultipleSequences): 204 | def __init__(self, *args, **kwargs): 205 | super().__init__(*args, **kwargs) 206 | self.sample_fn = self._build_linear_time_q_sample 207 | self.sample_kwargs = {'return_f_moments': True, 208 | 'return_x_cov_chols': True, 209 | 'sample_u': True} 210 | self.KL_fn = self._build_factorized_transition_KLs 211 | 212 | 213 | class GPSSM_MS_FactorizedLinear(SSM_SG_MultipleSequences): 214 | def __init__(self, latent_dim, Y, inputs=None, emissions=None, 215 | px1_mu=None, px1_cov=None, 216 | kern=None, Z=None, n_ind_pts=100, 217 | mean_fn=None, 218 | Q_diag=None, 219 | Umu=None, Ucov_chol=None, 220 | Xmu=None, Xchol=None, 221 | n_samples=100, batch_size=None, seed=None, 222 | jitter=gps.numerics.jitter_level, name=None): 223 | 224 | transitions = GPTransitions(latent_dim, 225 | input_dim=0 if inputs is None else inputs[0].shape[1], 226 | kern=kern, Z=Z, n_ind_pts=n_ind_pts, 227 | mean_fn=mean_fn, Q_diag=Q_diag, 228 | Umu=Umu, Ucov_chol=Ucov_chol, 229 | jitter=jitter, 230 | name=None if name is None else name + '/transitions') 231 | 232 | super().__init__(latent_dim, Y, transitions, 233 | T_latent=None, inputs=inputs, emissions=emissions, 234 | px1_mu=px1_mu, px1_cov=px1_cov, Xmu=Xmu, Xchol=Xchol, 235 | n_samples=n_samples, batch_size=batch_size, 236 | seed=seed, name=name) 237 | 238 | 239 | class GPSSM_MS_FactorizedNonLinear(GPSSM_MultipleSequences): 240 | def __init__(self, *args, **kwargs): 241 | super().__init__(*args, **kwargs) 242 | self.sample_fn = self._build_linear_time_q_sample 243 | self.sample_kwargs = {'return_f_moments': True, 244 | 'return_x_cov_chols': True, 245 | 'sample_u': False} 246 | self.KL_fn = self._build_factorized_transition_KLs 247 | 248 | 249 | class GPSSM_MS_Parametric(GPSSM_MultipleSequences): 250 | def __init__(self, *args, **kwargs): 251 | super().__init__(*args, **kwargs) 252 | self.sample_fn = self._build_linear_time_q_sample 253 | self.sample_kwargs = {'return_f_moments': True, 254 | 'sample_u': True} 255 | self.KL_fn = self._build_transition_KLs 256 | 257 | 258 | class GPSSM_MS_Cubic(GPSSM_MultipleSequences): 259 | def __init__(self, *args, **kwargs): 260 | super().__init__(*args, **kwargs) 261 | self.sample_fn = self._build_cubic_time_q_sample 262 | self.sample_kwargs = {'return_f_moments': True, 263 | 'return_f': False, 264 | 'sample_u': False} 265 | self.KL_fn = self._build_transition_KLs 266 | 267 | 268 | # ===== Methods where the posterior transitions are fixed to the prior (A=I, b=0, S=Q) ===== # 269 | 270 | 271 | class PRSSM_MS(GPSSM_MultipleSequences): 272 | def __init__(self, *args, **kwargs): 273 | if 'As' in kwargs.keys(): kwargs.pop('As') 274 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 275 | kwargs['Ss'] = False 276 | super().__init__(*args, **kwargs) 277 | self.As.trainable = False 278 | self.bs.trainable = False 279 | self.sample_fn = self._build_linear_time_q_sample 280 | self.sample_kwargs = {'return_f_moments': False, 281 | 'sample_u': False} 282 | 283 | @property 284 | def S_chols(self): 285 | if self._S_chols is None: 286 | if self.batch_size is None: 287 | self._S_chols = [tf.ones((self.T[s] - 1, self.latent_dim), 288 | dtype=gps.float_type) * self.Q_sqrt for s in range(self.n_seq)] 289 | else: 290 | self._S_chols = tf.ones((self.n_seq, self.max_T - 1, self.latent_dim), 291 | dtype=gps.float_type) * self.Q_sqrt 292 | return self._S_chols 293 | 294 | @params_as_tensors 295 | def _build_KL_X(self, fs, batch_indices=None): 296 | return tf.constant(0., dtype=gps.float_type) 297 | 298 | 299 | class GPSSM_MS_PPT(GPSSM_MultipleSequences): 300 | def __init__(self, *args, **kwargs): 301 | if 'As' in kwargs.keys(): kwargs.pop('As') 302 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 303 | kwargs['Ss'] = False 304 | super().__init__(*args, **kwargs) 305 | self.As.trainable = False 306 | self.bs.trainable = False 307 | self.sample_fn = self._build_linear_time_q_sample 308 | self.sample_kwargs = {'return_f_moments': False, 309 | 'sample_u': True} 310 | 311 | @property 312 | def S_chols(self): 313 | if self._S_chols is None: 314 | if self.batch_size is None: 315 | self._S_chols = [tf.ones((self.T[s] - 1, self.latent_dim), 316 | dtype=gps.float_type) * self.Q_sqrt for s in range(self.n_seq)] 317 | else: 318 | self._S_chols = tf.ones((self.n_seq, self.max_T - 1, self.latent_dim), 319 | dtype=gps.float_type) * self.Q_sqrt 320 | return self._S_chols 321 | 322 | @params_as_tensors 323 | def _build_KL_X(self, fs, batch_indices=None): 324 | return tf.constant(0., dtype=gps.float_type) 325 | 326 | 327 | class GPSSM_MS_VPT(GPSSM_MultipleSequences): 328 | def __init__(self, *args, **kwargs): 329 | if 'As' in kwargs.keys(): kwargs.pop('As') 330 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 331 | kwargs['Ss'] = False 332 | super().__init__(*args, **kwargs) 333 | self.As.trainable = False 334 | self.bs.trainable = False 335 | self.sample_fn = self._build_linear_time_q_sample 336 | self.sample_kwargs = {'return_f_moments': True, 337 | 'return_x_cov_chols': True, 338 | 'sample_u': True} 339 | self.KL_fn = self._build_factorized_transition_KLs 340 | 341 | @property 342 | def S_chols(self): 343 | if self._S_chols is None: 344 | if self.batch_size is None: 345 | self._S_chols = [tf.ones((self.T[s] - 1, self.latent_dim), 346 | dtype=gps.float_type) * self.Q_sqrt for s in range(self.n_seq)] 347 | else: 348 | self._S_chols = tf.ones((self.n_seq, self.max_T - 1, self.latent_dim), 349 | dtype=gps.float_type) * self.Q_sqrt 350 | return self._S_chols 351 | 352 | 353 | class GPSSM_MS_CPT(GPSSM_MultipleSequences): 354 | def __init__(self, *args, **kwargs): 355 | if 'As' in kwargs.keys(): kwargs.pop('As') 356 | if 'bs' in kwargs.keys(): kwargs.pop('bs') 357 | kwargs['Ss'] = False 358 | super().__init__(*args, **kwargs) 359 | self.As.trainable = False 360 | self.bs.trainable = False 361 | self.sample_fn = self._build_cubic_time_q_sample 362 | self.sample_kwargs = {'return_f_moments': False, 363 | 'return_f': False, 364 | 'sample_u': False} 365 | 366 | @property 367 | def S_chols(self): 368 | if self._S_chols is None: 369 | if self.batch_size is None: 370 | self._S_chols = [tf.ones((self.T[s] - 1, self.latent_dim), 371 | dtype=gps.float_type) * self.Q_sqrt for s in range(self.n_seq)] 372 | else: 373 | self._S_chols = tf.ones((self.n_seq, self.max_T - 1, self.latent_dim), 374 | dtype=gps.float_type) * self.Q_sqrt 375 | return self._S_chols 376 | 377 | @params_as_tensors 378 | def _build_KL_X(self, fs, batch_indices=None): 379 | return tf.constant(0., dtype=gps.float_type) 380 | -------------------------------------------------------------------------------- /tests/test_sampling_schemes.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | from numpy.testing import assert_allclose 18 | import tensorflow as tf 19 | from tensorflow_probability import distributions as tfd 20 | import gpflow as gp 21 | from gpflow.test_util import GPflowTestCase 22 | from gpflow import mean_functions as mean_fns 23 | from gpflow.conditionals import conditional, Kuu, Kuf 24 | from GPt.gpssm import GPSSM 25 | 26 | 27 | def general_prepare(self): 28 | Y = np.random.randn(self.T, self.D) 29 | inputs = np.random.randn(self.T - 1, self.input_dim) if self.input_dim > 0 else None 30 | Q_diag = np.random.randn(self.E) ** 2. 31 | kern = [gp.kernels.RBF(self.E + self.input_dim, ARD=True) for _ in range(self.E)] 32 | for k in kern: k.lengthscales = np.random.rand(self.E + self.input_dim) * 2. 33 | for k in kern: k.variance = np.random.rand() 34 | Z = np.random.randn(self.E, self.n_ind_pts, self.E + self.input_dim) 35 | mean_fn = mean_fns.Linear(np.random.randn(self.E, self.E), np.random.randn(self.E)) 36 | Umu = np.random.randn(self.E, self.n_ind_pts) 37 | Ucov_chol = np.random.randn(self.E, self.n_ind_pts, self.n_ind_pts) 38 | Ucov_chol = np.linalg.cholesky(np.matmul(Ucov_chol, np.transpose(Ucov_chol, [0, 2, 1]))) 39 | qx1_mu = np.random.randn(self.E) 40 | qx1_cov = np.random.randn(self.E, self.E) 41 | qx1_cov = qx1_cov @ qx1_cov.T 42 | As = np.random.randn(self.T - 1, self.E) 43 | bs = np.random.randn(self.T - 1, self.E) 44 | Ss = np.random.randn(self.T - 1, self.E) ** 2. 45 | m = GPSSM(self.E, Y, inputs=inputs, emissions=None, px1_mu=None, px1_cov=None, 46 | kern=kern, Z=Z, n_ind_pts=None, mean_fn=mean_fn, 47 | Q_diag=Q_diag, Umu=Umu, Ucov_chol=Ucov_chol, 48 | qx1_mu=qx1_mu, qx1_cov=qx1_cov, As=As, bs=bs, Ss=Ss, n_samples=self.n_samples, seed=self.seed) 49 | _ = m.compute_log_likelihood() 50 | return m 51 | 52 | 53 | class FactorizedSamplingTest(GPflowTestCase): 54 | def __init__(self, *args, **kwargs): 55 | super().__init__(*args, **kwargs) 56 | self.seed = 0 57 | np.random.seed(self.seed) 58 | tf.set_random_seed(self.seed) 59 | self.T, self.D, self.E, self.input_dim = 11, 3, 2, 0 60 | self.n_samples, self.n_ind_pts = int(1e3), 4 61 | self.white = True 62 | 63 | def prepare(self): 64 | return general_prepare(self) 65 | 66 | def test_X_samples(self): 67 | with self.test_context() as sess: 68 | shape = [self.T, self.n_samples, self.E] 69 | 70 | m = self.prepare() 71 | 72 | qe_samples = tfd.MultivariateNormalDiag(loc=tf.zeros(shape[1:], dtype=gp.settings.float_type)) 73 | qe_samples = sess.run(qe_samples.sample(self.T, seed=self.seed)) 74 | X_tmin1 = tf.placeholder(gp.settings.float_type, shape=shape[1:]) 75 | Kzz = sess.run(Kuu(m.Z, m.kern, jitter=gp.settings.numerics.jitter_level)) 76 | Kzz_inv = np.linalg.inv(np.linalg.cholesky(Kzz)) if self.white else np.linalg.inv(Kzz) # E x M x M 77 | X_samples_np = np.zeros(shape) 78 | X_samples_np[0] = m.qx1_mu.value + qe_samples[0] @ m.qx1_cov_chol.value.T 79 | for t in range(self.T-1): 80 | Kzx = sess.run(Kuf(m.Z, m.kern, X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) # E x M x N 81 | Kxx = sess.run(m.kern.Kdiag(X_tmin1, full_output_cov=False), feed_dict={X_tmin1:X_samples_np[t]}) # N x E 82 | mean_x = sess.run(m.mean_fn(X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) 83 | Kzz_invKzx = np.matmul(Kzz_inv, Kzx) # E x M x N 84 | mu = mean_x + np.sum(Kzz_invKzx * m.Umu.value[..., None], 1).T # N x E 85 | mu = m.As.value[t] * mu + m.bs.value[t] 86 | if self.white: 87 | cov = np.matmul(np.transpose(m.Ucov_chol.value, [0, 2, 1]), Kzz_invKzx) 88 | cov = np.sum(np.square(cov) - np.square(Kzz_invKzx), 1) 89 | else: 90 | cov = np.matmul(m.Ucov_chol.value, np.transpose(m.Ucov_chol.value, [0, 2, 1])) - Kzz # E x M x M 91 | cov = np.sum(np.matmul(cov, Kzz_invKzx) * Kzz_invKzx, 1) 92 | cov = Kxx + cov.T 93 | cov = np.square(m.As.value[t]) * cov + np.square(m.S_chols.value[t]) # N x E 94 | X_samples_np[t+1] = mu + qe_samples[t+1] * np.sqrt(cov) 95 | 96 | X_samples_tf = sess.run(m._build_linear_time_q_sample(sample_u=False))[0] 97 | 98 | assert_allclose(X_samples_tf, X_samples_np) 99 | 100 | def test_X_F_samples(self): 101 | with self.test_context() as sess: 102 | shape = [self.T, self.n_samples, self.E] 103 | 104 | m = self.prepare() 105 | 106 | qe_samples = tfd.MultivariateNormalDiag(loc=tf.zeros(shape[1:], dtype=gp.settings.float_type)) 107 | qe_samples_X = sess.run(qe_samples.sample(self.T, seed=self.seed)) 108 | qe_samples_F = sess.run(qe_samples.sample(self.T-1, seed=self.seed)) 109 | X_tmin1 = tf.placeholder(gp.settings.float_type, shape=shape[1:]) 110 | Kzz = sess.run(Kuu(m.Z, m.kern, jitter=gp.settings.numerics.jitter_level)) 111 | Kzz_inv = np.linalg.inv(np.linalg.cholesky(Kzz)) if self.white else np.linalg.inv(Kzz) # E x M x M 112 | X_samples_np = np.zeros(shape) 113 | X_samples_np[0] = m.qx1_mu.value + qe_samples_X[0] @ m.qx1_cov_chol.value.T 114 | F_samples_np = np.zeros([self.T-1] + shape[1:]) 115 | for t in range(self.T-1): 116 | Kzx = sess.run(Kuf(m.Z, m.kern, X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) # E x M x N 117 | Kxx = sess.run(m.kern.Kdiag(X_tmin1, full_output_cov=False), feed_dict={X_tmin1:X_samples_np[t]}) # N x E 118 | mean_x = sess.run(m.mean_fn(X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) 119 | Kzz_invKzx = np.matmul(Kzz_inv, Kzx) # E x M x N 120 | mu = mean_x + np.sum(Kzz_invKzx * m.Umu.value[..., None], 1).T # N x E 121 | if self.white: 122 | cov = np.matmul(np.transpose(m.Ucov_chol.value, [0, 2, 1]), Kzz_invKzx) 123 | cov = np.sum(np.square(cov) - np.square(Kzz_invKzx), 1) 124 | else: 125 | cov = np.matmul(m.Ucov_chol.value, np.transpose(m.Ucov_chol.value, [0, 2, 1])) - Kzz # E x M x M 126 | cov = np.sum(np.matmul(cov, Kzz_invKzx) * Kzz_invKzx, 1) 127 | cov = Kxx + cov.T 128 | F_samples_np[t] = mu + qe_samples_F[t] * np.sqrt(cov) 129 | 130 | x_mu = m.As.value[t] * F_samples_np[t] + m.bs.value[t] 131 | X_samples_np[t+1] = x_mu + qe_samples_X[t+1] * m.S_chols.value[t] 132 | 133 | X_samples_tf, F_samples_tf = sess.run(m._build_linear_time_q_sample(sample_f=True, sample_u=False)) 134 | 135 | assert_allclose(X_samples_tf, X_samples_np) 136 | assert_allclose(F_samples_tf, F_samples_np) 137 | 138 | def test_f_moments(self): 139 | with self.test_context() as sess: 140 | m = self.prepare() 141 | X_samples, F_samples, f_mus, f_vars = sess.run( 142 | m._build_linear_time_q_sample(return_f_moments=True, sample_f=True, sample_u=False)) 143 | f_mus_batch, f_vars_batch = conditional(tf.reshape(X_samples[:-1], [-1, self.E]), 144 | m.Z, m.kern, m.Umu.constrained_tensor, white=self.white, 145 | q_sqrt=m.Ucov_chol.constrained_tensor) 146 | f_mus_batch += m.mean_fn(tf.reshape(X_samples[:-1], [-1, self.E])) 147 | 148 | f_mus_batch = sess.run(f_mus_batch).reshape(self.T - 1, self.n_samples, self.E) 149 | f_vars_batch = sess.run(f_vars_batch).reshape(self.T - 1, self.n_samples, self.E) 150 | 151 | assert_allclose(f_mus, f_mus_batch) 152 | assert_allclose(f_vars, f_vars_batch) 153 | 154 | X_samples_2, f_mus_2, f_vars_2 = sess.run( 155 | m._build_linear_time_q_sample(return_f_moments=True, sample_f=False, sample_u=False)) 156 | f_mus_batch_2, f_vars_batch_2 = conditional(tf.reshape(X_samples_2[:-1], [-1, self.E]), 157 | m.Z, m.kern, m.Umu.constrained_tensor, white=self.white, 158 | q_sqrt=m.Ucov_chol.constrained_tensor) 159 | f_mus_batch_2 += m.mean_fn(tf.reshape(X_samples_2[:-1], [-1, self.E])) 160 | 161 | f_mus_batch_2 = sess.run(f_mus_batch_2).reshape(self.T - 1, self.n_samples, self.E) 162 | f_vars_batch_2 = sess.run(f_vars_batch_2).reshape(self.T - 1, self.n_samples, self.E) 163 | 164 | assert_allclose(f_mus_2, f_mus_batch_2) 165 | assert_allclose(f_vars_2, f_vars_batch_2) 166 | 167 | 168 | class uDependenceSamplingTest(GPflowTestCase): 169 | def __init__(self, *args, **kwargs): 170 | super().__init__(*args, **kwargs) 171 | self.seed = 0 172 | np.random.seed(self.seed) 173 | tf.set_random_seed(self.seed) 174 | self.T, self.D, self.E, self.input_dim = 11, 3, 2, 0 175 | self.n_samples, self.n_ind_pts = int(1e3), 4 176 | self.white = True 177 | 178 | def prepare(self): 179 | return general_prepare(self) 180 | 181 | def test_X_samples(self): 182 | with self.test_context() as sess: 183 | shape = [self.T, self.n_samples, self.E] 184 | 185 | m = self.prepare() 186 | 187 | qe_samples = tfd.MultivariateNormalDiag(loc=tf.zeros(shape[1:], dtype=gp.settings.float_type)) 188 | qe_samples = sess.run(qe_samples.sample(self.T, seed=self.seed)) 189 | U_samples_np = sess.run(tfd.MultivariateNormalDiag(loc=tf.zeros( 190 | [self.E, self.n_ind_pts, self.n_samples], dtype=gp.settings.float_type)).sample(seed=self.seed)) 191 | U_samples_np = m.Umu.value[:, :, None] + np.matmul(m.Ucov_chol.value, U_samples_np) 192 | 193 | X_tmin1 = tf.placeholder(gp.settings.float_type, shape=shape[1:]) 194 | Kzz = sess.run(Kuu(m.Z, m.kern, jitter=gp.settings.numerics.jitter_level)) 195 | Kzz_inv = np.linalg.inv(np.linalg.cholesky(Kzz)) if self.white else np.linalg.inv(Kzz) # E x M x M 196 | X_samples_np = np.zeros(shape) 197 | X_samples_np[0] = m.qx1_mu.value + qe_samples[0] @ m.qx1_cov_chol.value.T 198 | for t in range(self.T-1): 199 | Kzx = sess.run(Kuf(m.Z, m.kern, X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) # E x M x N 200 | Kxx = sess.run(m.kern.Kdiag(X_tmin1, full_output_cov=False), feed_dict={X_tmin1:X_samples_np[t]}) # N x E 201 | mean_x = sess.run(m.mean_fn(X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) 202 | Kzz_invKzx = np.matmul(Kzz_inv, Kzx) # E x M x N 203 | mu = mean_x + np.sum(Kzz_invKzx * U_samples_np, 1).T # N x E 204 | mu = m.As.value[t] * mu + m.bs.value[t] 205 | if self.white: 206 | cov = np.sum(np.square(Kzz_invKzx), 1) 207 | else: 208 | cov = np.sum(Kzz_invKzx * Kzx, 1) 209 | cov = Kxx - cov.T 210 | cov = np.square(m.As.value[t]) * cov + np.square(m.S_chols.value[t]) # N x E 211 | X_samples_np[t+1] = mu + qe_samples[t+1] * np.sqrt(cov) 212 | 213 | X_samples_tf, U_samples_tf = sess.run(m._build_linear_time_q_sample(sample_u=True, return_u=True)) 214 | 215 | assert_allclose(X_samples_tf, X_samples_np) 216 | assert_allclose(U_samples_tf, U_samples_np) 217 | 218 | def test_X_F_samples(self): 219 | with self.test_context() as sess: 220 | shape = [self.T, self.n_samples, self.E] 221 | 222 | m = self.prepare() 223 | 224 | qe_samples = tfd.MultivariateNormalDiag(loc=tf.zeros(shape[1:], dtype=gp.settings.float_type)) 225 | qe_samples_X = sess.run(qe_samples.sample(self.T, seed=self.seed)) 226 | qe_samples_F = sess.run(qe_samples.sample(self.T-1, seed=self.seed)) 227 | U_samples_np = sess.run(tfd.MultivariateNormalDiag(loc=tf.zeros( 228 | [self.E, self.n_ind_pts, self.n_samples], dtype=gp.settings.float_type)).sample(seed=self.seed)) 229 | U_samples_np = m.Umu.value[:, :, None] + np.matmul(m.Ucov_chol.value, U_samples_np) 230 | 231 | X_tmin1 = tf.placeholder(gp.settings.float_type, shape=shape[1:]) 232 | Kzz = sess.run(Kuu(m.Z, m.kern, jitter=gp.settings.numerics.jitter_level)) 233 | Kzz_inv = np.linalg.inv(np.linalg.cholesky(Kzz)) if self.white else np.linalg.inv(Kzz) # E x M x M 234 | X_samples_np = np.zeros(shape) 235 | X_samples_np[0] = m.qx1_mu.value + qe_samples_X[0] @ m.qx1_cov_chol.value.T 236 | F_samples_np = np.zeros([self.T-1] + shape[1:]) 237 | for t in range(self.T-1): 238 | Kzx = sess.run(Kuf(m.Z, m.kern, X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) # E x M x N 239 | Kxx = sess.run(m.kern.Kdiag(X_tmin1, full_output_cov=False), feed_dict={X_tmin1:X_samples_np[t]}) # N x E 240 | mean_x = sess.run(m.mean_fn(X_tmin1), feed_dict={X_tmin1:X_samples_np[t]}) 241 | Kzz_invKzx = np.matmul(Kzz_inv, Kzx) # E x M x N 242 | mu = mean_x + np.sum(Kzz_invKzx * U_samples_np, 1).T # N x E 243 | if self.white: 244 | cov = np.sum(np.square(Kzz_invKzx), 1) 245 | else: 246 | cov = np.sum(Kzz_invKzx * Kzx, 1) 247 | cov = Kxx - cov.T 248 | F_samples_np[t] = mu + qe_samples_F[t] * np.sqrt(cov) 249 | 250 | x_mu = m.As.value[t] * F_samples_np[t] + m.bs.value[t] 251 | X_samples_np[t+1] = x_mu + qe_samples_X[t+1] * m.S_chols.value[t] 252 | 253 | X_samples_tf, F_samples_tf, U_samples_tf = sess.run( 254 | m._build_linear_time_q_sample(sample_f=True, sample_u=True, return_u=True)) 255 | 256 | assert_allclose(X_samples_tf, X_samples_np) 257 | assert_allclose(F_samples_tf, F_samples_np) 258 | assert_allclose(U_samples_tf, U_samples_np) 259 | 260 | def test_f_moments(self): 261 | with self.test_context() as sess: 262 | m = self.prepare() 263 | X_samples, F_samples, f_mus, f_vars, U_samples = sess.run( 264 | m._build_linear_time_q_sample(return_f_moments=True, sample_f=True, sample_u=True, return_u=True)) 265 | 266 | X_samples_2, f_mus_2, f_vars_2, U_samples_2 = sess.run( 267 | m._build_linear_time_q_sample(return_f_moments=True, sample_f=False, sample_u=True, return_u=True)) 268 | 269 | def single_t_moments(X, U_samples): 270 | f_mu, f_var = conditional(X, m.Z, m.kern, tf.constant(U_samples, dtype=gp.settings.float_type), 271 | q_sqrt=None, white=self.white) 272 | f_mu += m.mean_fn(X) 273 | return f_mu, f_var 274 | 275 | f_mus_batch, f_vars_batch = sess.run( 276 | tf.map_fn(lambda X: single_t_moments(X, U_samples), 277 | tf.constant(X_samples[:-1], dtype=gp.settings.float_type), 278 | dtype=(gp.settings.float_type, gp.settings.float_type))) 279 | 280 | f_mus_batch_2, f_vars_batch_2 = sess.run( 281 | tf.map_fn(lambda X: single_t_moments(X, U_samples_2), 282 | tf.constant(X_samples_2[:-1], dtype=gp.settings.float_type), 283 | dtype=(gp.settings.float_type, gp.settings.float_type))) 284 | 285 | assert_allclose(f_mus, f_mus_batch) 286 | assert_allclose(f_vars, f_vars_batch) 287 | 288 | assert_allclose(f_mus_2, f_mus_batch_2) 289 | assert_allclose(f_vars_2, f_vars_batch_2) 290 | 291 | 292 | class JointSamplingTest(GPflowTestCase): 293 | def __init__(self, *args, **kwargs): 294 | super().__init__(*args, **kwargs) 295 | self.seed = 0 296 | np.random.seed(self.seed) 297 | tf.set_random_seed(self.seed) 298 | self.T, self.D, self.E, self.input_dim = 6, 4, 3, 0 299 | self.n_samples, self.n_ind_pts = int(1e1), 10 300 | self.white = True 301 | 302 | def prepare(self): 303 | return general_prepare(self) 304 | 305 | def test_joint_samples(self): 306 | with self.test_context() as sess: 307 | shape = [self.T, self.n_samples, self.E] 308 | 309 | m = self.prepare() 310 | 311 | white_samples_X = tfd.MultivariateNormalDiag( 312 | loc=tf.zeros((self.n_samples, self.T, self.E), dtype=gp.settings.float_type)).sample(seed=self.seed) 313 | white_samples_X = np.transpose(sess.run(white_samples_X), [1, 0, 2]) 314 | 315 | white_samples_F = tfd.MultivariateNormalDiag( 316 | loc=tf.zeros((self.n_samples, self.T - 1, self.E), dtype=gp.settings.float_type)).sample(seed=self.seed) 317 | white_samples_F = np.transpose(sess.run(white_samples_F), [1, 0, 2]) 318 | 319 | X_buff = tf.placeholder(gp.settings.float_type, shape=[None, self.E]) 320 | Kzz = sess.run(Kuu(m.Z, m.kern, jitter=gp.settings.numerics.jitter_level)) 321 | Kzz_inv = np.linalg.inv(np.linalg.cholesky(Kzz)) if self.white else np.linalg.inv(Kzz) # E x M x M 322 | X_samples_np = np.zeros(shape) 323 | F_samples_np = np.zeros([self.T - 1] + shape[1:]) 324 | f_mus_np = np.zeros([self.T - 1] + shape[1:]) 325 | f_vars_np = np.zeros([self.T - 1] + shape[1:]) 326 | 327 | X_samples_np[0] = white_samples_X[0] @ m.qx1_cov_chol.value.T + m.qx1_mu.value 328 | 329 | Kzx = sess.run(Kuf(m.Z, m.kern, X_buff), feed_dict={X_buff: X_samples_np[0]}) # E x M x N 330 | Kxx = sess.run(m.kern.Kdiag(X_buff, full_output_cov=False), feed_dict={X_buff: X_samples_np[0]}) # N x E 331 | mean_x = sess.run(m.mean_fn(X_buff), feed_dict={X_buff: X_samples_np[0]}) # N x E 332 | Kzz_invKzx = np.matmul(Kzz_inv, Kzx) # E x M x N 333 | f_mus_np[0] = mean_x + np.sum(Kzz_invKzx * m.Umu.value[..., None], 1).T # N x E 334 | if self.white: 335 | f_var = np.matmul(np.transpose(m.Ucov_chol.value, [0, 2, 1]), Kzz_invKzx) 336 | f_var = np.sum(np.square(f_var) - np.square(Kzz_invKzx), 1) 337 | else: 338 | f_var = np.matmul(m.Ucov_chol.value, np.transpose(m.Ucov_chol.value, [0, 2, 1])) - Kzz # E x M x M 339 | f_var = np.sum(np.matmul(f_var, Kzz_invKzx) * Kzz_invKzx, 1) 340 | f_vars_np[0] = Kxx + f_var.T 341 | F_samples_np[0] = f_mus_np[0] + white_samples_F[0] * np.sqrt(f_vars_np[0]) 342 | X_samples_np[1] = m.As.value[0] * F_samples_np[0] + m.bs.value[0] + white_samples_X[1] * m.S_chols.value[0] 343 | 344 | def single_sample_f_cond(X, F): 345 | feed_dict = {X_buff: X} 346 | Kzx = sess.run(Kuf(m.Z, m.kern, X_buff), feed_dict=feed_dict) # E x M x t+1 347 | Kxx = sess.run(m.kern.K(X_buff, full_output_cov=False), feed_dict=feed_dict) # E x t+1 x t+1 348 | mean_x = sess.run(m.mean_fn(X_buff), feed_dict=feed_dict) # t+1 x E 349 | Kzz_invKzx = np.matmul(Kzz_inv, Kzx) # E x M x t+1 350 | f_mu_joint = mean_x + np.sum(Kzz_invKzx * m.Umu.value[..., None], 1).T # t+1 x E 351 | if self.white: 352 | f_cov_joint = np.matmul(np.transpose(m.Ucov_chol.value, [0, 2, 1]), Kzz_invKzx) # E x M x t+1 353 | f_cov_joint = np.matmul(np.transpose(f_cov_joint, [0, 2, 1]), f_cov_joint) # E x t+1 x t+1 354 | f_cov_joint -= np.matmul(np.transpose(Kzz_invKzx, [0, 2, 1]), Kzz_invKzx) # E x t+1 x t+1 355 | else: 356 | f_cov_joint = np.matmul(m.Ucov_chol.value, np.transpose(m.Ucov_chol.value, [0, 2, 1])) - Kzz # E x M x M 357 | f_cov_joint = np.matmul(np.matmul(np.transpose(Kzz_invKzx, [0, 2, 1]), f_cov_joint), Kzz_invKzx) # E x t+1 x t+1 358 | f_cov_joint = Kxx + f_cov_joint # E x t+1 x t+1 359 | 360 | C_F_inv_C_F_ft = np.linalg.solve(f_cov_joint[:, :-1, :-1], f_cov_joint[:, :-1, -1:None])[:, :, 0] # E x t 361 | F_min_Fmu = F - f_mu_joint[:-1] 362 | f_mu = f_mu_joint[-1] + np.sum(C_F_inv_C_F_ft * F_min_Fmu.T, -1) # E 363 | f_var = f_cov_joint[:, -1, -1] - np.sum(C_F_inv_C_F_ft * f_cov_joint[:, :-1, -1], -1) # E 364 | return f_mu, f_var 365 | 366 | for t in range(1, self.T-1): 367 | for n in range(self.n_samples): 368 | f_mus_np[t, n], f_vars_np[t, n] = single_sample_f_cond(X_samples_np[:t+1, n], F_samples_np[:t, n]) 369 | 370 | F_samples_np[t, n] = f_mus_np[t, n] + white_samples_F[t, n] * np.sqrt(f_vars_np[t, n]) 371 | 372 | X_samples_np[t+1, n] = m.As.value[t] * F_samples_np[t, n] + m.bs.value[t] \ 373 | + white_samples_X[t+1, n] * m.S_chols.value[t] 374 | 375 | X_samples_tf, F_samples_tf, f_mus_tf, f_vars_tf = sess.run( 376 | m._build_cubic_time_q_sample(return_f_moments=True, sample_u=False, add_jitter=False)) 377 | 378 | assert_allclose(X_samples_tf, X_samples_np) 379 | assert_allclose(F_samples_tf, F_samples_np) 380 | assert_allclose(f_mus_tf, f_mus_np) 381 | assert_allclose(f_vars_tf, f_vars_np) 382 | 383 | def test_joint_samples_sample_u(self): 384 | with self.test_context() as sess: 385 | shape = [self.T, self.n_samples, self.E] 386 | 387 | m = self.prepare() 388 | 389 | white_samples_X = tfd.MultivariateNormalDiag( 390 | loc=tf.zeros((self.n_samples, self.T, self.E), dtype=gp.settings.float_type)).sample(seed=self.seed) 391 | white_samples_X = np.transpose(sess.run(white_samples_X), [1, 0, 2]) 392 | 393 | white_samples_F = tfd.MultivariateNormalDiag( 394 | loc=tf.zeros((self.n_samples, self.T - 1, self.E), dtype=gp.settings.float_type)).sample(seed=self.seed) 395 | white_samples_F = np.transpose(sess.run(white_samples_F), [1, 0, 2]) 396 | 397 | U_samples_np = sess.run(tfd.MultivariateNormalDiag( 398 | loc=tf.zeros((self.E, self.n_ind_pts, self.n_samples), dtype=gp.settings.float_type) 399 | ).sample(seed=self.seed)) 400 | U_samples_np = m.Umu.value[:, :, None] + np.matmul(m.Ucov_chol.value, U_samples_np) 401 | 402 | Kzz = sess.run(Kuu(m.Z, m.kern, jitter=gp.settings.numerics.jitter_level)) 403 | 404 | if self.white: 405 | Kzz_chol = np.linalg.cholesky(Kzz) 406 | U_samples_np = np.matmul(Kzz_chol, U_samples_np) 407 | 408 | X_buff = tf.placeholder(gp.settings.float_type, shape=[None, self.E]) 409 | 410 | X_samples_np = np.zeros(shape) 411 | F_samples_np = np.zeros([self.T - 1] + shape[1:]) 412 | f_mus_np = np.zeros([self.T - 1] + shape[1:]) 413 | f_vars_np = np.zeros([self.T - 1] + shape[1:]) 414 | 415 | X_samples_np[0] = white_samples_X[0] @ m.qx1_cov_chol.value.T + m.qx1_mu.value 416 | 417 | def single_sample_f_cond(K, X, F, U): 418 | feed_dict = {X_buff: X} 419 | Kzx = sess.run(Kuf(m.Z, m.kern, X_buff[-1:]), feed_dict=feed_dict)[:, :, 0] # E x M 420 | Kxx = sess.run(m.kern.K(X_buff, X_buff[-1:], full_output_cov=False), feed_dict=feed_dict)[:, :, 0] # E x (t+1) 421 | 422 | K_vector = np.concatenate([Kzx, Kxx], -1) # E x (M+t+1) 423 | mean_x = sess.run(m.mean_fn(X_buff), feed_dict=feed_dict) 424 | UF = (F - mean_x[:-1]).T 425 | UF = np.concatenate([U, UF], -1) # E x (M+t) 426 | 427 | Kinv_UF_Kvec = np.linalg.solve(K, np.stack([UF, K_vector[:, :-1]], -1)) 428 | f_mu_f_var = np.sum(K_vector[:, :-1, None] * Kinv_UF_Kvec, -2) 429 | 430 | f_mu = mean_x[-1] + f_mu_f_var[:, 0] 431 | f_var = K_vector[:, -1] - f_mu_f_var[:, 1] 432 | 433 | K = np.concatenate([K, K_vector[:, :-1, None]], -1) # E x (M+t) x (M+t+1) 434 | K = np.concatenate([K, K_vector[:, None, :]], -2) # E x (M+t+1) x (M+t+1) 435 | 436 | return K, f_mu, f_var 437 | 438 | for n in range(self.n_samples): 439 | K = Kzz 440 | for t in range(self.T - 1): 441 | K, f_mus_np[t, n], f_vars_np[t, n] = single_sample_f_cond( 442 | K, X_samples_np[:t+1, n], F_samples_np[:t, n], U_samples_np[:, :, n]) 443 | 444 | F_samples_np[t, n] = f_mus_np[t, n] + white_samples_F[t, n] * np.sqrt(f_vars_np[t, n]) 445 | 446 | X_samples_np[t + 1, n] = m.As.value[t] * F_samples_np[t, n] + m.bs.value[t] \ 447 | + white_samples_X[t + 1, n] * m.S_chols.value[t] 448 | 449 | if self.white: 450 | U_samples_np = np.linalg.solve(Kzz_chol, U_samples_np) 451 | 452 | X_samples_tf, F_samples_tf, f_mus_tf, f_vars_tf, U_samples_tf = sess.run( 453 | m._build_cubic_time_q_sample(return_f_moments=True, sample_u=True, return_u=True, add_jitter=False)) 454 | 455 | assert_allclose(X_samples_tf, X_samples_np) 456 | assert_allclose(F_samples_tf, F_samples_np) 457 | assert_allclose(f_mus_tf, f_mus_np) 458 | assert_allclose(f_vars_tf, f_vars_np) 459 | assert_allclose(U_samples_tf, U_samples_np) 460 | 461 | 462 | if __name__ == '__main__': 463 | tf.test.main() 464 | -------------------------------------------------------------------------------- /GPt/ssm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | from tensorflow_probability import distributions as tfd 19 | import gpflow as gp 20 | from gpflow import settings, params_as_tensors, autoflow 21 | from gpflow import transforms as gtf 22 | from gpflow.logdensities import mvn_logp, sum_mvn_logp, diag_mvn_logp 23 | from .transitions import GPTransitions 24 | from .emissions import GaussianEmissions 25 | from .utils import extract_cov_blocks 26 | from .KL import KL 27 | 28 | 29 | class SSM(gp.models.Model): 30 | """State-Space Model base class. Used for sampling, no built-in inference.""" 31 | def __init__(self, X_init, Y, transitions, inputs=None, emissions=None, px1_mu=None, px1_cov=None, name=None): 32 | super().__init__(name=name) 33 | self.T_latent, self.latent_dim = X_init.shape 34 | self.T, self.obs_dim = Y.shape 35 | 36 | self.transitions = transitions 37 | self.emissions = emissions or GaussianEmissions(self.latent_dim, self.obs_dim) 38 | 39 | self.X = gp.Param(X_init) 40 | self.Y = gp.Param(Y, trainable=False) 41 | self.inputs = None if inputs is None else gp.Param(inputs, trainable=False) 42 | 43 | self.px1_mu = gp.Param(np.zeros(self.latent_dim) if px1_mu is None else px1_mu, trainable=False) 44 | self.px1_cov_chol = gp.Param(np.eye(self.latent_dim) if px1_cov is None 45 | else np.linalg.cholesky(px1_cov), trainable=False, 46 | transform=gtf.LowerTriangular(self.latent_dim, squeeze=True)) 47 | 48 | @params_as_tensors 49 | def _build_likelihood(self): 50 | log_px1 = sum_mvn_logp((self.X[0] - self.px1_mu)[:, None], self.px1_cov_chol) 51 | inputs = self.Y[:-1] if self.transitions.OBSERVATIONS_AS_INPUT else self.inputs 52 | log_pX = tf.reduce_sum(self.transitions.logp(self.X, inputs)) 53 | log_pY = tf.reduce_sum(self.emissions.logp(self.X[:self.T], self.Y)) 54 | return log_px1 + log_pX + log_pY 55 | 56 | def sample(self, T, N=1, x0_samples=None, inputs=None): 57 | N = N if x0_samples is None else x0_samples.shape[0] 58 | T = T if x0_samples is None else T + 1 59 | X = np.zeros((N, T, self.latent_dim)) 60 | Y = np.zeros((N, T, self.obs_dim)) 61 | 62 | tr_sample_conditional = self.transitions.sample_conditional(N) 63 | 64 | if x0_samples is None: 65 | X[:, 0] = self.px1_mu.value + np.random.randn(N, self.latent_dim) @ self.px1_cov_chol.value.T 66 | else: 67 | X[:, 0] = x0_samples 68 | Y[:, 0] = self.emissions.sample_conditional(X[:, 0]) 69 | 70 | for t in range(T - 1): 71 | if self.transitions.OBSERVATIONS_AS_INPUT: 72 | input = Y[:, t] 73 | elif inputs is None: 74 | input = None if self.inputs is None else self.inputs.value[t] 75 | else: 76 | input = inputs[t] 77 | X[:, t + 1] = tr_sample_conditional(X[:, t], input=input) 78 | Y[:, t + 1] = self.emissions.sample_conditional(X[:, t + 1]) 79 | 80 | return X, Y 81 | 82 | 83 | class SSM_AG(SSM): 84 | """ 85 | Analytic inference Gaussian State-Space Model. The variational posterior over the states q(X) is Gaussian. 86 | The variational lower bound is computed and optimized in closed form. 87 | """ 88 | def __init__(self, latent_dim, Y, transitions, 89 | T_latent=None, inputs=None, emissions=None, 90 | px1_mu=None, px1_cov=None, Xmu=None, Xchol=None, name=None): 91 | 92 | _Xmu = np.zeros((T_latent or Y.shape[0], latent_dim)) if Xmu is None else Xmu 93 | super().__init__(_Xmu, Y, transitions, inputs, emissions, px1_mu, px1_cov, name=name) 94 | 95 | _Xchol = np.eye(self.T_latent * self.latent_dim) if Xchol is None else Xchol 96 | if _Xchol.ndim == 1: 97 | self.Xchol = gp.Param(_Xchol) 98 | else: 99 | chol_transform = gtf.LowerTriangular(self.T_latent * self.latent_dim if _Xchol.ndim == 2 100 | else self.latent_dim, 101 | num_matrices=1 if _Xchol.ndim == 2 else self.T_latent, 102 | squeeze=_Xchol.ndim == 2) 103 | self.Xchol = gp.Param(_Xchol, transform=chol_transform) 104 | 105 | @params_as_tensors 106 | def _build_likelihood(self): 107 | transitions = self._build_transition_expectations() 108 | emissions = self._build_emission_expectations() 109 | entropy = self._build_entropy() 110 | x1_cross_entropy = self._build_x1_cross_entropy() 111 | return transitions + emissions + entropy - x1_cross_entropy 112 | 113 | @params_as_tensors 114 | def _build_transition_expectations(self): 115 | inputs = self.Y[:-1] if self.transitions.OBSERVATIONS_AS_INPUT else self.inputs 116 | if self.Xchol.shape.ndims == 1: 117 | Xcov = tf.reshape(tf.square(self.Xchol), [self.T_latent, self.latent_dim]) 118 | elif self.Xchol.shape.ndims == 2: 119 | Xcov = extract_cov_blocks(self.Xchol, self.T_latent, self.latent_dim, return_off_diag_blocks=True) 120 | elif self.Xchol.shape.ndims == 3: 121 | Xcov = tf.matmul(self.Xchol, self.Xchol, transpose_b=True) 122 | return tf.reduce_sum(self.transitions.variational_expectations(self.X, Xcov, inputs)) 123 | 124 | @params_as_tensors 125 | def _build_emission_expectations(self): 126 | if self.Xchol.shape.ndims == 1: 127 | Xcov = tf.reshape(tf.square(self.Xchol[:self.T * self.latent_dim]), [self.T, self.latent_dim]) # TxD 128 | 129 | elif self.Xchol.shape.ndims == 2: 130 | Xcutoff = self.T * self.latent_dim 131 | if self.emissions.REQUIRE_FULL_COV: 132 | Xcov = extract_cov_blocks(self.Xchol[:Xcutoff, :Xcutoff], self.T, self.latent_dim) 133 | else: 134 | Xcov = tf.reshape(tf.reduce_sum( 135 | tf.square(self.Xchol[:Xcutoff, :Xcutoff]), 1), [self.T, self.latent_dim]) # TxD 136 | 137 | elif self.Xchol.shape.ndims == 3: 138 | if self.emissions.REQUIRE_FULL_COV: 139 | Xcov = tf.matmul(self.Xchol[:self.T], self.Xchol[:self.T], transpose_b=True) 140 | else: 141 | Xcov = tf.reduce_sum(tf.square(self.Xchol[:self.T]), 2) # TxD 142 | 143 | return tf.reduce_sum(self.emissions.variational_expectations(self.X[:self.T], Xcov, self.Y)) 144 | 145 | @params_as_tensors 146 | def _build_entropy(self): 147 | const = 0.5 * self.T_latent * self.latent_dim * (1. + np.log(2. * np.pi)) 148 | if self.Xchol.shape.ndims == 1: 149 | logdet = tf.reduce_sum(tf.log(tf.abs(self.Xchol))) 150 | else: 151 | logdet = tf.reduce_sum(tf.log(tf.abs(tf.matrix_diag_part(self.Xchol)))) 152 | return const + logdet 153 | 154 | @params_as_tensors 155 | def _build_x1_cross_entropy(self): 156 | logp = sum_mvn_logp((self.X[0] - self.px1_mu)[:, None], self.px1_cov_chol) 157 | if self.Xchol.shape.ndims == 1: 158 | qx1_cov_chol = tf.matrix_diag(self.Xchol[:self.latent_dim]) 159 | elif self.Xchol.shape.ndims == 2: 160 | qx1_cov_chol = self.Xchol[:self.latent_dim, :self.latent_dim] 161 | elif self.Xchol.shape.ndims == 3: 162 | qx1_cov_chol = self.Xchol[0] 163 | p_cov_inv_q_cov = tf.matrix_triangular_solve(self.px1_cov_chol, qx1_cov_chol, lower=True) 164 | trace = tf.reduce_sum(tf.square(p_cov_inv_q_cov)) 165 | return 0.5 * trace - logp 166 | 167 | # autoflow methods: 168 | 169 | @autoflow() 170 | def compute_transition_expectations(self): 171 | return self._build_transition_expectations() 172 | 173 | @autoflow() 174 | def compute_emission_expectations(self): 175 | return self._build_emission_expectations() 176 | 177 | @autoflow() 178 | def compute_entropy(self): 179 | return self._build_entropy() 180 | 181 | @autoflow() 182 | def compute_x1_cross_entropy(self): 183 | return self._build_x1_cross_entropy() 184 | 185 | 186 | class SSM_SG(SSM_AG): 187 | """ 188 | Stochastic inference Gaussian State-Space Model. The variational posterior over the states q(X) is Gaussian. 189 | The variational lower bound is evaluated and optimized by sampling from the posterior q(X). 190 | """ 191 | def __init__(self, latent_dim, Y, transitions, 192 | T_latent=None, inputs=None, emissions=None, 193 | px1_mu=None, px1_cov=None, Xmu=None, Xchol=None, 194 | n_samples=100, seed=None, name=None): 195 | super().__init__(latent_dim, Y, transitions, T_latent, inputs, emissions, 196 | px1_mu, px1_cov, Xmu, Xchol, name=name) 197 | self.n_samples = n_samples 198 | self.seed = seed 199 | self._qx = None 200 | 201 | @property 202 | def qx(self): 203 | if self._qx is None: 204 | if self.Xchol.shape.ndims == 1: 205 | self._qx = tfd.MultivariateNormalDiag( 206 | loc=tf.reshape(self.X, [-1]), scale_diag=self.Xchol) 207 | else: 208 | self._qx = tfd.MultivariateNormalTriL( 209 | loc=self.X if self.Xchol.shape.ndims == 3 else tf.reshape(self.X, [-1]), 210 | scale_tril=self.Xchol) 211 | return self._qx 212 | 213 | @params_as_tensors 214 | def _build_likelihood(self): 215 | qx_samples = self._build_sample_qx() 216 | transitions = self._build_transition_expectations(qx_samples) 217 | emissions = self._build_emission_expectations(qx_samples) 218 | entropy = super()._build_entropy() 219 | x1_cross_entropy = super()._build_x1_cross_entropy() 220 | return transitions + emissions + entropy - x1_cross_entropy 221 | 222 | @params_as_tensors 223 | def _build_sample_qx(self, n_samples=None): 224 | qx_samples = self.qx.sample(n_samples or self.n_samples, seed=self.seed) 225 | if self.Xchol.shape.ndims < 3: 226 | return tf.reshape(qx_samples, [-1, self.T_latent, self.latent_dim]) 227 | return qx_samples 228 | 229 | @params_as_tensors 230 | def _build_transition_expectations(self, qx_samples): 231 | inputs = self.Y[:-1] if self.transitions.OBSERVATIONS_AS_INPUT else self.inputs 232 | logp = self.transitions.logp(qx_samples, inputs) 233 | return tf.reduce_mean(tf.reduce_sum(logp, 1)) 234 | 235 | @params_as_tensors 236 | def _build_emission_expectations(self, qx_samples): 237 | logp = self.emissions.logp(qx_samples[:, :self.T], self.Y) 238 | return tf.reduce_mean(tf.reduce_sum(logp, 1)) 239 | 240 | @params_as_tensors 241 | def _build_stochastic_entropy(self, qx_samples): 242 | return - tf.reduce_mean(self._build_density_evaluation(qx_samples)) 243 | 244 | @params_as_tensors 245 | def _build_stochastic_x1_cross_entropy(self, qx1_samples): 246 | return - tf.reduce_mean(mvn_logp( 247 | tf.transpose(qx1_samples - self.px1_mu), self.px1_cov_chol)) 248 | 249 | @params_as_tensors 250 | def _build_density_evaluation(self, qx_samples): 251 | if self.Xchol.shape.ndims < 3: 252 | return self.qx.log_prob( 253 | tf.reshape(qx_samples, [-1, self.T_latent * self.latent_dim])) 254 | return tf.reduce_sum(self.qx.log_prob(qx_samples), -1) 255 | 256 | # autoflow methods: 257 | 258 | @autoflow((settings.int_type, [])) 259 | def sample_qx(self, n_samples=None): 260 | return self._build_sample_qx(n_samples or self.n_samples) 261 | 262 | @autoflow((settings.float_type,)) 263 | def evaluate_sample_density(self, qx_samples): 264 | qx_samples = tf.reshape(qx_samples, [-1, self.T_latent, self.latent_dim]) 265 | return self._build_density_evaluation(qx_samples) 266 | 267 | @autoflow() 268 | def compute_transition_expectations(self): 269 | qx_samples = self._build_sample_qx(self.n_samples) 270 | return self._build_transition_expectations(qx_samples) 271 | 272 | @autoflow() 273 | def compute_emission_expectations(self): 274 | qx_samples = self._build_sample_qx(self.n_samples) 275 | return self._build_emission_expectations(qx_samples) 276 | 277 | @autoflow() 278 | def compute_stochastic_entropy(self): 279 | qx_samples = self._build_sample_qx(self.n_samples) 280 | return self._build_stochastic_entropy(qx_samples) 281 | 282 | @autoflow() 283 | def compute_stochastic_x1_cross_entropy(self): 284 | qx_samples = self._build_sample_qx(self.n_samples) 285 | return self._build_stochastic_x1_cross_entropy(qx_samples[:, 0]) 286 | 287 | @autoflow((settings.float_type,)) 288 | def compute_variational_bound_from_samples(self, qx_samples): 289 | qx_samples = tf.reshape(qx_samples, [-1, self.T_latent, self.latent_dim]) 290 | transitions = self._build_transition_expectations(qx_samples) 291 | emissions = self._build_emission_expectations(qx_samples) 292 | entropy = super()._build_entropy() 293 | x1_cross_entropy = super()._build_x1_cross_entropy() 294 | return (transitions, emissions, entropy, -x1_cross_entropy) 295 | 296 | @autoflow((settings.float_type,)) 297 | def compute_entropy_from_samples(self, qx_samples): 298 | qx_samples = tf.reshape(qx_samples, [-1, self.T_latent, self.latent_dim]) 299 | return - tf.reduce_mean(self._build_density_evaluation(qx_samples)) 300 | 301 | 302 | class SSM_SG_MultipleSequences(SSM_SG): 303 | """Equivalent to SSM_SG but for data which comes as many (potentially variable-length) independent sequences.""" 304 | def __init__(self, latent_dim, Y, transitions, 305 | T_latent=None, inputs=None, emissions=None, 306 | px1_mu=None, px1_cov=None, Xmu=None, Xchol=None, 307 | n_samples=100, batch_size=None, seed=None, name=None): 308 | 309 | super().__init__(latent_dim, Y[0], transitions, 310 | T_latent=None, inputs=None, emissions=emissions, 311 | px1_mu=px1_mu, px1_cov=None, Xmu=None, Xchol=None, 312 | n_samples=n_samples, seed=seed, name=name) 313 | 314 | self.T = [Y_s.shape[0] for Y_s in Y] 315 | self.T_latent = T_latent or self.T 316 | self.n_seq = len(self.T) 317 | self.T_tf = tf.constant(self.T, dtype=gp.settings.int_type) 318 | self.T_latent_tf = tf.constant(self.T_latent, dtype=gp.settings.int_type) 319 | self.sum_T = float(sum(self.T)) 320 | self.sum_T_latent = float(sum(self.T_latent)) 321 | self.batch_size = batch_size 322 | 323 | self.Y = gp.ParamList(Y, trainable=False) 324 | 325 | self.inputs = None if inputs is None else gp.ParamList(inputs, trainable=False) 326 | 327 | _Xmu = [np.zeros((T_s, self.latent_dim)) for T_s in self.T_latent] if Xmu is None else Xmu 328 | self.X = gp.ParamList(_Xmu) 329 | 330 | _Xchol = [np.eye(T_s * self.latent_dim) for T_s in self.T_latent] if Xchol is None else Xchol 331 | xc_tr = lambda xc: None if xc.ndim == 1 else gtf.LowerTriangular( 332 | xc.shape[-1], num_matrices=1 if xc.ndim == 2 else xc.shape[0], squeeze=xc.ndim == 2) 333 | self.Xchol = gp.ParamList([gp.Param(xc, transform=xc_tr(xc)) for xc in _Xchol]) 334 | 335 | self.multi_diag_px1_cov = False 336 | if isinstance(px1_cov, list): # different prior for each sequence 337 | _x1_cov = np.stack(px1_cov) 338 | _x1_cov = np.sqrt(_x1_cov) if _x1_cov.ndim == 2 else np.linalg.cholesky(_x1_cov) 339 | _transform = None if _x1_cov.ndim == 2 else gtf.LowerTriangular(self.latent_dim, num_matrices=self.n_seq) 340 | self.multi_diag_px1_cov = _x1_cov.ndim == 2 341 | elif isinstance(px1_cov, np.ndarray): # same prior for each sequence 342 | assert px1_cov.ndim < 3 343 | _x1_cov = np.sqrt(px1_cov) if px1_cov.ndim == 1 else np.linalg.cholesky(px1_cov) 344 | _transform = None if px1_cov.ndim == 1 else gtf.LowerTriangular(self.latent_dim, squeeze=True) 345 | else: 346 | _x1_cov = np.eye(self.latent_dim) 347 | _transform = gtf.LowerTriangular(self.latent_dim, squeeze=True) 348 | 349 | self.px1_cov_chol = gp.Param(_x1_cov, trainable=False, transform=_transform) 350 | 351 | @property 352 | def qx(self): 353 | if self._qx is None: 354 | self._qx = [] 355 | for s in range(self.n_seq): 356 | if self.Xchol[s].shape.ndims == 1: 357 | self._qx.append(tfd.MultivariateNormalDiag( 358 | loc=tf.reshape(self.X[s], [-1]), scale_diag=self.Xchol[s])) 359 | else: 360 | self._qx.append(tfd.MultivariateNormalTriL( 361 | loc=self.X[s] if self.Xchol[s].shape.ndims == 3 else tf.reshape(self.X[s], [-1]), 362 | scale_tril=self.Xchol[s])) 363 | return self._qx 364 | 365 | @params_as_tensors 366 | def _build_likelihood(self): 367 | batch_indices = None if self.batch_size is None else \ 368 | tf.random_shuffle(tf.range(self.n_seq), seed=self.seed)[:self.batch_size] 369 | 370 | qx_samples = self._build_sample_qx(batch_indices=batch_indices) 371 | 372 | transitions = self._build_transition_expectations(qx_samples, batch_indices=batch_indices) 373 | emissions = self._build_emission_expectations(qx_samples, batch_indices=batch_indices) 374 | entropy = self._build_entropy(batch_indices=batch_indices) 375 | x1_cross_entropy = self._build_x1_cross_entropy(batch_indices=batch_indices) 376 | return transitions + emissions + entropy - x1_cross_entropy 377 | 378 | @params_as_tensors 379 | def _build_sample_qx(self, n_samples=None, batch_indices=None): 380 | if n_samples is None: n_samples = self.n_samples 381 | qx_samples = [] 382 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 383 | b_s = s if batch_indices is None else batch_indices[s] 384 | list_of_samples = [self.qx[i].sample(n_samples, seed=self.seed) for i in range(self.n_seq)] 385 | qx_s = self.gather_from_list(list_of_samples, b_s) 386 | if self.gather_from_list(self.Xchol, b_s).shape.ndims < 3: 387 | qx_s = tf.reshape(qx_s, [-1, self.T_latent_tf[b_s], self.latent_dim]) 388 | qx_samples.append(qx_s) 389 | return qx_samples 390 | 391 | @params_as_tensors 392 | def _build_transition_expectations(self, qx_samples, batch_indices=None): 393 | logp_kwargs = {'subtract_KL_U': False} if isinstance(self.transitions, GPTransitions) else {} 394 | 395 | tr_expectations = 0. 396 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 397 | b_s = s if batch_indices is None else batch_indices[s] 398 | inputs = self.gather_from_list(self.Y, b_s)[:-1] if self.transitions.OBSERVATIONS_AS_INPUT \ 399 | else (None if self.inputs is None else self.gather_from_list(self.inputs, b_s)) 400 | logp = self.transitions.logp(qx_samples[s], inputs, **logp_kwargs) 401 | tr_expectations += tf.reduce_mean(tf.reduce_sum(logp, 1)) 402 | 403 | if batch_indices is not None: 404 | sum_T_l_batch = tf.cast(tf.reduce_sum(tf.gather(self.T_latent_tf, batch_indices)), gp.settings.float_type) 405 | tr_expectations *= (self.sum_T_latent - self.n_seq) / (sum_T_l_batch - self.batch_size) 406 | 407 | if isinstance(self.transitions, GPTransitions): 408 | KL_U = KL(self.transitions.Umu, self.transitions.Ucov_chol) 409 | tr_expectations -= KL_U 410 | return tr_expectations 411 | 412 | @params_as_tensors 413 | def _build_emission_expectations(self, qx_samples, batch_indices=None): 414 | em_expectations = 0. 415 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 416 | b_s = s if batch_indices is None else batch_indices[s] 417 | logp = self.emissions.logp(qx_samples[s][:, :self.T_tf[b_s]], self.gather_from_list(self.Y, b_s)) 418 | em_expectations += tf.reduce_mean(tf.reduce_sum(logp, 1)) 419 | 420 | if batch_indices is not None: 421 | sum_T_batch = tf.cast(tf.reduce_sum(tf.gather(self.T_tf, batch_indices)), gp.settings.float_type) 422 | em_expectations *= self.sum_T / sum_T_batch 423 | return em_expectations 424 | 425 | @params_as_tensors 426 | def _build_entropy(self, batch_indices=None): 427 | entropy = 0. 428 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 429 | b_s = s if batch_indices is None else batch_indices[s] 430 | T_latent_b_s = tf.cast(self.T_latent_tf[b_s], gp.settings.float_type) 431 | const = 0.5 * T_latent_b_s * self.latent_dim * (1. + np.log(2. * np.pi)) 432 | _Xchol = self.gather_from_list(self.Xchol, b_s) 433 | if _Xchol.shape.ndims == 1: 434 | logdet = tf.reduce_sum(tf.log(tf.abs(_Xchol))) 435 | else: 436 | logdet = tf.reduce_sum(tf.log(tf.abs(tf.matrix_diag_part(_Xchol)))) 437 | entropy = const + logdet 438 | 439 | if batch_indices is not None: 440 | sum_T_l_batch = tf.cast(tf.reduce_sum(tf.gather(self.T_latent_tf, batch_indices)), gp.settings.float_type) 441 | entropy *= self.sum_T_latent / sum_T_l_batch 442 | return entropy 443 | 444 | @params_as_tensors 445 | def _build_x1_cross_entropy(self, batch_indices=None): 446 | diag_px1 = self.px1_cov_chol.shape.ndims == 1 or self.multi_diag_px1_cov 447 | shared_px1 = (self.px1_cov_chol.shape.ndims < 3) and (not self.multi_diag_px1_cov) 448 | 449 | x1_ce = 0. 450 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 451 | b_s = s if batch_indices is None else batch_indices[s] 452 | _px1_mu = self.px1_mu if self.px1_mu.shape.ndims == 1 else self.px1_mu[b_s] 453 | _px1_cov_chol = self.px1_cov_chol if shared_px1 else self.px1_cov_chol[b_s] 454 | _qx1_mu = self.gather_from_list(self.X, b_s)[0] 455 | _qx1_cov_chol = self.gather_from_list(self.Xchol, b_s) 456 | assert _qx1_cov_chol.shape.ndims in {1, 2, 3} 457 | if _qx1_cov_chol.shape.ndims == 1: 458 | _qx1_cov_chol = _qx1_cov_chol[:self.latent_dim] 459 | _qx1_cov_chol = _qx1_cov_chol[:, None] if diag_px1 else tf.matrix_diag(_qx1_cov_chol) 460 | elif _qx1_cov_chol.shape.ndims == 2: 461 | _qx1_cov_chol = _qx1_cov_chol[:self.latent_dim, :self.latent_dim] 462 | elif _qx1_cov_chol.shape.ndims == 3: 463 | _qx1_cov_chol = _qx1_cov_chol[0] 464 | 465 | if diag_px1: 466 | logp = diag_mvn_logp(_qx1_mu - _px1_mu, _px1_cov_chol) 467 | trace = tf.reduce_sum(tf.square(_qx1_cov_chol / _px1_cov_chol[:, None])) 468 | else: 469 | logp = sum_mvn_logp((_qx1_mu - _px1_mu)[:, None], _px1_cov_chol) 470 | trace = tf.reduce_sum(tf.square( 471 | tf.matrix_triangular_solve(_px1_cov_chol, _qx1_cov_chol, lower=True))) 472 | x1_ce += 0.5 * trace - logp 473 | 474 | if batch_indices is not None: 475 | x1_ce *= float(self.n_seq) / float(self.batch_size) 476 | return x1_ce 477 | 478 | @params_as_tensors 479 | def _build_stochastic_entropy(self, qx_samples, batch_indices=None): 480 | entropy = - tf.reduce_sum(tf.reduce_mean(tf.stack(self._build_density_evaluation(qx_samples)), -1)) 481 | if batch_indices is not None: 482 | sum_T_l_batch = tf.cast(tf.reduce_sum(tf.gather(self.T_latent_tf, batch_indices)), gp.settings.float_type) 483 | entropy *= self.sum_T_latent / sum_T_l_batch 484 | return entropy 485 | 486 | @params_as_tensors 487 | def _build_stochastic_x1_cross_entropy(self, qx1_samples, batch_indices=None): 488 | diag_px1 = self.px1_cov_chol.shape.ndims == 1 or self.multi_diag_px1_cov 489 | if self.multi_diag_px1_cov or self.px1_cov_chol.shape.ndims == 3: 490 | x1_ce = 0. 491 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 492 | b_s = s if batch_indices is None else batch_indices[s] 493 | _px1_mu = self.px1_mu if self.px1_mu.shape.ndims == 1 else self.px1_mu[b_s] 494 | if diag_px1: 495 | _x1_ce = diag_mvn_logp(qx1_samples[s] - _px1_mu, self.px1_cov_chol[b_s]) 496 | else: 497 | _x1_ce = mvn_logp(tf.transpose(qx1_samples[s] - _px1_mu), self.px1_cov_chol[b_s]) 498 | x1_ce += tf.reduce_mean(_x1_ce) 499 | else: 500 | _px1_mu = self.px1_mu if self.px1_mu.shape.ndims == 1 else self.px1_mu[:, None, :] 501 | if diag_px1: 502 | x1_ce = diag_mvn_logp(qx1_samples - _px1_mu, self.px1_cov_chol) 503 | else: 504 | x1_ce = mvn_logp(tf.transpose(qx1_samples - _px1_mu, [2, 0, 1]), self.px1_cov_chol) 505 | x1_ce = tf.reduce_sum(tf.reduce_mean(x1_ce, -1)) 506 | 507 | if batch_indices is not None: 508 | x1_ce *= float(self.n_seq) / float(self.batch_size) 509 | return - x1_ce 510 | 511 | @params_as_tensors 512 | def _build_density_evaluation(self, qx_samples, batch_indices=None): 513 | densities = [] 514 | for s in range(self.n_seq if batch_indices is None else self.batch_size): 515 | b_s = s if batch_indices is None else batch_indices[s] 516 | if self.gather_from_list(self.Xchol, b_s).shape.ndims < 3: 517 | reshaped_samples = tf.reshape(qx_samples[s], [-1, self.T_latent[s] * self.latent_dim]) 518 | list_of_logp = [self.qx[i].log_prob(reshaped_samples) for i in range(self.n_seq)] 519 | densities.append(self.gather_from_list(list_of_logp, b_s)) 520 | else: 521 | list_of_logp = [self.qx[i].log_prob(qx_samples[s]) for i in range(self.n_seq)] 522 | densities.append(tf.reduce_sum(self.gather_from_list(list_of_logp, b_s), -1)) 523 | return densities 524 | 525 | def gather_from_list(self, obj_list, index): 526 | """ 527 | Warning: if index is not within range it returns first element of obj_list 528 | """ 529 | if isinstance(index, int): 530 | return obj_list[index] 531 | 532 | s_getter = lambda s: lambda: obj_list[s] 533 | recursive_getter = obj_list[0] 534 | for s in range(1, len(obj_list)): 535 | recursive_getter = tf.cond(tf.equal(index, s), s_getter(s), lambda: recursive_getter) 536 | return recursive_getter 537 | 538 | # autoflow methods: 539 | 540 | @autoflow() 541 | def compute_transition_expectations(self): 542 | raise NotImplementedError 543 | 544 | @autoflow() 545 | def compute_emission_expectations(self): 546 | raise NotImplementedError 547 | 548 | @autoflow() 549 | def compute_entropy(self): 550 | raise NotImplementedError 551 | 552 | @autoflow() 553 | def compute_x1_cross_entropy(self): 554 | raise NotImplementedError 555 | 556 | @autoflow((settings.int_type, [])) 557 | def sample_qx(self, n_samples=None): 558 | raise NotImplementedError 559 | 560 | @autoflow((settings.float_type,)) 561 | def evaluate_sample_density(self, qx_samples): 562 | raise NotImplementedError 563 | 564 | @autoflow() 565 | def compute_transition_expectations(self): 566 | raise NotImplementedError 567 | 568 | @autoflow() 569 | def compute_emission_expectations(self): 570 | raise NotImplementedError 571 | 572 | @autoflow() 573 | def compute_stochastic_entropy(self): 574 | raise NotImplementedError 575 | 576 | @autoflow() 577 | def compute_stochastic_x1_cross_entropy(self): 578 | raise NotImplementedError 579 | 580 | @autoflow((settings.float_type,)) 581 | def compute_variational_bound_from_samples(self, qx_samples): 582 | raise NotImplementedError 583 | 584 | @autoflow((settings.float_type,)) 585 | def compute_entropy_from_samples(self, qx_samples): 586 | raise NotImplementedError 587 | -------------------------------------------------------------------------------- /GPt/gpssm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 Alessandro Davide Ialongo (@ialong) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import numpy as np 17 | import tensorflow as tf 18 | import gpflow as gp 19 | import pandas as pd 20 | 21 | from tensorflow_probability import distributions as tfd 22 | 23 | from gpflow import Param, params_as_tensors 24 | from gpflow import transforms as gtf 25 | from gpflow import mean_functions as mean_fns 26 | from gpflow.conditionals import conditional 27 | from gpflow.multioutput.features import Kuu, Kuf 28 | from gpflow import settings as gps 29 | import gpflow.multioutput.kernels as mk 30 | import gpflow.multioutput.features as mf 31 | 32 | from .KL import KL, KL_samples 33 | from .emissions import GaussianEmissions 34 | 35 | 36 | class GPSSM(gp.models.Model): 37 | """Gaussian Process State-Space Model base class. Used for sampling, no built-in inference.""" 38 | def __init__(self, 39 | latent_dim, 40 | Y, 41 | inputs=None, 42 | emissions=None, 43 | px1_mu=None, px1_cov=None, 44 | kern=None, 45 | Z=None, n_ind_pts=100, 46 | mean_fn=None, 47 | Q_diag=None, 48 | Umu=None, Ucov_chol=None, 49 | qx1_mu=None, qx1_cov=None, 50 | As=None, bs=None, Ss=None, 51 | n_samples=100, 52 | seed=None, 53 | parallel_iterations=10, 54 | jitter=gps.numerics.jitter_level, 55 | name=None): 56 | 57 | super().__init__(name=name) 58 | 59 | self.latent_dim = latent_dim 60 | self.T, self.obs_dim = Y.shape 61 | self.Y = Param(Y, trainable=False) 62 | 63 | self.inputs = None if inputs is None else Param(inputs, trainable=False) 64 | self.input_dim = 0 if self.inputs is None else self.inputs.shape[1] 65 | 66 | self.qx1_mu = Param(np.zeros(self.latent_dim) if qx1_mu is None else qx1_mu) 67 | self.qx1_cov_chol = Param( 68 | np.eye(self.latent_dim) if qx1_cov is None else np.linalg.cholesky(qx1_cov), 69 | transform=gtf.LowerTriangular(self.latent_dim, squeeze=True)) 70 | 71 | self.As = Param(np.ones((self.T - 1, self.latent_dim)) if As is None else As) 72 | self.bs = Param(np.zeros((self.T - 1, self.latent_dim)) if bs is None else bs) 73 | 74 | self.Q_sqrt = Param(np.ones(self.latent_dim) if Q_diag is None else Q_diag ** 0.5, transform=gtf.positive) 75 | if Ss is False: 76 | self._S_chols = None 77 | else: 78 | self.S_chols = Param(np.tile(self.Q_sqrt.value.copy()[None, ...], [self.T - 1, 1]) if Ss is None 79 | else (np.sqrt(Ss) if Ss.ndim == 2 else np.linalg.cholesky(Ss)), 80 | transform=gtf.positive if (Ss is None or Ss.ndim == 2) 81 | else gtf.LowerTriangular(self.latent_dim, num_matrices=self.T - 1, squeeze=False)) 82 | 83 | self.emissions = emissions or GaussianEmissions(latent_dim=self.latent_dim, obs_dim=self.obs_dim) 84 | 85 | self.px1_mu = Param(np.zeros(self.latent_dim) if px1_mu is None else px1_mu, trainable=False) 86 | self.px1_cov_chol = None if px1_cov is None else \ 87 | Param(np.sqrt(px1_cov) if px1_cov.ndim == 1 else np.linalg.cholesky(px1_cov), trainable=False, 88 | transform=gtf.positive if px1_cov.ndim == 1 else gtf.LowerTriangular(self.latent_dim, squeeze=True)) 89 | 90 | self.n_samples = n_samples 91 | self.seed = seed 92 | self.parallel_iterations = parallel_iterations 93 | self.jitter = jitter 94 | 95 | # Inference-specific attributes (see gpssm_models.py for appropriate choices): 96 | nans = tf.constant(np.zeros((self.T, self.n_samples, self.latent_dim)) * np.nan, dtype=gps.float_type) 97 | self.sample_fn = lambda **kwargs: (nans, None) 98 | self.sample_kwargs = {} 99 | self.KL_fn = lambda *fs: tf.constant(np.nan, dtype=gps.float_type) 100 | 101 | # GP Transitions: 102 | self.n_ind_pts = n_ind_pts if Z is None else (Z[0].shape[-2] if isinstance(Z, list) else Z.shape[-2]) 103 | 104 | if isinstance(Z, np.ndarray) and Z.ndim == 2: 105 | self.Z = mf.SharedIndependentMof(gp.features.InducingPoints(Z)) 106 | else: 107 | Z_list = [np.random.randn(self.n_ind_pts, self.latent_dim + self.input_dim) 108 | for _ in range(self.latent_dim)] if Z is None else [z for z in Z] 109 | self.Z = mf.SeparateIndependentMof([gp.features.InducingPoints(z) for z in Z_list]) 110 | 111 | if isinstance(kern, gp.kernels.Kernel): 112 | self.kern = mk.SharedIndependentMok(kern, self.latent_dim) 113 | else: 114 | kern_list = kern or [gp.kernels.Matern32(self.latent_dim + self.input_dim, ARD=True) 115 | for _ in range(self.latent_dim)] 116 | self.kern = mk.SeparateIndependentMok(kern_list) 117 | 118 | self.mean_fn = mean_fn or mean_fns.Identity(self.latent_dim) 119 | self.Umu = Param(np.zeros((self.latent_dim, self.n_ind_pts)) if Umu is None else Umu) # (Lm^-1)(Umu - m(Z)) 120 | LT_transform = gtf.LowerTriangular(self.n_ind_pts, num_matrices=self.latent_dim, squeeze=False) 121 | self.Ucov_chol = Param(np.tile(np.eye(self.n_ind_pts)[None, ...], [self.latent_dim, 1, 1]) 122 | if Ucov_chol is None else Ucov_chol, transform=LT_transform) # (Lm^-1)Lu 123 | self._Kzz = None 124 | 125 | @property 126 | def Kzz(self): 127 | if self._Kzz is None: 128 | self._Kzz = Kuu(self.Z, self.kern, jitter=self.jitter) # (latent_dim x) M x M 129 | return self._Kzz 130 | 131 | @params_as_tensors 132 | def _build_likelihood(self): 133 | X_samples, *fs = self._build_sample() 134 | emissions = self._build_emissions(X_samples) 135 | KL_X = self._build_KL_X(fs) 136 | KL_U = self._build_KL_U() 137 | KL_x1 = self._build_KL_x1() 138 | return emissions - KL_X - KL_U - KL_x1 139 | 140 | @params_as_tensors 141 | def _build_sample(self): 142 | return self.sample_fn(**self.sample_kwargs) 143 | 144 | @params_as_tensors 145 | def _build_emissions(self, X_samples): 146 | emissions = self.emissions.logp(X_samples, self.Y[:, None, :]) # T x n_samples 147 | return tf.reduce_sum(tf.reduce_mean(emissions, -1)) 148 | 149 | @params_as_tensors 150 | def _build_KL_X(self, fs): 151 | return tf.reduce_sum(self.KL_fn(*fs)) 152 | 153 | @params_as_tensors 154 | def _build_KL_U(self): 155 | return KL(self.Umu, self.Ucov_chol) 156 | 157 | @params_as_tensors 158 | def _build_KL_x1(self): 159 | return KL(self.qx1_mu - self.px1_mu, self.qx1_cov_chol, P_chol=self.px1_cov_chol) 160 | 161 | @params_as_tensors 162 | def _build_transition_KLs(self, f_mus, f_vars, As=None, bs=None, S_chols=None): 163 | As = self.As if As is None else As 164 | bs = self.bs if bs is None else bs 165 | S_chols = self.S_chols if S_chols is None else S_chols 166 | 167 | const = tf.reduce_sum(tf.log(tf.square(self.Q_sqrt))) - self.latent_dim 168 | 169 | if As.shape.ndims == 2: 170 | mahalanobis = (As - 1.)[:, None, :] * f_mus 171 | else: 172 | mahalanobis = tf.matmul(f_mus, As - tf.eye(self.latent_dim, dtype=gps.float_type), 173 | transpose_b=True) 174 | mahalanobis += bs[:, None, :] # (T-1) x n_samples x latent_dim 175 | mahalanobis = tf.reduce_mean(tf.reduce_sum(tf.square(mahalanobis / self.Q_sqrt), -1), -1) # T - 1 176 | 177 | mean_f_var = tf.reduce_mean(f_vars, 1) 178 | 179 | if (S_chols.shape.ndims == 2) and (As.shape.ndims == 2): 180 | trace = tf.square(S_chols) + mean_f_var * tf.square(As - 1.) 181 | elif As.shape.ndims == 2: 182 | trace = tf.reduce_sum(tf.square(S_chols), -1) + mean_f_var * tf.square(As - 1.) 183 | elif S_chols.shape.ndims == 2: 184 | trace = tf.square(S_chols) + tf.reduce_sum( 185 | mean_f_var[:, None, :] * tf.square(As - tf.eye(self.latent_dim, dtype=gps.float_type)), -1) 186 | else: 187 | trace = tf.reduce_sum(tf.square(S_chols), -1) + tf.reduce_sum( 188 | mean_f_var[:, None, :] * tf.square(As - tf.eye(self.latent_dim, dtype=gps.float_type)), -1) 189 | 190 | trace = tf.reduce_sum(trace / tf.square(self.Q_sqrt), -1) # T - 1 191 | 192 | log_det_S = 2. * tf.reduce_sum(tf.log(tf.abs(S_chols if S_chols.shape.ndims == 2 193 | else tf.matrix_diag_part(S_chols))), -1) # T - 1 194 | 195 | return 0.5 * (const + mahalanobis + trace - log_det_S) # T - 1 196 | 197 | @params_as_tensors 198 | def _build_factorized_transition_KLs(self, f_mus, f_vars, x_cov_chols, As=None, bs=None, S_chols=None): 199 | As = self.As if As is None else As 200 | bs = self.bs if bs is None else bs 201 | S_chols = self.S_chols if S_chols is None else S_chols 202 | 203 | const = tf.reduce_sum(tf.log(tf.square(self.Q_sqrt))) - self.latent_dim 204 | 205 | if As.shape.ndims == 2: 206 | mahalanobis = (As - 1.)[:, None, :] * f_mus 207 | else: 208 | mahalanobis = tf.matmul(f_mus, As - tf.eye(self.latent_dim, dtype=gps.float_type), 209 | transpose_b=True) 210 | mahalanobis += bs[:, None, :] # (T-1) x n_samples x latent_dim 211 | mahalanobis = tf.reduce_mean(tf.reduce_sum(tf.square(mahalanobis / self.Q_sqrt), -1), -1) # T - 1 212 | 213 | is_diag_xcov_chol = (S_chols.shape.ndims == 2) and (As.shape.ndims == 2) 214 | 215 | if is_diag_xcov_chol: 216 | trace = f_vars + tf.square(x_cov_chols) 217 | else: 218 | trace = f_vars + tf.reduce_sum(tf.square(x_cov_chols), -1) 219 | trace = tf.reduce_mean(tf.reduce_sum(trace / tf.square(self.Q_sqrt), -1), -1) # T - 1 220 | 221 | log_det_x_covs = 2. * tf.reduce_mean(tf.reduce_sum(tf.log(tf.abs( 222 | x_cov_chols if is_diag_xcov_chol else tf.matrix_diag_part(x_cov_chols))), -1), -1) # T - 1 223 | 224 | return 0.5 * (const + mahalanobis + trace - log_det_x_covs) # T - 1 225 | 226 | @params_as_tensors 227 | def _build_transition_KLs_from_samples(self, F_samples, As=None, bs=None, S_chols=None): 228 | As = self.As if As is None else As 229 | bs = self.bs if bs is None else bs 230 | S_chols = self.S_chols if S_chols is None else S_chols 231 | 232 | if As.shape.ndims == 2: 233 | mu_diff = (As - 1.)[:, None, :] * F_samples 234 | else: 235 | mu_diff = tf.matmul(F_samples, As - tf.eye(self.latent_dim, dtype=gps.float_type), 236 | transpose_b=True) 237 | mu_diff += bs[:, None, :] # (T-1) x n_samples x latent_dim 238 | return KL_samples(mu_diff, S_chols, P_chol=self.Q_sqrt) 239 | 240 | @params_as_tensors 241 | def _build_linear_time_q_sample(self, return_f_moments=False, return_x_cov_chols=False, 242 | sample_f=False, sample_u=True, return_u=False, 243 | T=None, inputs=None, qx1_mu=None, qx1_cov_chol=None, x1_samples=None, 244 | As=None, bs=None, S_chols=None, Lm=None): 245 | T = self.T if T is None else T 246 | inputs = self.inputs if inputs is None else inputs 247 | qx1_mu = self.qx1_mu if qx1_mu is None else qx1_mu 248 | qx1_cov_chol = self.qx1_cov_chol if qx1_cov_chol is None else qx1_cov_chol 249 | As = self.As if As is None else As 250 | bs = self.bs if bs is None else bs 251 | S_chols = self.S_chols if S_chols is None else S_chols 252 | n_samples = self.n_samples if x1_samples is None else int(x1_samples.shape[0]) 253 | n_mean_inputs = self.mean_fn.input_dim if hasattr(self.mean_fn, "input_dim") else self.latent_dim 254 | differentiate = x1_samples is None 255 | 256 | Lm = tf.cholesky(self.Kzz) if Lm is None else Lm 257 | 258 | X_samples = tf.TensorArray(size=T, dtype=gps.float_type, clear_after_read=False, 259 | infer_shape=False, element_shape=(n_samples, self.latent_dim)) 260 | if sample_f: 261 | F_samples = tf.TensorArray(size=T - 1, dtype=gps.float_type, clear_after_read=False, 262 | infer_shape=False, element_shape=(n_samples, self.latent_dim)) 263 | if return_f_moments: 264 | f_mus = tf.TensorArray(size=T - 1, dtype=gps.float_type, clear_after_read=False, 265 | infer_shape=False, element_shape=(n_samples, self.latent_dim)) 266 | f_vars = tf.TensorArray(size=T - 1, dtype=gps.float_type, clear_after_read=False, 267 | infer_shape=False, element_shape=(n_samples, self.latent_dim)) 268 | 269 | is_diag_xcov = (S_chols.shape.ndims == 2) if sample_f else \ 270 | ((S_chols.shape.ndims == 2) and (As.shape.ndims == 2)) 271 | 272 | if return_x_cov_chols: 273 | x_cov_chols = tf.TensorArray(size=T - 1, dtype=gps.float_type, clear_after_read=False, 274 | infer_shape=False, element_shape= 275 | (n_samples, self.latent_dim) if is_diag_xcov 276 | else (n_samples, self.latent_dim, self.latent_dim)) 277 | if sample_u: 278 | U_samples = tfd.MultivariateNormalDiag(loc=tf.zeros( 279 | (self.latent_dim, self.n_ind_pts, n_samples), dtype=gps.float_type)) 280 | U_samples = U_samples.sample(seed=self.seed) 281 | U_samples = self.Umu[:, :, None] + tf.matmul(self.Ucov_chol, U_samples) 282 | 283 | white_samples = tfd.MultivariateNormalDiag(loc=tf.zeros( 284 | (n_samples, self.latent_dim), dtype=gps.float_type)) 285 | 286 | white_samples_X = white_samples.sample(T, seed=self.seed) 287 | if x1_samples is not None: 288 | X_samples = X_samples.write(0, x1_samples) 289 | else: 290 | if qx1_cov_chol.shape.ndims == 1: 291 | x1_noise = white_samples_X[0] * qx1_cov_chol 292 | else: 293 | x1_noise = tf.matmul(white_samples_X[0], qx1_cov_chol, transpose_b=True) 294 | X_samples = X_samples.write(0, qx1_mu + x1_noise) 295 | 296 | if sample_f: white_samples_F = white_samples.sample(T - 1, seed=self.seed) 297 | 298 | def _loop_body(*args): 299 | t, X = args[:2] 300 | if sample_f: F = args[2] 301 | if return_f_moments: f_mus, f_vars = args[-3:-1] if return_x_cov_chols else args[-2:] 302 | if return_x_cov_chols: x_cov_chols = args[-1] 303 | 304 | x_t = X.read(t) # n_samples x latent_dim 305 | if inputs is not None: 306 | x_t = tf.concat([x_t, tf.tile(inputs[t][None, :], [n_samples, 1])], -1) 307 | 308 | if sample_u: 309 | f_mu, f_var = conditional(x_t, self.Z, self.kern, U_samples, q_sqrt=None, white=True, Lm=Lm) 310 | else: 311 | f_mu, f_var = conditional(x_t, self.Z, self.kern, self.Umu, q_sqrt=self.Ucov_chol, white=True, Lm=Lm) 312 | f_mu += self.mean_fn(x_t[:, :n_mean_inputs]) 313 | f_var = tf.abs(f_var) 314 | 315 | if sample_f: 316 | f_t = f_mu + tf.sqrt(f_var) * white_samples_F[t] # n_samples x latent_dim 317 | F = F.write(t, f_t) 318 | f_mu_or_t = f_t 319 | tiling = [n_samples, 1] if is_diag_xcov else [n_samples, 1, 1] 320 | x_cov_chol = tf.tile(S_chols[t][None, ...], tiling) 321 | else: 322 | f_mu_or_t = f_mu 323 | if is_diag_xcov: 324 | x_cov_chol = tf.sqrt(tf.square(S_chols[t]) + f_var * tf.square(As[t])) # (n_samples x latent_dim) 325 | elif As.shape.ndims == 2: 326 | x_cov_chol = tf.matmul(S_chols[t], S_chols[t], transpose_b=True) 327 | x_cov_chol += tf.matrix_diag(f_var * tf.square(As[t])) 328 | x_cov_chol = tf.cholesky(x_cov_chol) # (n_samples x latent_dim x latent_dim) 329 | elif S_chols.shape.ndims == 2: 330 | x_cov_chol = tf.diag(tf.square(S_chols[t])) 331 | x_cov_chol += tf.tensordot(f_var[:, None, :] * As[t], As[t], axes=[[2], [1]]) 332 | x_cov_chol = tf.cholesky(x_cov_chol) # (n_samples x latent_dim x latent_dim) 333 | else: 334 | x_cov_chol = tf.matmul(S_chols[t], S_chols[t], transpose_b=True) 335 | x_cov_chol += tf.tensordot(f_var[:, None, :] * As[t], As[t], axes=[[2], [1]]) 336 | x_cov_chol = tf.cholesky(x_cov_chol) # (n_samples x latent_dim x latent_dim) 337 | 338 | x_tplus1 = bs[t] + ((As[t] * f_mu_or_t) if As.shape.ndims == 2 339 | else tf.matmul(f_mu_or_t, As[t], transpose_b=True)) # n_samples x latent_dim 340 | x_tplus1 += (white_samples_X[t + 1] * x_cov_chol) if is_diag_xcov \ 341 | else tf.reduce_sum(x_cov_chol * white_samples_X[t + 1][:, None, :], -1) 342 | X = X.write(t + 1, x_tplus1) 343 | 344 | if return_f_moments: 345 | f_mus, f_vars = f_mus.write(t, f_mu), f_vars.write(t, f_var) 346 | if return_x_cov_chols: 347 | x_cov_chols = x_cov_chols.write(t, x_cov_chol) 348 | 349 | ret_values = [t + 1, X] 350 | if sample_f: ret_values += [F] 351 | if return_f_moments: ret_values += [f_mus, f_vars] 352 | if return_x_cov_chols: ret_values += [x_cov_chols] 353 | return ret_values 354 | 355 | _loop_vars = [0, X_samples] 356 | if sample_f: _loop_vars += [F_samples] 357 | if return_f_moments: _loop_vars += [f_mus, f_vars] 358 | if return_x_cov_chols: _loop_vars += [x_cov_chols] 359 | 360 | result = tf.while_loop( 361 | cond=lambda t, *args: t < (T - 1), 362 | body=_loop_body, 363 | loop_vars=_loop_vars, 364 | back_prop=differentiate, 365 | parallel_iterations=self.parallel_iterations) 366 | 367 | ret_values = tuple(r.stack() for r in result[1:]) 368 | if sample_u and return_u: ret_values += (U_samples,) 369 | return ret_values 370 | 371 | @params_as_tensors 372 | def _build_cubic_time_q_sample(self, return_f_moments=False, return_f=True, 373 | sample_u=False, return_u=False, add_jitter=True, inverse_chol=False, 374 | T=None, inputs=None, qx1_mu=None, qx1_cov_chol=None, x1_samples=None, 375 | As=None, bs=None, S_chols=None, Lm=None): 376 | T = self.T if T is None else T 377 | inputs = self.inputs if inputs is None else inputs 378 | if inputs is not None: 379 | inputs = tf.concat([inputs, tf.zeros((1, tf.shape(inputs)[-1]), dtype=gps.float_type)], 0) 380 | qx1_mu = self.qx1_mu if qx1_mu is None else qx1_mu 381 | qx1_cov_chol = self.qx1_cov_chol if qx1_cov_chol is None else qx1_cov_chol 382 | As = self.As if As is None else As 383 | bs = self.bs if bs is None else bs 384 | S_chols = self.S_chols if S_chols is None else S_chols 385 | n_samples = self.n_samples if x1_samples is None else int(x1_samples.shape[0]) 386 | n_mean_inputs = self.mean_fn.input_dim if hasattr(self.mean_fn, "input_dim") else self.latent_dim 387 | differentiate = x1_samples is None 388 | 389 | Lm = tf.cholesky(self.Kzz) if Lm is None else Lm 390 | shared_kern = isinstance(self.kern, mk.SharedIndependentMok) 391 | shared_kern_and_Z = Lm.shape.ndims == 2 or (shared_kern and isinstance(self.Z, mf.SharedIndependentMof)) 392 | 393 | if sample_u: 394 | U_samples = tfd.MultivariateNormalDiag(loc=tf.zeros( 395 | (self.latent_dim, self.n_ind_pts, n_samples), dtype=gps.float_type)) 396 | U_samples = U_samples.sample(seed=self.seed) 397 | U_samples = self.Umu[:, :, None] + tf.matmul(self.Ucov_chol, U_samples) 398 | 399 | white_samples_X = tfd.MultivariateNormalDiag( 400 | loc=tf.zeros((n_samples, T, self.latent_dim), dtype=gps.float_type)).sample(seed=self.seed) 401 | 402 | white_samples_F = tfd.MultivariateNormalDiag( 403 | loc=tf.zeros((n_samples, T - 1, self.latent_dim), dtype=gps.float_type)).sample(seed=self.seed) 404 | white_samples_F = tf.transpose(white_samples_F, [0, 2, 1]) 405 | 406 | if x1_samples is None: 407 | x1_samples = qx1_mu + ((white_samples_X[:, 0] * qx1_cov_chol) if qx1_cov_chol.shape.ndims == 1 408 | else tf.matmul(white_samples_X[:, 0], qx1_cov_chol, transpose_b=True)) 409 | 410 | if inputs is not None: 411 | x1_samples = tf.concat([x1_samples, tf.tile(inputs[:1], [n_samples, 1])], -1) 412 | 413 | if sample_u: 414 | f1_mu, f1_var = conditional(x1_samples, self.Z, self.kern, U_samples, 415 | q_sqrt=None, white=True, Lm=Lm) 416 | else: 417 | f1_mu, f1_var = conditional(x1_samples, self.Z, self.kern, self.Umu, 418 | q_sqrt=self.Ucov_chol, white=True, Lm=Lm) 419 | 420 | f1_mu += self.mean_fn(x1_samples[:, :n_mean_inputs]) 421 | f1_var = tf.abs(f1_var) 422 | f1_samples = f1_mu + tf.sqrt(f1_var) * white_samples_F[:, :, 0] 423 | 424 | if sample_u: U_samples = tf.transpose(U_samples, [2, 0, 1]) # n_samples x latent_dim x M 425 | 426 | def single_trajectory(args): 427 | if sample_u: 428 | U_samples_n = args[-1] 429 | args = args[:-1] 430 | x1_samples_n, f1_samples_n, white_samples_X_n, white_samples_F_n, f1_mu_n, f1_var_n = args 431 | 432 | x2_samples_n = bs[0] + ((As[0] * f1_samples_n) if As.shape.ndims == 2 433 | else tf.reduce_sum(As[0] * f1_samples_n, -1)) # latent_dim 434 | x2_samples_n += (white_samples_X_n[1] * S_chols[0]) if S_chols.shape.ndims == 2 \ 435 | else tf.reduce_sum(S_chols[0] * white_samples_X_n[1], -1) 436 | if inputs is not None: 437 | x2_samples_n = tf.concat([x2_samples_n, inputs[1]], 0) 438 | X_samples_n = tf.stack([x1_samples_n, x2_samples_n], 0) # 2 x latent_dim 439 | 440 | F_samples_n = tf.TensorArray(size=T - 1, dtype=gps.float_type, clear_after_read=False, 441 | infer_shape=False, element_shape=(self.latent_dim,)) 442 | F_samples_n = F_samples_n.write(0, f1_samples_n) 443 | 444 | if return_f_moments: 445 | f_mus = tf.TensorArray(size=T - 1, dtype=gps.float_type, clear_after_read=False, 446 | infer_shape=False, element_shape=(self.latent_dim,)) 447 | f_vars = tf.TensorArray(size=T - 1, dtype=gps.float_type, clear_after_read=False, 448 | infer_shape=False, element_shape=(self.latent_dim,)) 449 | f_mus = f_mus.write(0, f1_mu_n) 450 | f_vars = f_vars.write(0, f1_var_n) 451 | 452 | Kzx = Kuf(self.Z, self.kern, X_samples_n[:1]) # (latent_dim x) M x 1 453 | Lm_inv_Kzx = tf.matrix_triangular_solve(Lm, Kzx, lower=True) # (latent_dim x) M x 1 454 | 455 | F_cov_chol = f1_var_n[0] if (shared_kern_and_Z and sample_u) else f1_var_n # () or latent_dim 456 | F_cov_chol = tf.sqrt((F_cov_chol + self.jitter) if add_jitter else F_cov_chol) 457 | if inverse_chol: F_cov_chol = 1. / F_cov_chol 458 | F_cov_chol = F_cov_chol[..., None, None] # (latent_dim x) 1 x 1 459 | 460 | def _loop_body(t, X, F, Lm_inv_Kzx, F_cov_chol, f_mus=None, f_vars=None): 461 | if shared_kern: 462 | Kx1_to_tp1_xtp1 = self.kern.kern.K(X, X[-1:])[..., 0] # t+1 463 | else: 464 | Kx1_to_tp1_xtp1 = self.kern.K(X, X[-1:], full_output_cov=False)[..., 0] # latent_dim x (t+1) 465 | 466 | Kzxtp1 = Kuf(self.Z, self.kern, X[-1][None, :]) # (latent_dim x) M x 1 467 | Lm_inv_Kzxtp1 = tf.matrix_triangular_solve(Lm, Kzxtp1, lower=True)[..., 0] # (latent_dim x) M 468 | 469 | f_tp1_marg_mu = self.mean_fn(X[-1:, :n_mean_inputs])[0] # () or latent_dim 470 | if sample_u: 471 | f_tp1_marg_mu += tf.reduce_sum(Lm_inv_Kzxtp1 * U_samples_n, -1) # latent_dim 472 | else: 473 | f_tp1_marg_mu += tf.reduce_sum(Lm_inv_Kzxtp1 * self.Umu, -1) # latent_dim 474 | 475 | F_cov_tp1 = Kx1_to_tp1_xtp1[..., -1] - tf.reduce_sum(tf.square(Lm_inv_Kzxtp1), -1) # () or latent_dim 476 | 477 | F_cov_1_to_t_tp1 = Kx1_to_tp1_xtp1[..., :-1] # (latent_dim x) t 478 | F_cov_1_to_t_tp1 -= tf.reduce_sum(Lm_inv_Kzx * Lm_inv_Kzxtp1[..., None], -2) # (latent_dim x) t 479 | 480 | if not sample_u: 481 | Uchol_Lm_inv_Kzxtp1 = tf.reduce_sum(self.Ucov_chol * Lm_inv_Kzxtp1[..., None], -2) # latent_dim x M 482 | F_cov_tp1 += tf.reduce_sum(tf.square(Uchol_Lm_inv_Kzxtp1), -1) # latent_dim 483 | # latent_dim x M: 484 | Ucov_Lm_inv_Kzxtp1 = tf.reduce_sum(self.Ucov_chol * Uchol_Lm_inv_Kzxtp1[:, None, :], -1) 485 | F_cov_1_to_t_tp1 += tf.reduce_sum(Lm_inv_Kzx * Ucov_Lm_inv_Kzxtp1[:, :, None], -2) # latent_dim x t 486 | 487 | if inverse_chol: 488 | F_chol_inv_F_1_to_t_tp1 = tf.reduce_sum( 489 | F_cov_chol * F_cov_1_to_t_tp1[..., None, :], -1) # (latent_dim x) t 490 | else: 491 | F_chol_inv_F_1_to_t_tp1 = tf.matrix_triangular_solve( 492 | F_cov_chol, F_cov_1_to_t_tp1[..., None], lower=True)[..., 0] # (latent_dim x) t 493 | # latent_dim: 494 | f_tp1_mu = f_tp1_marg_mu + tf.reduce_sum(F_chol_inv_F_1_to_t_tp1 * white_samples_F_n[:, :t], -1) 495 | 496 | f_tp1_var = F_cov_tp1 - tf.reduce_sum(tf.square(F_chol_inv_F_1_to_t_tp1), -1) # () or latent_dim 497 | f_tp1_var = tf.abs(f_tp1_var) 498 | 499 | f_tp1 = f_tp1_mu + tf.sqrt(f_tp1_var) * white_samples_F_n[:, t] # latent_dim 500 | 501 | x_tplus2 = bs[t] + ((As[t] * f_tp1) if As.shape.ndims == 2 502 | else tf.reduce_sum(As[t] * f_tp1, -1)) # latent_dim 503 | x_tplus2 += (S_chols[t] * white_samples_X_n[t + 1]) if S_chols.shape.ndims == 2 \ 504 | else tf.reduce_sum(S_chols[t] * white_samples_X_n[t + 1], -1) # latent_dim 505 | 506 | if inputs is not None: 507 | x_tplus2 = tf.concat([x_tplus2, inputs[t + 1]], 0) 508 | 509 | X = tf.concat([X, x_tplus2[None, :]], 0) # (t+2) x latent_dim 510 | F = F.write(t, f_tp1) 511 | 512 | Lm_inv_Kzx = tf.concat([Lm_inv_Kzx, Lm_inv_Kzxtp1[..., None]], -1) # (latent_dim x) M x (t+1) 513 | 514 | F_cov_chol_diag = tf.sqrt((f_tp1_var + self.jitter) if add_jitter else f_tp1_var) # () or latent_dim 515 | if inverse_chol: 516 | F_cov_chol_bottom_offdiag = - tf.reduce_sum(F_chol_inv_F_1_to_t_tp1[..., None] * F_cov_chol, -2) 517 | F_cov_chol_bottom_offdiag /= F_cov_chol_diag[..., None] 518 | F_cov_chol_diag = 1. / F_cov_chol_diag 519 | F_cov_chol_bottom_row = tf.concat([F_cov_chol_bottom_offdiag, F_cov_chol_diag[..., None]], -1) 520 | else: 521 | F_cov_chol_bottom_row = tf.concat([F_chol_inv_F_1_to_t_tp1, F_cov_chol_diag[..., None]], -1) 522 | 523 | padding = [[0, 0], [0, 1]] if (shared_kern_and_Z and sample_u) else [[0, 0], [0, 0], [0, 1]] 524 | F_cov_chol = tf.pad(F_cov_chol, paddings=padding) # (latent_dim x) t x (t+1) 525 | # (latent_dim x) (t+1) x (t+1): 526 | F_cov_chol = tf.concat([F_cov_chol, F_cov_chol_bottom_row[..., None, :]], -2) 527 | 528 | ret_values = (t + 1, X, F, Lm_inv_Kzx, F_cov_chol) 529 | if return_f_moments: 530 | if shared_kern_and_Z and sample_u: 531 | f_tp1_var = tf.tile(f_tp1_var[None], [self.latent_dim]) # latent_dim 532 | f_mus, f_vars = f_mus.write(t, f_tp1_mu), f_vars.write(t, f_tp1_var) 533 | ret_values += (f_mus, f_vars) 534 | return ret_values 535 | 536 | _loop_vars = [1, X_samples_n, F_samples_n, Lm_inv_Kzx, F_cov_chol] 537 | 538 | shape_invar_Lm_inv_Kzx = tf.TensorShape([self.n_ind_pts, None]) if shared_kern_and_Z \ 539 | else tf.TensorShape([self.latent_dim, self.n_ind_pts, None]) 540 | shape_invar_F_cov_chol = tf.TensorShape([None, None]) if (shared_kern_and_Z and sample_u) \ 541 | else tf.TensorShape([self.latent_dim, None, None]) 542 | _shape_invariants = [tf.TensorShape([]), 543 | tf.TensorShape([None, self.latent_dim + self.input_dim]), 544 | tf.TensorShape(None), 545 | shape_invar_Lm_inv_Kzx, 546 | shape_invar_F_cov_chol] 547 | 548 | if return_f_moments: 549 | _loop_vars += [f_mus, f_vars] 550 | _shape_invariants += [tf.TensorShape(None), tf.TensorShape(None)] 551 | 552 | loop_result = tf.while_loop( 553 | cond=lambda t, *args: t < (T - 1), 554 | body=_loop_body, 555 | loop_vars=_loop_vars, 556 | shape_invariants=_shape_invariants, 557 | back_prop=differentiate, 558 | parallel_iterations=self.parallel_iterations) 559 | 560 | loop_result = loop_result[1:] 561 | X_traj = loop_result[0] 562 | if inputs is not None: X_traj = X_traj[:, :self.latent_dim] 563 | F_traj = loop_result[1].stack() 564 | if return_f_moments: 565 | return X_traj, F_traj, loop_result[-2].stack(), loop_result[-1].stack() 566 | return X_traj, F_traj 567 | 568 | iterables = (x1_samples, f1_samples, white_samples_X, white_samples_F, f1_mu, f1_var) 569 | if sample_u: iterables += (U_samples,) 570 | map_fn_result = tf.map_fn(single_trajectory, 571 | iterables, 572 | (gps.float_type,) * (4 if return_f_moments else 2), 573 | back_prop=differentiate, 574 | parallel_iterations=self.parallel_iterations) 575 | 576 | X_samples = tf.transpose(map_fn_result[0], [1, 0, 2]) # T x n_samples x latent_dim 577 | ret_values = (X_samples,) 578 | if return_f: 579 | F_samples = tf.transpose(map_fn_result[1], [1, 0, 2]) # (T-1) x n_samples x latent_dim 580 | ret_values += (F_samples,) 581 | if return_f_moments: 582 | f_mus = tf.transpose(map_fn_result[2], [1, 0, 2]) # (T-1) x n_samples x latent_dim 583 | f_vars = tf.transpose(map_fn_result[3], [1, 0, 2]) # (T-1) x n_samples x latent_dim 584 | ret_values += (f_mus, f_vars) 585 | if sample_u and return_u: ret_values += (tf.transpose(U_samples, [1, 2, 0]),) # latent_dim x M x n_samples 586 | return ret_values 587 | 588 | @params_as_tensors 589 | def _build_predict_f(self, X): 590 | f_mu, f_var = conditional(X, self.Z, self.kern, self.Umu, q_sqrt=self.Ucov_chol, white=True) 591 | n_mean_inputs = self.mean_fn.input_dim if hasattr(self.mean_fn, "input_dim") else self.latent_dim 592 | f_mu += self.mean_fn(X[:, :n_mean_inputs]) 593 | return f_mu, f_var 594 | 595 | def sample(self, T, N=1, x0_samples=None, inputs=None, cubic=True, 596 | sample_u=False, sample_f=False, return_op=False): 597 | if x0_samples is None: 598 | assert len(self.px1_mu.shape) == 1 599 | noise = tf.random_normal((N, self.latent_dim), dtype=gps.float_type, seed=self.seed) 600 | if self.px1_cov_chol is not None: 601 | if len(self.px1_cov_chol.shape) == 1: 602 | noise = noise * self.px1_cov_chol.constrained_tensor 603 | else: 604 | noise = tf.matmul(noise, self.px1_cov_chol.constrained_tensor, transpose_b=True) 605 | x0_samples = self.px1_mu.constrained_tensor + noise 606 | if inputs is not None: 607 | inputs = tf.constant(inputs) 608 | elif self.inputs is not None: 609 | inputs = self.inputs.constrained_tensor 610 | else: 611 | x0_samples = tf.constant(x0_samples) 612 | inputs = None if inputs is None else tf.constant(inputs) 613 | T += 1 614 | 615 | sample_fn = self._build_cubic_time_q_sample if cubic else \ 616 | lambda **kwargs: self._build_linear_time_q_sample(sample_f=sample_f, **kwargs) 617 | 618 | X_samples, *fs = sample_fn(T=T, sample_u=sample_u, 619 | inputs=inputs, 620 | x1_samples=x0_samples, 621 | As=tf.ones((T - 1, self.latent_dim), dtype=gps.float_type), 622 | bs=tf.zeros((T - 1, self.latent_dim), dtype=gps.float_type), 623 | S_chols=self.Q_sqrt.constrained_tensor * 624 | tf.ones((T - 1, self.latent_dim), dtype=gps.float_type)) 625 | 626 | if return_op: 627 | return X_samples 628 | else: 629 | session = self.enquire_session() 630 | X_samples = session.run(X_samples) 631 | Y_samples = self.emissions.sample_conditional(X_samples) 632 | return X_samples, Y_samples 633 | 634 | def assign(self, dct, **kwargs): 635 | if isinstance(dct, pd.Series): 636 | dct = dct.to_dict() 637 | for k in list(dct.keys()): 638 | new_key = '/'.join([self.name] + k.split('/')[1:]) 639 | dct[new_key] = dct.pop(k) 640 | super().assign(dct, **kwargs) 641 | --------------------------------------------------------------------------------