├── fig └── figure1.png ├── data_generation ├── toy_extrapolation.R └── circles_and_pinwheel.R ├── README.md ├── examples ├── 1_plot_example_decomposition.R └── 1_example_decomposition.py ├── cGPLVM ├── helpers.py ├── helpers_survival.py ├── kernels.py ├── cGPLVM_survival.py ├── cGPLVM.py └── GP_mappings.py └── LICENSE.txt /fig/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/c-GPLVM/HEAD/fig/figure1.png -------------------------------------------------------------------------------- /data_generation/toy_extrapolation.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | 3 | generate_toy_data <- function(N = 500){ 4 | z <- runif(N, -3, 3) 5 | c <- sample(c(-1, 0, 1), N, TRUE) 6 | 7 | y <- sin(z) + c + 0.1*rnorm(N) 8 | 9 | data.frame(y, z, c) %>% 10 | filter(!(c == 1 & z > -1.5), !(c == 0 & z > -0.5)) 11 | } 12 | 13 | df <- generate_toy_data(250) 14 | 15 | write_csv(df, "data/extrapolation.csv") 16 | 17 | df %>% 18 | ggplot(aes(z, y, col=c)) + 19 | geom_point() + 20 | scale_color_viridis_c() + 21 | theme_classic() 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # c-GPLVM 2 | 3 | Implementation of Covariate-GPLVM, accompanying our ICML 2019 paper [Decomposing feature-level variation with Covariate Gaussian Process Latent Variable Models](http://proceedings.mlr.press/v97/martens19a.html). 4 | 5 | ![](fig/figure1.png) 6 | 7 | See [here](https://github.com/kasparmartens/c-GPLVM/tree/master/examples) for a code example. 8 | 9 | ``` 10 | @InProceedings{martens2019, 11 | title = {Decomposing feature-level variation with {C}ovariate {G}aussian {P}rocess {L}atent {V}ariable {M}odels}, 12 | author = {M{\"a}rtens, Kaspar and Campbell, Kieran and Yau, Christopher}, 13 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 14 | pages = {4372--4381}, 15 | year = {2019}, 16 | volume = {97}, 17 | series = {Proceedings of Machine Learning Research}, 18 | publisher = {PMLR}, 19 | } 20 | ``` 21 | -------------------------------------------------------------------------------- /examples/1_plot_example_decomposition.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | 3 | df <- read_csv("output/toy_decomposition.csv") 4 | 5 | # plot overall prediction 6 | df %>% 7 | ggplot(aes(z, f, col=x, group=x)) + 8 | geom_path() + 9 | scale_color_viridis_c() + 10 | theme_classic() + 11 | labs(title = "Aggregate c-GPLVM mapping") 12 | 13 | # now plot decomposition 14 | 15 | # f_z 16 | df %>% 17 | ggplot(aes(z, f_z, group=x)) + 18 | geom_path() + 19 | theme_classic() + 20 | labs(title = "Decomposition: f(z)") 21 | 22 | # f_x 23 | df %>% 24 | ggplot(aes(z, f_x, col=x, group=x)) + 25 | geom_path() + 26 | scale_color_viridis_c() + 27 | theme_classic() + 28 | labs(title = "Decomposition: f(x)") 29 | 30 | # interaction f_{zx} 31 | df %>% 32 | ggplot(aes(z, f_int, col=x, group=x)) + 33 | geom_path() + 34 | scale_color_viridis_c() + 35 | theme_classic() + 36 | labs(title = "Decomposition: f(z, x)") 37 | -------------------------------------------------------------------------------- /cGPLVM/helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.distributions.normal import Normal 3 | 4 | # define KL between normal and standard normal distribution 5 | def KL_standard_normal(mu, sigma): 6 | p = Normal(torch.zeros_like(mu), torch.ones_like(mu)) 7 | q = Normal(mu, sigma) 8 | return torch.sum(torch.distributions.kl_divergence(q, p)) 9 | 10 | 11 | # define softplus function 12 | def my_softplus(x): 13 | return torch.log(1.0 + torch.exp(x)) 14 | 15 | def create_2D_grid(grid_x1, grid_x2, device="cpu"): 16 | x1_s, x2_s = torch.meshgrid([grid_x1.to(device), grid_x2.to(device)]) 17 | x1_star, x2_star = x1_s.reshape([-1, 1]), x2_s.reshape([-1, 1]) 18 | return x1_star, x2_star 19 | 20 | # takes in [N1, 1] and [N2, p2] matrices, then performs expand grid 21 | def grid_helper(a, b): 22 | nrow_a = a.size()[0] 23 | nrow_b = b.size()[0] 24 | ncol_b = b.size()[1] 25 | x = a.repeat(nrow_b, 1) 26 | y = b.repeat(1, nrow_a).view(-1, ncol_b) 27 | return x, y 28 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/1_example_decomposition.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | 4 | from cGPLVM.cGPLVM import cGPLVM 5 | from cGPLVM.GP_mappings import GP_2D_AddInt 6 | from cGPLVM.helpers import grid_helper 7 | 8 | # generate toy data 9 | def generate_toy_data(N, class_p=[1., 1., 1., 1., 1.], noise_sd=0.1): 10 | from torch.distributions.uniform import Uniform 11 | from torch.distributions.categorical import Categorical 12 | 13 | z = Uniform(-2.0, 2.0).rsample(sample_shape=(N, 1)) 14 | x = Categorical(torch.Tensor(class_p)).sample(sample_shape=(N, 1)).float()-2 15 | z_positive = 0.5 + 0.5*(z > 0).float() 16 | y1 = torch.sin(z) + 0.2*x + 0.2*torch.sin(z)*x*z_positive + noise_sd*torch.randn(N, 1) 17 | Y = torch.cat([y1], dim=1) 18 | return z, x, Y 19 | 20 | z, x, Y = generate_toy_data(N=1000) 21 | 22 | # set up inducing points in the (z, x) space 23 | z_inducing = torch.linspace(-2.0, 2.0, steps=10).reshape(-1, 1) 24 | x_inducing = torch.linspace(-2.0, 2.0, steps=10).reshape(-1, 1) 25 | 26 | # initialise covariate-GPLVM model 27 | m = cGPLVM(x, Y, z, GP_mapping=GP_2D_AddInt, mean_zero=True, z_inducing=z_inducing, x_inducing=x_inducing, fixed_z=True) 28 | 29 | # train the model using Adam 30 | m.train(n_iter=1000, verbose=100) 31 | 32 | 33 | ### predictions 34 | 35 | def helper_predict_decomposition(model,z_star, x_star): 36 | f_mean, _ = model.predict(z_star, x_star) 37 | f_z_mean, _ = model.predict_decomposition(z_star, x_star, which_kernels=[1.0, 0.0, 0.0]) 38 | f_x_mean, _ = model.predict_decomposition(z_star, x_star, which_kernels=[0.0, 1.0, 0.0]) 39 | f_int_mean, _ = model.predict_decomposition(z_star, x_star, which_kernels=[0.0, 0.0, 1.0]) 40 | F_pred = torch.cat([z_star, x_star, f_mean, f_z_mean, f_x_mean, f_int_mean], dim=1) 41 | return F_pred 42 | 43 | # create grid for predictions 44 | grid = torch.linspace(-2.0, 2.0, steps=50).reshape(-1, 1) 45 | z_grid, x_grid = grid_helper(grid, grid) 46 | 47 | # predict the posterior mean 48 | F_pred = helper_predict_decomposition(m, z_grid, x_grid).detach().numpy() 49 | 50 | # write the predictions into csv file 51 | col_names = ["z", "x", "f", "f_z", "f_x", "f_int"] 52 | pd.DataFrame(F_pred, columns=col_names).to_csv("output/toy_decomposition.csv", index=False) 53 | 54 | # plotting is done in R (see folder "plotting") 55 | -------------------------------------------------------------------------------- /data_generation/circles_and_pinwheel.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | library(patchwork) 3 | 4 | ## circles 5 | generate_circles <- function(path, N = 1000){ 6 | set.seed(0) 7 | z <- runif(N, 0, 1.8*pi) 8 | x <- sample(seq(-pi, pi, length=6)[-1], N, replace=TRUE) 9 | y1 <- cos(z) + 1.4*cos(x) + 0.05*rnorm(N) 10 | y2 <- sin(z) + 1.4*sin(x) + 0.05*rnorm(N) 11 | df_data <- data.frame(y1, y2, z_true = as.numeric(scale(z)), x = as.numeric(scale(x))) 12 | write_csv(df_data, path) 13 | } 14 | 15 | generate_distorted_circles <- function(path, N = 1000){ 16 | N <- 1000 17 | set.seed(0) 18 | z <- runif(N, 0, 1.8*pi) 19 | x <- sample(seq(-pi, pi, length=6)[-1], N, replace=TRUE) 20 | y1 <- cos(z) + 1.4*cos(x) + 0.05*rnorm(N) 21 | y2 <- sin(z) + 1.4*sin(x) + sin(z)*sin(0.25*x) + 0.05*rnorm(N) 22 | df_data <- data.frame(y1, y2, z_true = as.numeric(scale(z)), x = as.numeric(scale(x))) 23 | write_csv(df_data, path) 24 | } 25 | 26 | ## pinwheel 27 | generate_pinwheel <- function(path, N = 1000){ 28 | set.seed(0) 29 | n_classes <- 5 30 | rate <- 0.25 31 | rads <- 0.0 + seq(0, 2*pi, length=n_classes+1)[-1] 32 | # features <- cbind(rnorm(N, 1, 0.3), rnorm(N, 0, 0.05)) 33 | cluster <- sample(c(1, 2, 3), N, replace=TRUE) 34 | features_x <- runif(N, 0.1, 1.7) 35 | # features_x <- runif(N, c(0.1, 0.8, 1.4)[cluster], c(0.5, 1.2, 1.7)[cluster]) 36 | features_y <- runif(N, -0.1, 0.1) 37 | features <- cbind(features_x, features_y) 38 | angles0 <- sort(sample(rads, N, replace = TRUE)) 39 | angles <- angles0 + rate * exp(features[, 1]) 40 | Y <- matrix(0, N, 2) 41 | for(i in 1:N){ 42 | rotations <- rbind(c(cos(angles[i]), -sin(angles[i])), c(sin(angles[i]), cos(angles[i]))) 43 | Y[i, ] <- features[i, ] %*% rotations 44 | } 45 | covariate <- as.numeric(scale(angles0)) 46 | 47 | df_data <- data.frame(y1 = Y[, 1], y2 = Y[, 2], x = covariate, z_true = as.numeric(scale(features_x))) 48 | write_csv(df_data, path) 49 | } 50 | 51 | 52 | generate_circles("data/circles.csv") 53 | generate_pinwheel("data/pinwheel.csv") 54 | 55 | 56 | df_data <- read_csv("data/pinwheel.csv") 57 | # df_data <- read_csv("data/circles.csv") 58 | 59 | p1 <- df_data %>% 60 | ggplot(aes(z, y1, col=x)) + 61 | geom_point() + 62 | scale_color_viridis_c() + 63 | labs(title = "(z, y1)") 64 | p2 <- df_data %>% 65 | ggplot(aes(z, y2, col=x)) + 66 | geom_point() + 67 | scale_color_viridis_c() + 68 | labs(title = "(z, y2)") 69 | p0 <- df_data %>% 70 | ggplot(aes(y1, y2, col=x)) + 71 | geom_point() + 72 | scale_color_viridis_c() + 73 | labs(title = "(y1, y2)") 74 | 75 | p1 + p2 + p0 + plot_layout(widths = c(1, 1, 1.5)) 76 | -------------------------------------------------------------------------------- /cGPLVM/helpers_survival.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from torch.distributions.weibull import Weibull 5 | from torch.distributions.uniform import Uniform 6 | from torch.distributions.normal import Normal 7 | 8 | def Phi(x): 9 | return 0.5 * (1 + torch.erf(x / math.sqrt(2))) 10 | 11 | 12 | def Phi_inverse(value): 13 | return torch.erfinv(2 * value - 1) * math.sqrt(2) 14 | 15 | 16 | def rsample_TruncatedNormal(loc, scale, lower, upper, sample_shape=(1,)): 17 | F_a = Phi((lower - loc) / scale) 18 | F_b = Phi((upper - loc) / scale) 19 | u0 = Uniform(torch.zeros_like(loc), torch.ones_like(scale)).rsample(sample_shape=sample_shape) 20 | u = u0 * (F_b - F_a) + F_a 21 | out = Phi_inverse(u) * scale + loc 22 | if torch.isinf(out).any(): 23 | print("rsample_TruncatedNormal has a problem") 24 | return Phi_inverse(u) * scale + loc 25 | 26 | 27 | def log_prob_Normal(value, loc, scale): 28 | return Normal(loc=loc, scale=scale).log_prob(value) 29 | 30 | 31 | def log_prob_TruncatedNormal(value, loc, scale, lower, upper): 32 | F_a = Phi(lower) 33 | F_b = Phi(upper) 34 | return Normal(loc=loc, scale=scale).log_prob(value) - torch.log(F_b - F_a + 1e-8) 35 | 36 | 37 | def log_prob_Weibull(x, shape, scale): 38 | x_over_lambda = x / scale 39 | return torch.log(shape / scale) + (shape - 1) * torch.log(x_over_lambda) - torch.pow(x_over_lambda, shape) 40 | 41 | 42 | def log_prob_TruncatedWeibull(x, shape, scale, lower, upper): 43 | p = Weibull(scale=scale, concentration=shape) 44 | F_a = p.cdf(lower) 45 | F_b = p.cdf(upper) 46 | return log_prob_Weibull(x, shape, scale) - torch.log(F_b - F_a + 1e-8) 47 | 48 | 49 | def inv_cdf_Weibull(u, shape, scale): 50 | # inverse_cdf(u) = - lambda * log(1-u)^{1/k} 51 | return scale * torch.pow(-torch.log(1.0 - u), 1.0 / shape) 52 | 53 | 54 | # Truncated Weibull(lambda, k), [lower, upper] 55 | def rsample_TruncatedWeibull(shape, scale, lower, upper, sample_shape=(1,)): 56 | p = Weibull(scale=scale, concentration=shape) 57 | F_a = p.cdf(lower) 58 | F_b = p.cdf(upper) 59 | # u0 = Uniform(0.0, 1.0).rsample(sample_shape=sample_shape) 60 | u0 = Uniform(torch.zeros_like(scale), torch.ones_like(scale)).rsample(sample_shape=sample_shape) 61 | u = u0 * (F_b - F_a + 1e-12) + F_a 62 | return inv_cdf_Weibull(u, shape, scale) 63 | 64 | def calculate_KLqp_TruncatedWeibullNormal(p_shape, p_scale, q_loc, q_scale, lower, upper, n_samples=20): 65 | # sample from q 66 | sample_q = rsample_TruncatedNormal(q_loc, q_scale, lower, upper, sample_shape=(n_samples,)) 67 | # evaluate log_probs 68 | logp = torch.zeros_like(q_loc) # log_prob_TruncatedWeibull(sample_q, p_shape, p_scale, lower, upper) 69 | logq = log_prob_TruncatedNormal(sample_q, q_loc, q_scale, lower, upper) 70 | # set NaNs to zero 71 | logp[torch.isnan(logp)] = 0.0 72 | # print((sample_q.min().data, sample_q.max().data, logp.sum().data, logq.sum().data)) 73 | # average across replicates, sum over data points 74 | KL_avg = torch.mean(logq - logp, dim=0) 75 | KL = torch.sum(KL_avg) 76 | return KL 77 | -------------------------------------------------------------------------------- /cGPLVM/kernels.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | def RBF(x, y, lengthscale, variance, jitter=None): 5 | N = x.size()[0] 6 | x = x / lengthscale 7 | y = y / lengthscale 8 | s_x = torch.sum(torch.pow(x, 2), dim=1).reshape([-1, 1]) 9 | s_y = torch.sum(torch.pow(y, 2), dim=1).reshape([1, -1]) 10 | K = variance * torch.exp(- 0.5 * (s_x + s_y - 2 * torch.mm(x, y.t()))) 11 | if jitter is not None: 12 | K += jitter * torch.eye(N) 13 | return K 14 | 15 | 16 | def meanzeroRBF(x, y, lengthscale, variance, a, b, jitter=None): 17 | N = x.size()[0] 18 | if x.size()[1] != 1: 19 | raise ValueError("Not implemented for input dim > 1") 20 | 21 | sqrt2 = math.sqrt(2) 22 | sqrtpi = math.sqrt(math.pi) 23 | sqrt2lengthscale = sqrt2 * lengthscale 24 | K = RBF(x, y, lengthscale, variance) 25 | 26 | const = 0.5 * sqrtpi * sqrt2lengthscale * variance 27 | K12 = (torch.erf((b - x) / sqrt2lengthscale) - torch.erf((a - x) / sqrt2lengthscale)) 28 | K21T = (torch.erf((b - y) / sqrt2lengthscale) - torch.erf((a - y) / sqrt2lengthscale)) 29 | 30 | temp = (b - a) / (sqrt2lengthscale) 31 | K22 = 2 * ((a - b) * torch.erf((a - b) / sqrt2lengthscale * torch.ones(1)) + sqrt2 / sqrtpi * lengthscale * ( 32 | torch.exp(-temp ** 2) - 1)) 33 | out = K - const * torch.mm(K12, K21T.t()) / K22 34 | if jitter is not None: 35 | out += jitter * torch.eye(N) 36 | return out 37 | 38 | # calculates the diagonal of mean-zero kernel matrix 39 | def diag_meanzeroRBF(x, lengthscale, variance, a, b, jitter=None): 40 | N = x.size()[0] 41 | x = x.reshape(-1) 42 | 43 | sqrt2 = math.sqrt(2) 44 | sqrtpi = math.sqrt(math.pi) 45 | sqrt2lengthscale = sqrt2 * lengthscale 46 | Kdiag = variance * torch.ones(N) 47 | 48 | const = 0.5 * sqrtpi * sqrt2lengthscale * variance 49 | K12 = (torch.erf((b - x) / sqrt2lengthscale) - torch.erf((a - x) / sqrt2lengthscale)) 50 | 51 | temp = (b - a) / (sqrt2lengthscale) 52 | K22 = 2 * ((a - b) * torch.erf((a - b) / sqrt2lengthscale * torch.ones(1)) + sqrt2 / sqrtpi * lengthscale * ( 53 | torch.exp(-temp ** 2) - 1)) 54 | out = Kdiag - const * K12 * K12 / K22 55 | return out 56 | 57 | 58 | def addint_2D_kernel_decomposition(z, z2, x, x2, ls, var, a_z=-2, b_z=2, a_x=-2, b_x=2, mean_zero=True, jitter=None): 59 | if mean_zero: 60 | 61 | K_zz = meanzeroRBF(z, z2, ls[0], var[0], a_z, b_z, jitter) 62 | 63 | K_xx = meanzeroRBF(x, x2, ls[1], var[1], a_x, b_x, jitter) 64 | 65 | K_intz = meanzeroRBF(z, z2, ls[2], var[2], a_z, b_z, jitter) 66 | K_intx = meanzeroRBF(x, x2, ls[3], 1.0, a_x, b_x, jitter) 67 | 68 | K_int = K_intz * K_intx 69 | else: 70 | K_zz = RBF(z, z2, ls[0], var[0], jitter) 71 | 72 | K_xx = RBF(x, x2, ls[1], var[1], jitter) 73 | 74 | K_intz = RBF(z, z2, ls[2], var[2], jitter) 75 | K_intx = RBF(x, x2, ls[3], 1.0, jitter) 76 | K_int = K_intz * K_intx 77 | 78 | return K_zz, K_xx, K_int 79 | 80 | 81 | def addint_kernel_diag(z, x, ls, var, a_z=-2, b_z=2, a_x=-2, b_x=2, mean_zero=True, jitter=None): 82 | P = x.shape[1] 83 | if mean_zero: 84 | # f(z) 85 | K_zz = diag_meanzeroRBF(z, ls[0], var[0], a_z, b_z, jitter) 86 | # f(x) 87 | K_xx = sum([diag_meanzeroRBF(x[:, j:(j + 1)], ls[1 + j], var[1 + j], a_x, b_x, jitter) for j in range(P)]) 88 | 89 | # f(x, z) 90 | def productkernel(j, z, x, ls, var, jitter): 91 | K_intz = diag_meanzeroRBF(z, ls[1 + j + P], var[1 + j + P], a_z, b_z, jitter) 92 | K_intx = diag_meanzeroRBF(x[:, j:(j + 1)], ls[1 + j + 2 * P], 1.0, a_x, b_x, jitter) 93 | return K_intz * K_intx 94 | 95 | K_int = sum([productkernel(j, z, x, ls, var, jitter) for j in range(P)]) 96 | else: 97 | N = x.shape[0] 98 | K_zz = var[0] * torch.ones(N) 99 | K_xx = var[1] * torch.ones(N) 100 | K_int = var[2] * torch.ones(N) 101 | 102 | return K_zz, K_xx, K_int 103 | -------------------------------------------------------------------------------- /cGPLVM/cGPLVM_survival.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from .helpers import KL_standard_normal, my_softplus 6 | from .helpers_survival import calculate_KLqp_TruncatedWeibullNormal, rsample_TruncatedNormal 7 | 8 | class survival_cGPLVM(nn.Module): 9 | 10 | def __init__(self, x, Y, z_init, is_censored, lower, upper, GP_mapping, fixed_z=False, lr=1e-2, **kwargs): 11 | """ 12 | Initialise survival c-GPLVM (i.e. assuming some covariates c are censored) 13 | :param x: [N, 1] matrix of survival times (for censored data points, the corresponding entries are ignored) 14 | :param Y: Observed [N, P] data matrix 15 | :param z_init: [N, 1] matrix of initialisations for z 16 | :param is_censored: boolean [N, 1] indicating whether survival was observed or censored 17 | :param lower: lower bounds for survival (entries for non-censored data points are ignored), shape [N, 1] 18 | :param upper: upper bounds for survival (entries for non-censored data points are ignored), shape [N, 1] 19 | :param GP_mapping: which GP_mapping to use, e.g. GP_2D_AddInt 20 | :param fixed_z: whether z should be fixed or not 21 | :param lr: learning rate for Adam 22 | """ 23 | super(survival_cGPLVM, self).__init__() 24 | 25 | N = Y.size()[0] 26 | self.Y = Y 27 | self.x_obs = x 28 | self.output_dim = Y.size()[1] 29 | 30 | if fixed_z: 31 | self.z_mu = z_init.clone() 32 | self.z_logsigma = -10.0 * torch.ones_like(z_init) 33 | else: 34 | self.z_mu = nn.Parameter(z_init.clone(), requires_grad=True) 35 | self.z_logsigma = nn.Parameter(-1.0 * torch.ones_like(z_init), requires_grad=True) 36 | 37 | Y_colmeans = Y.mean(axis=0) 38 | 39 | # for every output dimension, create a separate GP object 40 | self.GP_mappings = nn.ModuleList([GP_mapping(intercept_init=Y_colmeans[j], **kwargs) for j in range(self.output_dim)]) 41 | 42 | # censoring related quantities 43 | self.is_censored = is_censored 44 | self.prior_shape = 2.0 * torch.ones(N, 1) 45 | self.prior_scale = 2.0 * torch.ones(N, 1) 46 | self.lower = lower 47 | self.upper = upper 48 | self.q_shape = nn.Parameter(-1.0 + torch.zeros(N, 1)) 49 | self.q_logscale = nn.Parameter(-1.0 + torch.zeros(N, 1)) 50 | 51 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) 52 | 53 | def get_kernel_vars(self): 54 | return np.array([mapping.get_kernel_var().detach().numpy() for mapping in self.GP_mappings]) 55 | 56 | def get_noise_sd(self): 57 | return np.array([mapping.get_noise_var().sqrt().detach().numpy() for mapping in self.GP_mappings]) 58 | 59 | def get_lengthscales(self): 60 | return np.array([mapping.get_ls().detach().numpy() for mapping in self.GP_mappings]) 61 | 62 | def sample_z(self): 63 | eps = torch.randn_like(self.z_mu) 64 | z = self.z_mu + my_softplus(self.z_logsigma) * eps 65 | return z 66 | 67 | def KL_z(self): 68 | return KL_standard_normal(self.z_mu, my_softplus(self.z_logsigma)) 69 | 70 | def loglik(self): 71 | 72 | # sample z 73 | z = self.sample_z() 74 | 75 | # for censored observations, sample x from q(surv) 76 | x_sample = self.sample_x() 77 | 78 | # for data points where "is_censored == True", set x <- x_sample, otherwise set x <- x_obs 79 | x = torch.zeros_like(self.x_obs) 80 | x[self.is_censored, :] = x_sample[self.is_censored, :] 81 | x[~self.is_censored, :] = self.x_obs[~self.is_censored, :] 82 | 83 | loss = 0.0 84 | for j in range(self.output_dim): 85 | loss += self.GP_mappings[j].total_loss(z, x, self.Y[:, j:(j + 1)]) 86 | 87 | return loss 88 | 89 | def optimizer_step(self): 90 | 91 | loss = self.loglik() + self.KL_z() + self.KLqp_survival_helper() 92 | 93 | self.optimizer.zero_grad() 94 | loss.backward() 95 | self.optimizer.step() 96 | 97 | return loss.item() 98 | 99 | def predict(self, z_star, x_star, add_likelihood_variance=False, to_numpy=False): 100 | N_star = z_star.size()[0] 101 | f_mean = torch.zeros(N_star, self.output_dim) 102 | f_var = torch.zeros(N_star, self.output_dim) 103 | for j in range(self.output_dim): 104 | x_sample = self.sample_x() 105 | f_mean[:, j], f_var[:, j] = self.GP_mappings[j].predict(self.z_mu, x_sample, self.Y[:, j:(j + 1)], z_star, 106 | x_star, add_likelihood_variance) 107 | 108 | f_sd = torch.sqrt(1e-6 + f_var) 109 | 110 | if to_numpy: 111 | f_mean, f_sd = f_mean.detach().numpy(), f_sd.detach().numpy() 112 | 113 | return f_mean, f_sd 114 | 115 | def predict_decomposition(self, z_star, x_star, which_kernels=None, to_numpy=False): 116 | N_star = z_star.size()[0] 117 | f_mean = torch.zeros(N_star, self.output_dim) 118 | f_var = torch.zeros(N_star, self.output_dim) 119 | for j in range(self.output_dim): 120 | x_sample = self.sample_x() 121 | f_mean[:, j], f_var[:, j] = self.GP_mappings[j].predict_decomposition(self.z_mu, x_sample, 122 | self.Y[:, j:(j + 1)], z_star, x_star, 123 | which_kernels) 124 | f_sd = torch.sqrt(1e-6 + f_var) 125 | 126 | if to_numpy: 127 | f_mean, f_sd = f_mean.detach().numpy(), f_sd.detach().numpy() 128 | 129 | return f_mean, f_sd 130 | 131 | def get_q_shape(self): 132 | return self.lower + (self.upper - self.lower) * torch.sigmoid(self.q_shape) 133 | 134 | def get_q_scale(self): 135 | upper_bound = 2.0 136 | return 1e-2 + upper_bound * torch.sigmoid(self.q_logscale) 137 | 138 | def KLqp_survival_helper(self): 139 | return calculate_KLqp_TruncatedWeibullNormal(self.prior_shape, self.prior_scale, self.get_q_shape(), 140 | self.get_q_scale(), self.lower, self.upper, n_samples=20) 141 | 142 | # sample x from q(surv) (makes sense for censored observations only) 143 | def sample_x(self): 144 | return rsample_TruncatedNormal(self.get_q_shape(), self.get_q_scale(), self.lower, self.upper)[0, :] 145 | 146 | def train(self, n_iter, verbose=200): 147 | 148 | for t in range(n_iter): 149 | 150 | loss = self.optimizer_step() 151 | 152 | if t % verbose == 0: 153 | print("Iter {0}. Loss {1}".format(t, loss)) 154 | 155 | -------------------------------------------------------------------------------- /cGPLVM/cGPLVM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from .helpers import KL_standard_normal, my_softplus 6 | 7 | class GPLVM(nn.Module): 8 | 9 | def __init__(self, Y, z_init, GP_mapping, lr=1e-2, fixed_z=False, **kwargs): 10 | super(GPLVM, self).__init__() 11 | 12 | self.Y = Y 13 | self.output_dim = Y.size()[1] 14 | 15 | if fixed_z: 16 | self.z_mu = z_init.clone() 17 | self.z_logsigma = -10.0 * torch.ones_like(z_init) 18 | else: 19 | self.z_mu = nn.Parameter(z_init.clone(), requires_grad=True) 20 | self.z_logsigma = nn.Parameter(-1.0 * torch.ones_like(z_init), requires_grad=True) 21 | 22 | Y_colmeans = Y.mean(axis=0) 23 | 24 | # for every output dimension, create a separate GP object 25 | self.GP_mappings = nn.ModuleList([GP_mapping(intercept_init=Y_colmeans[j], **kwargs) for j in range(self.output_dim)]) 26 | 27 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) 28 | 29 | def get_kernel_vars(self): 30 | return np.array([mapping.get_kernel_var().detach().numpy() for mapping in self.GP_mappings]) 31 | 32 | def get_lengthscales(self): 33 | return np.array([mapping.get_ls().detach().numpy() for mapping in self.GP_mappings]) 34 | 35 | def sample_z(self): 36 | eps = torch.randn_like(self.z_mu) 37 | z = self.z_mu + my_softplus(self.z_logsigma) * eps 38 | return z 39 | 40 | def get_z_inferred(self): 41 | return self.z_mu.detach().numpy() 42 | 43 | def optimizer_step(self): 44 | 45 | loss = 0.0 46 | # sample z 47 | z = self.sample_z() 48 | for j in range(self.output_dim): 49 | loss += -self.GP_mappings[j].log_prob(z, self.Y[:, j:(j + 1)]) 50 | 51 | # KL 52 | loss += KL_standard_normal(self.z_mu, my_softplus(self.z_logsigma)) 53 | 54 | self.optimizer.zero_grad() 55 | loss.backward() 56 | self.optimizer.step() 57 | 58 | return loss.item() 59 | 60 | def predict(self, z_star, which_kernels=None, add_likelihood_variance=False, to_numpy=False): 61 | N_star = z_star.size()[0] 62 | f_mean = torch.zeros(N_star, self.output_dim) 63 | f_var = torch.zeros(N_star, self.output_dim) 64 | for j in range(self.output_dim): 65 | f_mean[:, j], f_var[:, j] = self.GP_mappings[j].predict(self.z_mu, self.Y[:, j:(j + 1)], z_star, 66 | which_kernels, add_likelihood_variance) 67 | 68 | f_sd = torch.sqrt(1e-4 + f_var) 69 | 70 | if to_numpy: 71 | f_mean, f_sd = f_mean.detach().numpy(), f_sd.detach().numpy() 72 | 73 | return f_mean, f_sd 74 | 75 | def train(self, n_iter, verbose=200): 76 | 77 | for t in range(n_iter): 78 | 79 | loss = self.optimizer_step() 80 | 81 | if t % verbose == 0: 82 | print(loss) 83 | 84 | 85 | class cGPLVM(nn.Module): 86 | 87 | def __init__(self, x, Y, z_init, GP_mapping, lr=1e-2, fixed_z=False, **kwargs): 88 | super(cGPLVM, self).__init__() 89 | 90 | self.Y = Y 91 | self.x = x 92 | self.output_dim = Y.size()[1] 93 | 94 | if fixed_z: 95 | self.z_mu = z_init.clone() 96 | self.z_logsigma = -10.0 * torch.ones_like(z_init) 97 | else: 98 | self.z_mu = nn.Parameter(z_init.clone(), requires_grad=True) 99 | self.z_logsigma = nn.Parameter(-1.0 * torch.ones_like(z_init), requires_grad=True) 100 | 101 | Y_colmeans = Y.mean(axis=0) 102 | 103 | # for every output dimension, create a separate GP object 104 | self.GP_mappings = nn.ModuleList([GP_mapping(intercept_init=Y_colmeans[j], **kwargs) for j in range(self.output_dim)]) 105 | 106 | self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) 107 | 108 | def get_kernel_vars(self): 109 | return np.array([mapping.get_kernel_var().detach().numpy() for mapping in self.GP_mappings]) 110 | 111 | def get_noise_sd(self): 112 | return np.array([mapping.get_noise_var().sqrt().detach().numpy() for mapping in self.GP_mappings]) 113 | 114 | def get_lengthscales(self): 115 | return np.array([mapping.get_ls().detach().numpy() for mapping in self.GP_mappings]) 116 | 117 | def get_z_inferred(self): 118 | return self.z_mu.detach().numpy() 119 | 120 | def sample_z(self): 121 | eps = torch.randn_like(self.z_mu) 122 | z = self.z_mu + my_softplus(self.z_logsigma) * eps 123 | return z 124 | 125 | def optimizer_step(self): 126 | 127 | loss = 0.0 128 | # sample z 129 | z = self.sample_z() 130 | for j in range(self.output_dim): 131 | loss += self.GP_mappings[j].total_loss(z, self.x, self.Y[:, j:(j + 1)]) 132 | 133 | # KL 134 | loss += KL_standard_normal(self.z_mu, my_softplus(self.z_logsigma)) 135 | 136 | self.optimizer.zero_grad() 137 | loss.backward() 138 | self.optimizer.step() 139 | 140 | return loss.item() 141 | 142 | def predict(self, z_star, x_star, add_likelihood_variance=False, to_numpy=False): 143 | N_star = z_star.size()[0] 144 | f_mean = torch.zeros(N_star, self.output_dim) 145 | f_var = torch.zeros(N_star, self.output_dim) 146 | for j in range(self.output_dim): 147 | f_mean[:, j], f_var[:, j] = self.GP_mappings[j].predict(self.z_mu, self.x, self.Y[:, j:(j + 1)], z_star, 148 | x_star, add_likelihood_variance) 149 | 150 | f_sd = torch.sqrt(1e-6 + f_var) 151 | 152 | if to_numpy: 153 | f_mean, f_sd = f_mean.detach().numpy(), f_sd.detach().numpy() 154 | 155 | return f_mean, f_sd 156 | 157 | def predict_decomposition(self, z_star, x_star, which_kernels=None, to_numpy=False): 158 | N_star = z_star.size()[0] 159 | f_mean = torch.zeros(N_star, self.output_dim) 160 | f_var = torch.zeros(N_star, self.output_dim) 161 | for j in range(self.output_dim): 162 | f_mean[:, j], f_var[:, j] = self.GP_mappings[j].predict_decomposition(self.z_mu, self.x, 163 | self.Y[:, j:(j + 1)], z_star, x_star, 164 | which_kernels) 165 | f_sd = torch.sqrt(1e-6 + f_var) 166 | 167 | if to_numpy: 168 | f_mean, f_sd = f_mean.detach().numpy(), f_sd.detach().numpy() 169 | 170 | return f_mean, f_sd 171 | 172 | def train(self, n_iter, verbose=200): 173 | 174 | for t in range(n_iter): 175 | 176 | loss = self.optimizer_step() 177 | 178 | if t % verbose == 0: 179 | print("Iter {0}. Loss {1}".format(t, loss)) 180 | -------------------------------------------------------------------------------- /cGPLVM/GP_mappings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from torch.distributions.gamma import Gamma 6 | from torch.distributions.normal import Normal 7 | 8 | from .helpers import my_softplus, grid_helper 9 | from .kernels import RBF, meanzeroRBF, addint_2D_kernel_decomposition, addint_kernel_diag 10 | 11 | 12 | # class for GP-regression with inputs (z, x) and 1D output y 13 | class GP_2D_AddInt(nn.Module): 14 | 15 | def __init__(self, z_inducing, x_inducing, mean_zero=True, covariate_dim=1, intercept_init=None, kernel_var_init=None, ls_init=None, 16 | a_z=-2.0, b_z=2.0, a_x=-2.0, b_x=2.0): 17 | 18 | super().__init__() 19 | 20 | self.mean_zero = mean_zero 21 | self.a_z = a_z 22 | self.b_z = b_z 23 | self.a_x = a_x 24 | self.b_x = b_x 25 | self.jitter = 1e-4 26 | 27 | self.covariate_dim = covariate_dim 28 | if covariate_dim > 1: 29 | raise ValueError("This implementation only supports univariate covariate") 30 | 31 | # kernel hyperparameters 32 | n_kernel_vars = 1 + 2 * covariate_dim 33 | n_lengthscales = 1 + 3 * covariate_dim 34 | 35 | # lengthscales 36 | if ls_init is None: 37 | self.ls = nn.Parameter(3.0 * torch.ones(n_lengthscales), requires_grad=True) 38 | else: 39 | self.ls = nn.Parameter(ls_init.clone(), requires_grad=True) 40 | 41 | # kernel variances 42 | if kernel_var_init is None: 43 | var_init = torch.cat([torch.ones(1 + covariate_dim), -1.0 + torch.zeros(covariate_dim)]) 44 | self.var = nn.Parameter(var_init, requires_grad=True) 45 | else: 46 | self.var = nn.Parameter(kernel_var_init.clone(), requires_grad=True) 47 | 48 | self.noise_var = nn.Parameter(-1.0 * torch.ones(1), requires_grad=True) 49 | 50 | if intercept_init is None: 51 | self.intercept = nn.Parameter(torch.zeros(1), requires_grad=True) 52 | else: 53 | self.intercept = nn.Parameter(intercept_init, requires_grad=True) 54 | 55 | # inducing points 56 | self.z_u, self.x_u = grid_helper(z_inducing, x_inducing) 57 | self.M = self.z_u.size()[0] 58 | 59 | def get_kernel_var(self): 60 | return 1e-4 + my_softplus(self.var) 61 | 62 | def get_ls(self): 63 | return 1e-4 + my_softplus(self.ls) 64 | 65 | def get_noise_var(self): 66 | return 1e-4 + my_softplus(self.noise_var) 67 | 68 | def kernel_decomposition_single_covariate(self, z, z2, x, x2, jitter=None): 69 | K1, K2, K3 = addint_2D_kernel_decomposition(z, z2, x, x2, self.get_ls(), self.get_kernel_var(), 70 | a_z=self.a_z, b_z=self.b_z, a_x=self.a_x, b_x=self.b_x, 71 | mean_zero=self.mean_zero, jitter=jitter) 72 | return K1, K2, K3 73 | 74 | def kernel_decomposition(self, z, z2, x, x2, jitter=None): 75 | if self.covariate_dim == 1: 76 | return self.kernel_decomposition_single_covariate(z, z2, x, x2, jitter) 77 | else: 78 | return self.kernel_decomposition_multiple_covariates(z, z2, x, x2, jitter) 79 | 80 | def get_K_without_noise(self, z, z2, x, x2, jitter=None, which_kernels=None): 81 | K1, K2, K3 = self.kernel_decomposition(z, z2, x, x2, jitter) 82 | if which_kernels is None: 83 | return K1 + K2 + K3 84 | else: 85 | return which_kernels[0] * K1 + which_kernels[1] * K2 + which_kernels[2] * K3 86 | 87 | def get_K_diag(self, z, x): 88 | K1diag, K2diag, K3diag = addint_kernel_diag(z, x, self.get_ls(), self.get_kernel_var(), 89 | a_z=self.a_z, b_z=self.b_z, a_x=self.a_x, b_x=self.b_x, 90 | mean_zero=self.mean_zero) 91 | return K1diag + K2diag + K3diag 92 | 93 | def log_prob(self, z, x, y): 94 | subset = ~torch.isnan(y).reshape(-1) 95 | if subset.sum() > 0: 96 | y = y[subset, :] 97 | z = z[subset, :] 98 | x = x[subset, :] 99 | # inducing points log_prob 100 | N = y.size()[0] 101 | y = (y - self.intercept) 102 | sigma2 = self.get_noise_var() 103 | sigma = torch.sqrt(sigma2) 104 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, self.x_u, self.x_u, jitter=self.jitter) 105 | K_uf = self.get_K_without_noise(self.z_u, z, self.x_u, x) 106 | L = torch.cholesky(K_uu) 107 | A = torch.triangular_solve(K_uf, L, upper=False)[0] 108 | AAT = torch.mm(A, A.t()) 109 | B = AAT + torch.eye(self.M) 110 | LB = torch.cholesky(B) 111 | Aerr = torch.mm(A, y) 112 | c = torch.triangular_solve(Aerr, LB, upper=False)[0] / sigma 113 | 114 | Kdiag = self.get_K_diag(z, x) 115 | 116 | bound = -0.5 * N * np.log(2 * np.pi) - torch.sum(torch.log(torch.diag(LB))) - 0.5 * N * torch.log( 117 | sigma2) - 0.5 * torch.sum(torch.pow(y, 2)) / sigma2 118 | bound += 0.5 * torch.sum(torch.pow(c, 2)) - 0.5 * torch.sum(Kdiag) / sigma2 + 0.5 * torch.sum(torch.diag(AAT)) 119 | return bound 120 | 121 | def predict(self, z, x, y, z_star, x_star, add_likelihood_variance=False): 122 | subset = ~torch.isnan(y).reshape(-1) 123 | if subset.sum() > 0: 124 | y = y[subset, :] 125 | z = z[subset, :] 126 | x = x[subset, :] 127 | Nstar = z_star.size()[0] 128 | y = (y - self.intercept) 129 | sigma2 = self.get_noise_var() 130 | 131 | sigma = torch.sqrt(sigma2) 132 | 133 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, self.x_u, self.x_u, jitter=self.jitter) 134 | K_uf = self.get_K_without_noise(self.z_u, z, self.x_u, x) 135 | K_us = self.get_K_without_noise(self.z_u, z_star, self.x_u, x_star) 136 | 137 | L = torch.cholesky(K_uu) 138 | A = torch.triangular_solve(K_uf, L, upper=False)[0] / sigma 139 | AAT = torch.mm(A, A.t()) 140 | B = AAT + torch.eye(self.M) 141 | LB = torch.cholesky(B) 142 | Aerr = torch.mm(A, y) 143 | c = torch.triangular_solve(Aerr, LB, upper=False)[0] / sigma 144 | 145 | tmp1 = torch.triangular_solve(K_us, L, upper=False)[0] 146 | tmp2 = torch.triangular_solve(tmp1, LB, upper=False)[0] 147 | mean = self.intercept + torch.mm(tmp2.t(), c) 148 | 149 | Kdiag = self.get_K_diag(z_star, x_star) 150 | var = Kdiag + torch.pow(tmp2, 2).sum(dim=0) - torch.pow(tmp1, 2).sum(dim=0) 151 | 152 | if add_likelihood_variance: 153 | var += self.get_noise_var() 154 | 155 | return mean.reshape(-1), var 156 | 157 | def predict_decomposition(self, z, x, y, z_star, x_star, which_kernels): 158 | subset = ~torch.isnan(y).reshape(-1) 159 | if subset.sum() > 0: 160 | y = y[subset, :] 161 | z = z[subset, :] 162 | x = x[subset, :] 163 | Nstar = z_star.size()[0] 164 | y = (y - self.intercept) 165 | sigma2 = self.get_noise_var() 166 | 167 | K_all = self.get_K_without_noise(z, z, x, x, jitter=self.jitter) + sigma2 * torch.eye(y.size()[0]) 168 | L_all = torch.cholesky(K_all) 169 | K_all_inv = torch.cholesky_inverse(L_all) 170 | 171 | K_sf = self.get_K_without_noise(z_star, z, x_star, x, which_kernels=which_kernels) 172 | K_ss = self.get_K_without_noise(z_star, z_star, x_star, x_star, which_kernels=which_kernels) 173 | 174 | tmp = torch.mm(K_sf, K_all_inv) 175 | mean = self.intercept + torch.mm(tmp, y) 176 | var = torch.diag(K_ss - torch.mm(tmp, K_sf.t())) 177 | 178 | return mean.reshape(-1), var 179 | 180 | def prior_loss(self): 181 | p_ls = Gamma(50.0, 10.0).log_prob(self.get_ls()).sum() 182 | p_var = Gamma(1.0, 1.0).log_prob(self.get_kernel_var()).sum() 183 | return -1.0 * (p_ls + p_var) 184 | 185 | def total_loss(self, z, x, y): 186 | return -self.log_prob(z, x, y) + self.prior_loss() 187 | 188 | 189 | # GP with 2D ARD kernel on (z, x) 190 | class GP_2D_INT(nn.Module): 191 | 192 | def __init__(self, z_inducing=None, x_inducing=None, intercept_init=None): 193 | 194 | super().__init__() 195 | 196 | self.jitter = 1e-3 197 | 198 | # kernel hyperparameters 199 | self.ls = nn.Parameter(2.0*torch.ones(2), requires_grad=True) 200 | self.var = nn.Parameter(torch.zeros(1), requires_grad=True) 201 | self.noise_var = nn.Parameter(-1.0 * torch.ones(1), requires_grad=True) 202 | 203 | if intercept_init is None: 204 | self.intercept = nn.Parameter(torch.zeros(1), requires_grad=True) 205 | else: 206 | self.intercept = nn.Parameter(intercept_init, requires_grad=True) 207 | 208 | # create a grid of inducing points (assuming one-dimensional z and x) 209 | self.z_u, self.x_u = grid_helper(z_inducing, x_inducing) 210 | self.M = self.z_u.size()[0] 211 | 212 | def get_kernel_var(self): 213 | return 1e-4 + my_softplus(self.var) 214 | 215 | def get_ls(self): 216 | return 1e-4 + my_softplus(self.ls) 217 | 218 | def get_noise_var(self): 219 | return 1e-4 + my_softplus(self.noise_var) 220 | 221 | def get_K_without_noise(self, z, z2, x, x2, jitter=None): 222 | zx = torch.cat([z, x], dim=1) 223 | zx2 = torch.cat([z2, x2], dim=1) 224 | K = RBF(zx, zx2, self.get_ls(), self.get_kernel_var(), jitter) 225 | return K 226 | 227 | def log_prob(self, z, x, y): 228 | # inducing points log_prob 229 | subset = ~torch.isnan(y).reshape(-1) 230 | if subset.sum() > 0: 231 | y = y[subset, :] 232 | z = z[subset, :] 233 | x = x[subset, :] 234 | N = y.size()[0] 235 | y = (y - self.intercept) 236 | sigma2 = self.get_noise_var() 237 | sigma = torch.sqrt(sigma2) 238 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, self.x_u, self.x_u, jitter=self.jitter) 239 | K_uf = self.get_K_without_noise(self.z_u, z, self.x_u, x) 240 | L = torch.cholesky(K_uu) 241 | A = torch.triangular_solve(K_uf, L, upper=False)[0] / sigma 242 | AAT = torch.mm(A, A.t()) 243 | B = AAT + torch.eye(self.M) 244 | LB = torch.cholesky(B) 245 | Aerr = torch.mm(A, y) 246 | c = torch.triangular_solve(Aerr, LB, upper=False)[0] / sigma 247 | 248 | Kdiag = self.get_kernel_var().repeat(N) 249 | 250 | bound = -0.5 * N * np.log(2 * np.pi) - torch.sum(torch.log(torch.diag(LB))) - 0.5 * N * torch.log( 251 | sigma2) - 0.5 * torch.sum(torch.pow(y, 2)) / sigma2 252 | bound += 0.5 * torch.sum(torch.pow(c, 2)) - 0.5 * torch.sum(Kdiag) / sigma2 + 0.5 * torch.sum(torch.diag(AAT)) 253 | return bound 254 | 255 | def predict(self, z, x, y, z_star, x_star, add_likelihood_variance=False): 256 | subset = ~torch.isnan(y).reshape(-1) 257 | if subset.sum() > 0: 258 | y = y[subset, :] 259 | z = z[subset, :] 260 | x = x[subset, :] 261 | Nstar = z_star.size()[0] 262 | y = (y - self.intercept) 263 | sigma2 = self.get_noise_var() 264 | sigma = torch.sqrt(sigma2) 265 | 266 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, self.x_u, self.x_u, jitter=self.jitter) 267 | K_uf = self.get_K_without_noise(self.z_u, z, self.x_u, x) 268 | K_us = self.get_K_without_noise(self.z_u, z_star, self.x_u, x_star) 269 | 270 | L = torch.cholesky(K_uu) 271 | A = torch.triangular_solve(K_uf, L, upper=False)[0] / sigma 272 | AAT = torch.mm(A, A.t()) 273 | B = AAT + torch.eye(self.M) 274 | LB = torch.cholesky(B) 275 | Aerr = torch.mm(A, y) 276 | c = torch.triangular_solve(Aerr, LB, upper=False)[0] / sigma 277 | 278 | tmp1 = torch.triangular_solve(K_us, L, upper=False)[0] 279 | tmp2 = torch.triangular_solve(tmp1, LB, upper=False)[0] 280 | mean = self.intercept + torch.mm(tmp2.t(), c) 281 | 282 | K_ss = self.get_K_without_noise(z_star, z_star, x_star, x_star) 283 | Kdiag = K_ss.diag() 284 | var = Kdiag + torch.pow(tmp2, 2).sum(dim=0) - torch.pow(tmp1, 2).sum(dim=0) 285 | 286 | if add_likelihood_variance: 287 | var += self.get_noise_var() 288 | 289 | return mean.reshape(-1), var 290 | 291 | def log_prob_fullrank(self, z, x, y): 292 | N = z.size()[0] 293 | y = (y - self.intercept).reshape(-1) 294 | K = self.get_K_without_noise(z, z, x, x, jitter=self.jitter, which_kernels=None) 295 | K_noise = self.get_noise_var() * torch.eye(N) 296 | return MultivariateNormal(torch.zeros_like(y), K + K_noise).log_prob(y) 297 | 298 | def prior_loss(self): 299 | prior_var = -Gamma(1.0, 1.0).log_prob(self.get_kernel_var()).sum() 300 | prior_ls = -Gamma(10.0, 1.0).log_prob(self.get_ls()).sum() 301 | return prior_var + prior_ls 302 | 303 | def total_loss(self, z, x, y): 304 | return -self.log_prob(z, x, y) + self.prior_loss() 305 | 306 | 307 | # additive GP 308 | class GP_2D_ADD(nn.Module): 309 | 310 | def __init__(self, z_inducing, x_inducing, intercept_init=None): 311 | 312 | super().__init__() 313 | 314 | self.jitter = 1e-4 315 | 316 | # kernel hyperparameters 317 | self.ls = nn.Parameter(torch.ones(2), requires_grad=True) 318 | self.var = nn.Parameter(torch.zeros(2), requires_grad=True) 319 | self.noise_var = nn.Parameter(-1.0 * torch.ones(1), requires_grad=True) 320 | 321 | if intercept_init is None: 322 | self.intercept = nn.Parameter(torch.zeros(1), requires_grad=True) 323 | else: 324 | self.intercept = nn.Parameter(intercept_init, requires_grad=True) 325 | 326 | # inducing points 327 | grid = torch.linspace(-3, 3, steps=10).reshape(-1, 1) 328 | self.z_u, self.x_u = grid_helper(z_inducing, x_inducing) 329 | self.M = self.z_u.size()[0] 330 | 331 | def get_kernel_var(self): 332 | return 1e-4 + my_softplus(self.var) 333 | 334 | def get_ls(self): 335 | return 1e-4 + my_softplus(self.ls) 336 | 337 | def get_noise_var(self): 338 | return 1e-4 + my_softplus(self.noise_var) 339 | 340 | def kernel_decomposition(self, z, z2, x, x2, jitter=None): 341 | ls = self.get_ls() 342 | var = self.get_kernel_var() 343 | K1 = RBF(z, z2, ls[0], var[0], jitter) 344 | K2 = RBF(x, x2, ls[1], var[1], jitter) 345 | return K1, K2 346 | 347 | def get_K_without_noise(self, z, z2, x, x2, jitter=None, which_kernels=None): 348 | K1, K2 = self.kernel_decomposition(z, z2, x, x2, jitter) 349 | if which_kernels is None: 350 | return K1 + K2 351 | else: 352 | return which_kernels[0] * K1 + which_kernels[1] * K2 353 | 354 | def log_prob(self, z, x, y): 355 | subset = ~torch.isnan(y).reshape(-1) 356 | if subset.sum() > 0: 357 | y = y[subset, :] 358 | z = z[subset, :] 359 | x = x[subset, :] 360 | # inducing points log_prob 361 | 362 | N = y.size()[0] 363 | y = (y - self.intercept) 364 | sigma2 = self.get_noise_var() 365 | sigma = torch.sqrt(sigma2) 366 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, self.x_u, self.x_u, jitter=self.jitter) 367 | K_uf = self.get_K_without_noise(self.z_u, z, self.x_u, x) 368 | L = torch.cholesky(K_uu) 369 | A = torch.triangular_solve(K_uf, L, upper=False)[0] / sigma 370 | AAT = torch.mm(A, A.t()) 371 | B = AAT + torch.eye(self.M) 372 | LB = torch.cholesky(B) 373 | Aerr = torch.mm(A, y) 374 | c = torch.triangular_solve(Aerr, LB, upper=False)[0] / sigma 375 | 376 | K_ff = self.get_K_without_noise(z, z, x, x) 377 | Kdiag = K_ff.diag() 378 | 379 | bound = -0.5 * N * np.log(2 * np.pi) - torch.sum(torch.log(torch.diag(LB))) - 0.5 * N * torch.log( 380 | sigma2) - 0.5 * torch.sum(torch.pow(y, 2)) / sigma2 381 | bound += 0.5 * torch.sum(torch.pow(c, 2)) - 0.5 * torch.sum(Kdiag) / sigma2 + 0.5 * torch.sum(torch.diag(AAT)) 382 | return bound 383 | 384 | def predict(self, z, x, y, z_star, x_star, add_likelihood_variance=False): 385 | subset = ~torch.isnan(y).reshape(-1) 386 | if subset.sum() > 0: 387 | y = y[subset, :] 388 | z = z[subset, :] 389 | x = x[subset, :] 390 | Nstar = z_star.size()[0] 391 | y = (y - self.intercept) 392 | sigma2 = self.get_noise_var() 393 | sigma = torch.sqrt(sigma2) 394 | 395 | K_all_inv = torch.inverse( 396 | self.get_K_without_noise(z, z, x, x, jitter=self.jitter) + sigma2 * torch.eye(y.size()[0])) 397 | 398 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, self.x_u, self.x_u, jitter=self.jitter) 399 | K_uf = self.get_K_without_noise(self.z_u, z, self.x_u, x) 400 | K_us = self.get_K_without_noise(self.z_u, z_star, self.x_u, x_star) 401 | 402 | K_sf = self.get_K_without_noise(z_star, z, x_star, x) 403 | K_ss = self.get_K_without_noise(z_star, z_star, x_star, x_star) 404 | tmp = torch.mm(K_sf, K_all_inv) 405 | mean = self.intercept + torch.mm(tmp, y) 406 | var = torch.diag(K_ss - torch.mm(tmp, K_sf.t())) 407 | 408 | if add_likelihood_variance: 409 | var += self.get_noise_var() 410 | 411 | return mean.reshape(-1), var 412 | 413 | def log_prob_fullrank(self, z, x, y): 414 | N = z.size()[0] 415 | y = (y - self.intercept).reshape(-1) 416 | K = self.get_K_without_noise(z, z, x, x, jitter=self.jitter, which_kernels=None) 417 | K_noise = self.get_noise_var() * torch.eye(N) 418 | return MultivariateNormal(torch.zeros_like(y), K + K_noise).log_prob(y) 419 | 420 | def prior_loss(self): 421 | return -Gamma(1.0, 1.0).log_prob(self.get_kernel_var()).sum() 422 | 423 | def total_loss(self, z, x, y): 424 | return -self.log_prob(z, x, y) + self.prior_loss() 425 | 426 | 427 | # GP with 1D inputs (useful for additive) 428 | class GP_1D(nn.Module): 429 | 430 | def __init__(self, z_inducing, intercept_init=None): 431 | 432 | super().__init__() 433 | 434 | self.jitter = 1e-4 435 | 436 | # kernel hyperparameters 437 | self.ls = nn.Parameter(3.0*torch.ones(1), requires_grad=True) 438 | self.var = nn.Parameter(torch.zeros(1), requires_grad=True) 439 | self.noise_var = nn.Parameter(-1.0 * torch.ones(1), requires_grad=True) 440 | 441 | if intercept_init is None: 442 | self.intercept = nn.Parameter(torch.zeros(1), requires_grad=True) 443 | else: 444 | self.intercept = nn.Parameter(intercept_init, requires_grad=True) 445 | 446 | # inducing points 447 | self.z_u = z_inducing 448 | self.M = self.z_u.size()[0] 449 | 450 | def get_kernel_var(self): 451 | return 1e-4 + my_softplus(self.var) 452 | 453 | def get_ls(self): 454 | return 1e-4 + my_softplus(self.ls) 455 | 456 | def get_noise_var(self): 457 | return 1e-4 + my_softplus(self.noise_var) 458 | 459 | def kernel_decomposition(self, z, z2, jitter=None): 460 | ls = self.get_ls() 461 | var = self.get_kernel_var() 462 | K1 = RBF(z, z2, ls[0], var[0], jitter) 463 | return K1 464 | 465 | def get_K_without_noise(self, z, z2, jitter=None, which_kernels=None): 466 | K1 = self.kernel_decomposition(z, z2, jitter) 467 | return K1 468 | 469 | def log_prob(self, z, y): 470 | # inducing points log_prob 471 | subset = ~torch.isnan(y).reshape(-1) 472 | if subset.sum() > 0: 473 | y = y[subset, :] 474 | z = z[subset, :] 475 | N = y.size()[0] 476 | y = (y - self.intercept) 477 | sigma2 = self.get_noise_var() 478 | sigma = torch.sqrt(sigma2) 479 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, jitter=self.jitter) 480 | K_uf = self.get_K_without_noise(self.z_u, z) 481 | L = torch.cholesky(K_uu) 482 | A = torch.triangular_solve(K_uf, L, upper=False)[0] / sigma 483 | AAT = torch.mm(A, A.t()) 484 | B = AAT + torch.eye(self.M) 485 | LB = torch.cholesky(B) 486 | Aerr = torch.mm(A, y) 487 | c = torch.triangular_solve(Aerr, LB, upper=False)[0] / sigma 488 | 489 | K_ff = self.get_K_without_noise(z, z) 490 | Kdiag = K_ff.diag() 491 | 492 | bound = -0.5 * N * np.log(2 * np.pi) - torch.sum(torch.log(torch.diag(LB))) - 0.5 * N * torch.log( 493 | sigma2) - 0.5 * torch.sum(torch.pow(y, 2)) / sigma2 494 | bound += 0.5 * torch.sum(torch.pow(c, 2)) - 0.5 * torch.sum(Kdiag) / sigma2 + 0.5 * torch.sum(torch.diag(AAT)) 495 | return bound 496 | 497 | def predict(self, z, y, z_star, which_kernels, add_likelihood_variance=False): 498 | subset = ~torch.isnan(y).reshape(-1) 499 | if subset.sum() > 0: 500 | y = y[subset, :] 501 | z = z[subset, :] 502 | Nstar = z_star.size()[0] 503 | y = (y - self.intercept) 504 | sigma2 = self.get_noise_var() 505 | sigma = torch.sqrt(sigma2) 506 | K_uu = self.get_K_without_noise(self.z_u, self.z_u, jitter=self.jitter, which_kernels=which_kernels) 507 | K_uf = self.get_K_without_noise(self.z_u, z, which_kernels=which_kernels) 508 | K_us = self.get_K_without_noise(self.z_u, z_star, which_kernels=which_kernels) 509 | 510 | L = torch.cholesky(K_uu) 511 | A = torch.triangular_solve(K_uf, L, upper=False)[0] / sigma 512 | AAT = torch.mm(A, A.t()) 513 | B = AAT + torch.eye(self.M) 514 | LB = torch.cholesky(B) 515 | Aerr = torch.mm(A, y) 516 | c = torch.triangular_solve(Aerr, LB, upper=False)[0] / sigma 517 | 518 | tmp1 = torch.triangular_solve(K_us, L, upper=False)[0] 519 | tmp2 = torch.triangular_solve(tmp1, LB, upper=False)[0] 520 | mean = self.intercept + torch.mm(tmp2.t(), c) 521 | 522 | K_ss = self.get_K_without_noise(z_star, z_star) 523 | Kdiag = K_ss.diag() 524 | var = Kdiag + torch.pow(tmp2, 2).sum(dim=0) - torch.pow(tmp1, 2).sum(dim=0) 525 | 526 | if add_likelihood_variance: 527 | var += self.get_noise_var() 528 | 529 | return mean.reshape(-1), var 530 | 531 | def log_prob_fullrank(self, z, y): 532 | N = z.size()[0] 533 | y = (y - self.intercept).reshape(-1) 534 | K = self.get_K_without_noise(z, z, jitter=self.jitter, which_kernels=None) 535 | K_noise = self.get_noise_var() * torch.eye(N) 536 | return MultivariateNormal(torch.zeros_like(y), K + K_noise).log_prob(y) 537 | 538 | def prior_loss(self): 539 | return -Gamma(1.0, 1.0).log_prob(self.get_kernel_var()).sum() 540 | 541 | def total_loss(self, z, x, y): 542 | return -self.log_prob(z, x, y) + self.prior_loss() 543 | 544 | 545 | class linear_mapping(nn.Module): 546 | 547 | def __init__(self, intercept_init=None): 548 | super().__init__() 549 | 550 | self.jitter = 1e-3 551 | 552 | # parameters 553 | self.beta = nn.Parameter(torch.zeros(3), requires_grad=True) 554 | self.noise_var = nn.Parameter(-1.0 * torch.ones(1), requires_grad=True) 555 | 556 | if intercept_init is None: 557 | self.intercept = nn.Parameter(torch.zeros(1), requires_grad=True) 558 | else: 559 | self.intercept = nn.Parameter(intercept_init, requires_grad=True) 560 | 561 | def get_noise_var(self): 562 | return 1e-4 + my_softplus(self.noise_var) 563 | 564 | def get_kernel_var(self): 565 | return self.beta 566 | 567 | def log_prob(self, z, x, y): 568 | # inducing points log_prob 569 | N = y.size()[0] 570 | y_pred = self.intercept + self.beta[0] * z + self.beta[1] * x + self.beta[2] * z * x 571 | log_prob = Normal(y_pred, torch.sqrt(self.get_noise_var())).log_prob(y).sum() 572 | return log_prob 573 | 574 | def predict(self, z, x, y, z_star, x_star): 575 | mean = self.intercept + self.beta[0] * z_star + self.beta[1] * x_star + self.beta[2] * z_star * x_star 576 | return mean.reshape(-1), torch.zeros_like(mean).reshape(-1) 577 | 578 | def prior_loss(self): 579 | return -Gamma(1.0, 1.0).log_prob(self.beta).sum() 580 | 581 | def total_loss(self, z, x, y): 582 | return -self.log_prob(z, x, y) # + self.prior_loss() 583 | --------------------------------------------------------------------------------