├── GP_helpers.R ├── NP_architecture1.R ├── NP_architecture2.R ├── NP_core.R ├── README.md ├── experiments ├── 0_draws_from_GP_prior.R ├── 0_draws_from_prior.R ├── 1_experiment.R ├── 2_experiment.R └── 3_experiment.R ├── fig ├── GP_draws.png ├── NP_banner.gif ├── draws_from_prior.png ├── draws_from_prior_relu.png ├── experiment1.gif ├── experiment1.png ├── experiment2.gif ├── experiment2_latent_space.png ├── experiment2_misspecification.png ├── experiment2_pred.png ├── experiment3.gif ├── experiment3_1.gif ├── experiment3_2.gif ├── experiment3_pred1.png ├── experiment3_pred2.png ├── experiment4.png ├── observed_data.png ├── schema1.png ├── schema2.png ├── schema3.png └── two_scenarios.png ├── helpers_for_plotting.R └── previous_implementation ├── _NP_core.R ├── _NP_helpers.R └── _helpers_for_plotting.R /GP_helpers.R: -------------------------------------------------------------------------------- 1 | 2 | rbf_kernel <- function(X, Y, ls, var){ 3 | X <- X / ls 4 | Y <- Y / ls 5 | N1 <- nrow(X) 6 | N2 <- nrow(Y) 7 | s1 <- rowSums(X**2) 8 | s2 <- rowSums(Y**2) 9 | square <- matrix(s1, N1, N2) + matrix(s2, N1, N2, byrow=T) - 2 * X %*% t(Y) 10 | var * exp(-0.5 * square) 11 | } 12 | 13 | # helper function for fitting GP using gpflow 14 | 15 | fit_GP <- function(x, y, x_star, n_draws = 10L){ 16 | k <- gpflow$kernels$RBF(1L, lengthscales = 1.0) 17 | k$variance$prior <- gpflow$priors$Gamma(1.0, 1.0) 18 | k$lengthscales$prior <- gpflow$priors$Gamma(1.0, 1.0) 19 | m <- gpflow$gpr$GPR(cbind(x), cbind(y), k) 20 | m$likelihood$variance <- 0.05 21 | m$optimize() 22 | # print(k$lengthscales) 23 | pred_mat <- m$predict_f_samples(cbind(x_star), n_draws) 24 | pred_mat %>% 25 | reshape2::melt() %>% 26 | rename(index = Var2, draw = Var1, y = value) %>% 27 | mutate(x = x_star[index]) 28 | } 29 | 30 | fit_and_plot_GP <- function(x0, y0, x_star){ 31 | df_pred <- fit_GP(x0, y0, x_star, n_draws = 50L) 32 | df_obs <- data.frame(x = x0, y = y0) 33 | df_pred %>% 34 | ggplot(aes(x, y)) + 35 | geom_line(aes(group = draw), alpha = 0.25) + 36 | geom_point(data = df_obs, col = "#b2182b", size = 3) + 37 | theme_classic() 38 | } 39 | -------------------------------------------------------------------------------- /NP_architecture1.R: -------------------------------------------------------------------------------- 1 | # Architecture for the Neural Process 2 | 3 | # encoder h -- map inputs (x_i, y_i) to r_i 4 | h <- function(input){ 5 | input %>% 6 | tf$layers$dense(dim_h_hidden, tf$nn$sigmoid, name = "encoder_layer1", reuse = tf$AUTO_REUSE) %>% 7 | tf$layers$dense(dim_r, name = "encoder_layer2", reuse = tf$AUTO_REUSE) 8 | } 9 | 10 | # aggregate the output of h (i.e. values of r_i) to a single vector r 11 | aggregate_r <- function(input){ 12 | input %>% 13 | tf$reduce_mean(axis=0L) %>% 14 | tf$reshape(shape(1L, -1L)) 15 | } 16 | 17 | # map aggregated r to (mu_z, sigma_z) 18 | get_z_params <- function(input_r){ 19 | 20 | mu <- input_r %>% 21 | tf$layers$dense(dim_z, name = "z_params_mu", reuse = tf$AUTO_REUSE) 22 | 23 | sigma <- input_r %>% 24 | tf$layers$dense(dim_z, name = "z_params_sigma", reuse = tf$AUTO_REUSE) %>% 25 | tf$nn$softplus() 26 | 27 | list(mu = mu, sigma = sigma) 28 | } 29 | 30 | 31 | # decoder g -- map (z, x*) -> hidden -> y* 32 | g <- function(z_sample, x_star, noise_sd = 0.05){ 33 | # inputs dimensions 34 | # z_sample has dim [n_draws, dim_z] 35 | # x_star has dim [N_star, dim_x] 36 | 37 | n_draws <- z_sample$get_shape()$as_list()[1] 38 | N_star <- tf$shape(x_star)[1] 39 | 40 | # z_sample_rep will have dim [n_draws, N_star, dim_z] 41 | z_sample_rep <- z_sample %>% 42 | tf$expand_dims(axis = 1L) %>% 43 | tf$tile(c(1L, N_star, 1L)) 44 | 45 | # x_star_rep will have dim [n_draws, N_star, dim_x] 46 | x_star_rep <- x_star %>% 47 | tf$expand_dims(axis = 0L) %>% 48 | tf$tile(shape(n_draws, 1L, 1L)) 49 | 50 | # concatenate x* and z 51 | input <- list(x_star_rep, z_sample_rep) %>% 52 | tf$concat(axis = 2L) 53 | 54 | # hidden layer 55 | hidden <- input %>% 56 | tf$layers$dense(dim_g_hidden, tf$nn$sigmoid, name = "decoder_layer1", reuse = tf$AUTO_REUSE) 57 | 58 | # mu will be of the shape [N_star, n_draws] 59 | mu_star <- hidden %>% 60 | tf$layers$dense(1L, name = "decoder_layer2", reuse = tf$AUTO_REUSE) %>% 61 | tf$squeeze(axis = 2L) %>% 62 | tf$transpose() 63 | 64 | # for the toy example, assume y* ~ N(mu, sigma) with fixed sigma 65 | sigma_star <- tf$constant(noise_sd, dtype = tf$float32) 66 | 67 | list(mu = mu_star, sigma = sigma_star) 68 | } 69 | 70 | -------------------------------------------------------------------------------- /NP_architecture2.R: -------------------------------------------------------------------------------- 1 | # Architecture for the Neural Process 2 | 3 | # encoder h -- map inputs (x_i, y_i) to r_i 4 | h <- function(input){ 5 | input %>% 6 | tf$layers$dense(dim_h_hidden, tf$nn$relu, name = "encoder_layer1", reuse = tf$AUTO_REUSE) %>% 7 | tf$layers$dense(dim_r, name = "encoder_layer2", reuse = tf$AUTO_REUSE) 8 | } 9 | 10 | # aggregate the output of h (i.e. values of r_i) to a single vector r 11 | aggregate_r <- function(input){ 12 | input %>% 13 | tf$reduce_mean(axis=0L) %>% 14 | tf$reshape(shape(1L, -1L)) 15 | } 16 | 17 | # map aggregated r to (mu_z, sigma_z) 18 | get_z_params <- function(input_r){ 19 | 20 | hidden <- input_r 21 | # hidden <- input_r %>% 22 | # tf$layers$dense(dim_r, tf$nn$sigmoid, name = "z_params_layer1", reuse = tf$AUTO_REUSE) 23 | 24 | mu <- hidden %>% 25 | tf$layers$dense(dim_z, name = "z_params_mu", reuse = tf$AUTO_REUSE) 26 | 27 | sigma <- hidden %>% 28 | tf$layers$dense(dim_z, name = "z_params_sigma", reuse = tf$AUTO_REUSE) %>% 29 | tf$nn$softplus() 30 | 31 | list(mu = mu, sigma = sigma) 32 | } 33 | 34 | 35 | # decoder g -- map (z, x*) -> hidden -> y* 36 | g <- function(z_sample, x_star, noise_sd = 0.05){ 37 | # inputs dimensions 38 | # z_sample has dim [n_draws, dim_z] 39 | # x_star has dim [N_star, dim_x] 40 | 41 | n_draws <- z_sample$get_shape()$as_list()[1] 42 | N_star <- tf$shape(x_star)[1] 43 | 44 | # z_sample_rep will have dim [n_draws, N_star, dim_z] 45 | z_sample_rep <- z_sample %>% 46 | tf$expand_dims(axis = 1L) %>% 47 | tf$tile(c(1L, N_star, 1L)) 48 | 49 | # x_star_rep will have dim [n_draws, N_star, dim_x] 50 | x_star_rep <- x_star %>% 51 | tf$expand_dims(axis = 0L) %>% 52 | tf$tile(shape(n_draws, 1L, 1L)) 53 | 54 | # concatenate x* and z 55 | input <- list(x_star_rep, z_sample_rep) %>% 56 | tf$concat(axis = 2L) 57 | 58 | # hidden layer 59 | hidden <- input %>% 60 | tf$layers$dense(dim_g_hidden, tf$nn$relu, name = "decoder_layer1", reuse = tf$AUTO_REUSE) 61 | 62 | # mu will be of the shape [N_star, n_draws] 63 | mu_star <- hidden %>% 64 | tf$layers$dense(1L, name = "decoder_layer2", reuse = tf$AUTO_REUSE) %>% 65 | tf$squeeze(axis = 2L) %>% 66 | tf$transpose() 67 | 68 | # for the toy example, assume y* ~ N(mu, sigma) with fixed sigma 69 | sigma_star <- tf$constant(noise_sd, dtype = tf$float32) 70 | 71 | list(mu = mu_star, sigma = sigma_star) 72 | } 73 | 74 | -------------------------------------------------------------------------------- /NP_core.R: -------------------------------------------------------------------------------- 1 | 2 | # helper function to map (x, y) -> z directly without intermediate steps 3 | map_xy_to_z_params <- function(x, y){ 4 | list(x, y) %>% 5 | tf$concat(axis = 1L) %>% 6 | h() %>% 7 | aggregate_r() %>% 8 | get_z_params() 9 | } 10 | 11 | # set up the NN architecture with train_op and loss 12 | init_NP <- function(x_context, y_context, x_target, y_target, learning_rate = 0.001){ 13 | 14 | # concatenate context and target 15 | x_all <- tf$concat(list(x_context, x_target), axis = 0L) 16 | y_all <- tf$concat(list(y_context, y_target), axis = 0L) 17 | 18 | # map input to z 19 | z_context <- map_xy_to_z_params(x_context, y_context) 20 | z_all <- map_xy_to_z_params(x_all, y_all) 21 | 22 | # sample z using reparametrisation, z = mu + sigma*eps 23 | epsilon <- tf$random_normal(shape(7L, dim_z)) 24 | z_sample <- epsilon %>% 25 | tf$multiply(z_all$sigma) %>% 26 | tf$add(z_all$mu) 27 | 28 | # map (z, x*) to y* 29 | y_pred_params <- g(z_sample, x_target) 30 | 31 | # ELBO 32 | loglik <- loglikelihood(y_target, y_pred_params) 33 | KL_loss <- KLqp_gaussian(z_all$mu, z_all$sigma, z_context$mu, z_context$sigma) 34 | loss <- tf$negative(loglik) + KL_loss 35 | 36 | # optimisation 37 | optimizer <- tf$train$AdamOptimizer(learning_rate) 38 | train_op <- optimizer$minimize(loss) 39 | 40 | # return train_op and loss 41 | list(train_op, loss) 42 | } 43 | 44 | prior_predict <- function(x_star_value, epsilon = NULL, n_draws = 1L){ 45 | N_star <- nrow(x_star_value) 46 | x_star <- tf$constant(x_star_value, dtype = tf$float32) 47 | 48 | # the source of randomness can be optionally passed as an argument 49 | if(is.null(epsilon)){ 50 | epsilon <- tf$random_normal(shape(n_draws, dim_z)) 51 | } 52 | # draw z ~ N(0, 1) 53 | z_sample <- epsilon 54 | 55 | # y ~ g(z, x*) 56 | y_star <- g(z_sample, x_star) 57 | 58 | y_star 59 | } 60 | 61 | 62 | posterior_predict <- function(x, y, x_star_value, epsilon = NULL, n_draws = 1L){ 63 | # inputs for prediction time 64 | x_obs <- tf$constant(x, dtype = tf$float32) 65 | y_obs <- tf$constant(y, dtype = tf$float32) 66 | x_star <- tf$constant(x_star_value, dtype = tf$float32) 67 | 68 | # for out-of-sample new points 69 | z_params <- map_xy_to_z_params(x_obs, y_obs) 70 | 71 | # the source of randomness can be optionally passed as an argument 72 | if(is.null(epsilon)){ 73 | epsilon <- tf$random_normal(shape(n_draws, dim_z)) 74 | } 75 | # sample z using reparametrisation 76 | z_sample <- epsilon %>% 77 | tf$multiply(z_params$sigma) %>% 78 | tf$add(z_params$mu) 79 | 80 | # predictions 81 | y_star <- g(z_sample, x_star) 82 | 83 | y_star 84 | } 85 | 86 | # KLqp helper 87 | KLqp_gaussian <- function(mu_q, sigma_q, mu_p, sigma_p){ 88 | sigma2_q <- tf$square(sigma_q) + 1e-16 89 | sigma2_p <- tf$square(sigma_p) + 1e-16 90 | temp <- sigma2_q / sigma2_p + tf$square(mu_q - mu_p) / sigma2_p - 1.0 + tf$log(sigma2_p / sigma2_q + 1e-16) 91 | 0.5 * tf$reduce_sum(temp) 92 | } 93 | 94 | # for ELBO 95 | loglikelihood <- function(y_star, y_pred_params){ 96 | 97 | p_normal <- tf$distributions$Normal(loc = y_pred_params$mu, scale = y_pred_params$sigma) 98 | 99 | loglik <- y_star %>% 100 | p_normal$log_prob() %>% 101 | # sum over data points 102 | tf$reduce_sum(axis=0L) %>% 103 | # average over n_draws 104 | tf$reduce_mean() 105 | 106 | loglik 107 | } 108 | 109 | 110 | # for training 111 | helper_context_and_target <- function(x, y, N_context, x_context, y_context, x_target, y_target){ 112 | N <- length(y) 113 | context_set <- sample(1:N, N_context) 114 | dict( 115 | x_context = cbind(x[context_set]), 116 | y_context = cbind(y[context_set]), 117 | x_target = cbind(x[-context_set]), 118 | y_target = cbind(y[-context_set]) 119 | ) 120 | } 121 | 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Processes 2 | 3 | ![](fig/NP_banner.gif) 4 | 5 | This is an implementation of Neural Processes for 1D-regression, accompanying [my blog post](https://kasparmartens.rbind.io/post/np/). 6 | 7 | ### Structure of the repo 8 | 9 | The implementation uses TensorFlow in R: 10 | 11 | * The files [NP_architecture*.R](https://github.com/kasparmartens/NeuralProcesses/blob/master/NP_architecture1.R) specify the NN architectures for the encoder *h* and decoder *g* as well as the aggregator and the mapping from *r* to *z*. 12 | * The file [NP_core.R](https://github.com/kasparmartens/NeuralProcesses/blob/master/NP_core.R) contains functions to define the loss function and carry out posterior prediction. 13 | 14 | Note: when changing network architecture, e.g. when fitting a new model, you need to run `tf$reset_default_graph()` or restart your R session. 15 | 16 | All experiments can be found in the "experiments" folder (where they appear in the same order as in the blog post): 17 | 18 | * The [first experiment](https://github.com/kasparmartens/NeuralProcesses/blob/master/experiments/1_experiment.R) involves training an NP on a single small data set. 19 | * The [second experiment](https://github.com/kasparmartens/NeuralProcesses/blob/master/experiments/2_experiment.R) involves training an NP on a small class of functions of the form `a * sin(x)`. 20 | * The [third experiment](https://github.com/kasparmartens/NeuralProcesses/blob/master/experiments/3_experiment.R) involves training an NP on repeated draws from the GP. 21 | 22 | ### Example code 23 | 24 | Loading all the libraries and helper functions 25 | 26 | ```R 27 | library(tidyverse) 28 | library(tensorflow) 29 | library(patchwork) 30 | 31 | source("NP_architecture1.R") 32 | source("NP_core.R") 33 | source("GP_helpers.R") 34 | source("helpers_for_plotting.R") 35 | ``` 36 | 37 | Setting up the NP model: 38 | 39 | ```R 40 | sess <- tf$Session() 41 | 42 | # specify (global variables) for dimensionality of r, z, and hidden layers of g and h 43 | dim_r <- 2L 44 | dim_z <- 2L 45 | dim_h_hidden <- 32L 46 | dim_g_hidden <- 32L 47 | 48 | # placeholders for training inputs 49 | x_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 50 | y_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 51 | x_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 52 | y_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 53 | 54 | # set up NN 55 | train_op_and_loss <- init_NP(x_context, y_context, x_target, y_target, learning_rate = 0.001) 56 | 57 | # initialise 58 | init <- tf$global_variables_initializer() 59 | sess$run(init) 60 | ``` 61 | 62 | Now, sampling data according to the function y = a*sin(x),we can fit the model as follows: 63 | 64 | ```R 65 | n_iter <- 10000 66 | 67 | for(iter in 1:n_iter){ 68 | # sample data (x_obs, y_obs) 69 | N <- 20 70 | x_obs <- runif(N, -3, 3) 71 | a <- runif(1, -2, 2) 72 | y_obs <- a * sin(x_obs) 73 | 74 | # sample N_context for training 75 | N_context <- sample(1:10, 1) 76 | 77 | # use helper function to pick a random context set 78 | feed_dict <- helper_context_and_target(x_obs, y_obs, N_context, x_context, y_context, x_target, y_target) 79 | 80 | # optimisation step 81 | a <- sess$run(train_op_and_loss, feed_dict = feed_dict) 82 | 83 | if(iter %% 1e3 == 0){ 84 | cat(sprintf("loss = %1.3f\n", a[[2]])) 85 | } 86 | } 87 | ``` 88 | 89 | Prediction using the trained model: 90 | 91 | ```R 92 | # context set at prediction-time 93 | x0 <- c(0, 1) 94 | y0 <- 1*sin(x0) 95 | 96 | # prediction grid 97 | x_star <- seq(-4, 4, length=100) 98 | 99 | # plot posterior draws 100 | plot_posterior_draws(x0, y0, x_star, n_draws = 50) 101 | 102 | ``` 103 | 104 | ### Other resources 105 | 106 | *Update (February 2019)*: The authors of the Neural Process papers have now made their implementation available here [https://github.com/deepmind/neural-processes](https://github.com/deepmind/neural-processes) 107 | -------------------------------------------------------------------------------- /experiments/0_draws_from_GP_prior.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | source("GP_helpers.R") 3 | 4 | GP_prior_draws <- function(ls, n_draws, N = 200){ 5 | x <- seq(-3, 3, length=N) 6 | Y <- matrix(0, n_draws, N) 7 | for(i in 1:n_draws){ 8 | K <- rbf_kernel(cbind(x), cbind(x), ls = ls, var = 1.0) + 1e-6*diag(N) 9 | Y[i, ] <- as.numeric(mvtnorm::rmvnorm(1, sigma = K)) 10 | } 11 | Y %>% 12 | reshape2::melt() %>% 13 | rename(draw = Var1, index = Var2, f = value) %>% 14 | mutate(x = x[index], ls = ls) 15 | } 16 | 17 | df1 <- GP_prior_draws(ls = 1, n_draws = 20) 18 | df2 <- GP_prior_draws(ls = 2, n_draws = 20) 19 | df3 <- GP_prior_draws(ls = 3, n_draws = 20) 20 | 21 | bind_rows(df1, df2, df3) %>% 22 | mutate(label = sprintf("lengthscale = %d", ls)) %>% 23 | ggplot(aes(x, f, group=draw)) + 24 | geom_line(alpha = 0.2) + 25 | # scale_color_viridis_c() + 26 | theme_classic() + 27 | facet_wrap(~ label) + 28 | theme(legend.position = "none") + 29 | labs(title = "Draws from the GP prior", y = "Function value") 30 | 31 | ggsave("fig/GP_draws.png", width = 7.5, height = 2.5) 32 | -------------------------------------------------------------------------------- /experiments/0_draws_from_prior.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | library(tensorflow) 3 | library(patchwork) 4 | 5 | source("NP_core.R") 6 | source("GP_helpers.R") 7 | source("helpers_for_plotting.R") 8 | source("NP_architecture1.R") 9 | 10 | 11 | dim_r_values <- c(1, 2, 4, 8) 12 | plot_list <- list() 13 | for(i in seq_along(dim_r_values)){ 14 | dim_r <- dim_r_values[i] 15 | dim_z <- dim_r 16 | dim_h_hidden <- 8L 17 | dim_g_hidden <- 8L 18 | 19 | sess <- tf$Session() 20 | 21 | # placeholders for training inputs 22 | x_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 23 | y_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 24 | x_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 25 | y_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 26 | 27 | # set up NN 28 | train_op_and_loss <- init_NP(x_context, y_context, x_target, y_target, learning_rate = 0.001) 29 | 30 | # initialise 31 | init <- tf$global_variables_initializer() 32 | sess$run(init) 33 | 34 | x_star <- seq(-4, 4, length=100) 35 | prior_predict_op <- prior_predict(cbind(x_star), n_draws = 50L) 36 | y_star_mat <- sess$run(prior_predict_op$mu) 37 | df_pred <- reshape_predictions(y_star_mat, x_star) 38 | 39 | plot_list[[i]] <- df_pred %>% 40 | ggplot(aes(x, y, group=rep_index)) + 41 | geom_line(alpha = 0.2) + 42 | theme_classic() + 43 | labs(title = sprintf("dim(z) = %d", dim_r), y = "Function value") 44 | 45 | tf$reset_default_graph() 46 | } 47 | 48 | p <- wrap_plots(plot_list, nrow = 1) + plot_annotation(title = "Function draws from the NP prior") 49 | p 50 | ggsave("fig/draws_from_prior.png", p, width = 10, height = 3) 51 | -------------------------------------------------------------------------------- /experiments/1_experiment.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | library(tensorflow) 3 | library(patchwork) 4 | 5 | source("NP_core.R") 6 | source("GP_helpers.R") 7 | source("helpers_for_plotting.R") 8 | source("NP_architecture1.R") 9 | 10 | # generate data 11 | N <- 5 12 | x <- c(-2, -1, 0, 1, 2) 13 | y <- sin(x) 14 | df_obs <- data.frame(x, y) 15 | 16 | # plot data 17 | p0 <- df_obs %>% 18 | ggplot(aes(x, y)) + 19 | geom_point(col = "#377EB8", size=3) + 20 | theme_classic() + 21 | coord_cartesian(xlim = c(-3, 3)) + 22 | labs(title = "Observed data") 23 | p <- plot_spacer() + p0 + plot_spacer() 24 | # ggsave("fig/observed_data.png", p, width=10, height=2.5) 25 | 26 | 27 | ### Now fitting the NP 28 | 29 | # global variables for training the model 30 | dim_r <- 2L 31 | dim_z <- 2L 32 | dim_h_hidden <- 8L 33 | dim_g_hidden <- 8L 34 | 35 | sess <- tf$Session() 36 | 37 | # placeholders for training inputs 38 | x_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 39 | y_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 40 | x_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 41 | y_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 42 | 43 | # set up NN 44 | train_op_and_loss <- init_NP(x_context, y_context, x_target, y_target, learning_rate = 0.001) 45 | 46 | # initialise 47 | init <- tf$global_variables_initializer() 48 | sess$run(init) 49 | 50 | n_iter <- 5000 51 | plot_freq <- 200 52 | 53 | # Plotting functionality (to record training) 54 | # (using fixed source of randomness over all iters) 55 | n_draws <- 50L 56 | x_star <- seq(-4, 4, length=100) 57 | eps_value <- matrix(rnorm(n_draws*dim_r), n_draws, dim_r) 58 | epsilon <- tf$constant(eps_value, dtype = tf$float32) 59 | predict_op <- posterior_predict(cbind(x), cbind(y), cbind(x_star), epsilon) 60 | 61 | 62 | 63 | df_pred_list <- list() 64 | for(iter in 1:n_iter){ 65 | N_context <- sample(1:4, 1) 66 | # create feed_dict containing context and target sets 67 | feed_dict <- helper_context_and_target(x, y, N_context, x_context, y_context, x_target, y_target) 68 | # optimisation step 69 | a <- sess$run(train_op_and_loss, feed_dict = feed_dict) 70 | 71 | # plotting 72 | if(iter %% plot_freq == 0){ 73 | y_star_mat <- sess$run(predict_op$mu) 74 | df_pred <- y_star_mat %>% 75 | reshape_predictions(x_star) %>% 76 | mutate(iter = iter) 77 | df_pred_list[[iter]] <- df_pred 78 | } 79 | } 80 | df_pred <- bind_rows(df_pred_list) 81 | 82 | 83 | # gif 84 | 85 | library(gganimate) 86 | 87 | df_obs_rep <- crossing(df_obs, 88 | iter = df_pred$iter, 89 | rep_index = df_pred$rep_index) 90 | 91 | p <- df_pred %>% 92 | ggplot(aes(x, y)) + 93 | geom_line(aes(group=rep_index), alpha = 0.2) + 94 | geom_point(data = df_obs_rep, col = "#377EB8") + 95 | transition_time(iter) + 96 | labs(title = "Training a Neural Process", subtitle = "Iteration: {frame_time}", y = "Function value") + 97 | coord_cartesian(ylim = c(-2, 2)) + 98 | theme_classic() 99 | 100 | animate(p, nframes = 50, width=400, height=250) 101 | 102 | anim_save("fig/experiment1.gif") 103 | 104 | 105 | # prediction for a different set of context points 106 | y2 <- 1 + sin(x) 107 | df_obs2 <- data.frame(x, y = y2) 108 | predict_op2 <- posterior_predict(weights, cbind(x), cbind(y2), cbind(x_star), epsilon) 109 | 110 | y_star_mat <- sess$run(predict_op2$mu) 111 | df_pred2 <- y_star_mat %>% 112 | reshape_predictions(x_star) 113 | 114 | df_pred2 %>% 115 | ggplot(aes(x, y)) + 116 | geom_line(aes(group=rep_index), alpha = 0.2) + 117 | geom_point(aes(col = "Context points at training time"), data = df_obs, size=3) + 118 | geom_point(aes(col = "Context points at prediction time"), data = df_obs2, size=3) + 119 | scale_color_brewer("", palette = "Set1") + 120 | theme_classic() 121 | 122 | ggsave("fig/experiment1.png", width = 8, height = 3) 123 | -------------------------------------------------------------------------------- /experiments/2_experiment.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | library(tensorflow) 3 | library(patchwork) 4 | 5 | source("NP_core.R") 6 | source("GP_helpers.R") 7 | source("helpers_for_plotting.R") 8 | source("NP_architecture1.R") 9 | 10 | # global variables for training the model 11 | dim_r <- 2L 12 | dim_z <- 2L 13 | dim_h_hidden <- 32L 14 | dim_g_hidden <- 32L 15 | 16 | sess <- tf$Session() 17 | 18 | # placeholders for training inputs 19 | x_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 20 | y_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 21 | x_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 22 | y_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 23 | 24 | # set up NN 25 | train_op_and_loss <- init_NP(x_context, y_context, x_target, y_target, learning_rate = 0.001) 26 | 27 | # initialise 28 | init <- tf$global_variables_initializer() 29 | sess$run(init) 30 | 31 | n_iter <- 50000 32 | 33 | for(iter in 1:n_iter){ 34 | N <- 20 35 | x_obs <- runif(N, -3, 3) 36 | a <- runif(1, -2, 2) 37 | y_obs <- a * sin(x_obs) 38 | 39 | # sample N_context for training 40 | N_context <- sample(1:10, 1) 41 | feed_dict <- helper_context_and_target(x_obs, y_obs, N_context, x_context, y_context, x_target, y_target) 42 | a <- sess$run(train_op_and_loss, feed_dict = feed_dict) 43 | if(iter %% 1e3 == 0){ 44 | cat(sprintf("loss = %1.3f\n", a[[2]])) 45 | } 46 | } 47 | 48 | 49 | # create a grid (z1, z2) and plot predictions 50 | x_star <- seq(-4, 4, length=100) 51 | z1 <- seq(-4, 4, length=9) 52 | z2 <- seq(-4, 4, length=9) 53 | eps_value <- as.matrix(expand.grid(z1, z2)) 54 | eps <- tf$constant(eps_value, dtype = tf$float32) 55 | prior_predict_op <- prior_predict(cbind(x_star), epsilon = eps) 56 | y_star_mat <- sess$run(prior_predict_op$mu) 57 | reshape_predictions(y_star_mat, x_star) %>% 58 | mutate(z1 = eps_value[, 1][rep_index], 59 | z2 = eps_value[, 2][rep_index]) %>% 60 | ggplot(aes(x, y, group=rep_index, col = z1+z2)) + 61 | geom_line() + 62 | facet_grid(z1 ~ z2) + 63 | theme_classic() + 64 | theme(legend.position = "none") + 65 | scale_y_continuous(breaks = c(-2, 0, 2)) 66 | ggsave("fig/experiment2_latent_space.png", width=8, height=5) 67 | 68 | 69 | # create a gif 70 | 71 | library(gganimate) 72 | 73 | x_star <- seq(-4, 4, length=100) 74 | z1 <- seq(-4, 4, length=41) 75 | z2 <- seq(-4, 4, length=41) 76 | eps_value <- as.matrix(expand.grid(z1, z2)) 77 | eps <- tf$constant(eps_value, dtype = tf$float32) 78 | prior_predict_op <- prior_predict(cbind(x_star), epsilon = eps) 79 | y_star_mat <- sess$run(prior_predict_op$mu) 80 | 81 | df_pred <- y_star_mat %>% 82 | reshape_predictions(x_star) %>% 83 | mutate(z1 = eps_value[, 1][rep_index], 84 | z2 = eps_value[, 2][rep_index]) 85 | 86 | p1 <- df_pred %>% 87 | ggplot(aes(x, y, group=z2, col = z2)) + 88 | geom_line() + 89 | transition_time(z1) + 90 | labs(title = "NP for a*sin(x)", subtitle = 'z1 = {frame_time}') + 91 | scale_color_viridis_c() + 92 | theme_classic() 93 | animate(p1, nframes = 50, width=300, height=250) 94 | # anim_save("fig/sin_z1.gif") 95 | 96 | p2 <- df_pred %>% 97 | ggplot(aes(x, y, group=z1, col = z1)) + 98 | geom_line() + 99 | transition_time(z2) + 100 | labs(title = " ", subtitle = 'z2 = {frame_time}') + 101 | scale_color_viridis_c() + 102 | theme_classic() 103 | animate(p2, nframes = 50, width=300, height=250) 104 | # anim_save("fig/sin_z2.gif") 105 | 106 | 107 | # Static plots for predictions 108 | x_star <- seq(-4, 4, length=100) 109 | 110 | # start with point (0, 0) and then expand context set 111 | x0 <- c(0) 112 | y0 <- 1*sin(x0) 113 | p1 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) 114 | 115 | x0 <- c(0, 1) 116 | y0 <- 1*sin(x0) 117 | p2 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) 118 | 119 | x0 <- c(0, 1, -1, 2, -2) 120 | y0 <- 1*sin(x0) 121 | p3 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) 122 | 123 | p <- (p1 + p2 + p3) * coord_cartesian(ylim = c(-1, 1)) + plot_layout(nrow = 1) 124 | p 125 | ggsave("fig/experiment2_pred.png", p, width = 9, height = 3) 126 | 127 | 128 | # generalisation / model misspecification 129 | x0 <- seq(-2, 2, length=10) 130 | y0 <- 2.5*sin(x0) 131 | p1 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) + 132 | labs(subtitle = "2.5 sin(x)") 133 | 134 | y0 <- abs(sin(x0)) 135 | p2 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) + 136 | coord_cartesian(ylim = c(-1, 1)) + 137 | labs(subtitle = "abs(sin(x))") 138 | 139 | 140 | p <- (p1 + p2) + plot_layout(nrow = 1) 141 | ggsave("fig/experiment2_misspecification.png", p, width = 7, height = 2.5) 142 | -------------------------------------------------------------------------------- /experiments/3_experiment.R: -------------------------------------------------------------------------------- 1 | library(tidyverse) 2 | library(tensorflow) 3 | library(patchwork) 4 | 5 | source("NP_core.R") 6 | source("GP_helpers.R") 7 | source("helpers_for_plotting.R") 8 | source("NP_architecture1.R") 9 | 10 | # global variables for training the model 11 | dim_r <- 2L 12 | dim_z <- 2L 13 | dim_h_hidden <- 32L 14 | dim_g_hidden <- 32L 15 | 16 | sess <- tf$Session() 17 | 18 | # placeholders for training inputs 19 | x_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 20 | y_context <- tf$placeholder(tf$float32, shape(NULL, 1)) 21 | x_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 22 | y_target <- tf$placeholder(tf$float32, shape(NULL, 1)) 23 | 24 | # set up NN 25 | train_op_and_loss <- init_NP(x_context, y_context, x_target, y_target, learning_rate = 0.001) 26 | 27 | # initialise 28 | init <- tf$global_variables_initializer() 29 | sess$run(init) 30 | 31 | n_iter <- 300000 32 | 33 | for(iter in 1:n_iter){ 34 | N <- 20 35 | x_obs <- runif(N, -3, 3) 36 | ls <- sample(c(1, 2, 3), 1) 37 | K <- rbf_kernel(cbind(x_obs), cbind(x_obs), ls, 1.0) + 1e-5 * diag(N) 38 | y_obs <- as.numeric(mvtnorm::rmvnorm(1, sigma = K)) 39 | 40 | # sample N_context for training 41 | N_context <- sample(1:15, 1) 42 | feed_dict <- helper_context_and_target(x_obs, y_obs, N_context, x_context, y_context, x_target, y_target) 43 | a <- sess$run(train_op_and_loss, feed_dict = feed_dict) 44 | if(iter %% 1e3 == 0){ 45 | cat(sprintf("loss = %1.3f\n", a[[2]])) 46 | } 47 | } 48 | 49 | 50 | # create a gif 51 | 52 | library(gganimate) 53 | 54 | x_star <- seq(-4, 4, length=100) 55 | z1 <- seq(-15, 15, length=31) 56 | z2 <- seq(-15, 15, length=31) 57 | eps_value <- as.matrix(expand.grid(z1, z2)) 58 | eps <- tf$constant(eps_value, dtype = tf$float32) 59 | prior_predict_op <- prior_predict(weights, cbind(x_star), epsilon = eps) 60 | y_star_mat <- sess$run(prior_predict_op$mu) 61 | 62 | df_pred <- y_star_mat %>% 63 | reshape_predictions(x_star) %>% 64 | mutate(z1 = eps_value[, 1][rep_index], 65 | z2 = eps_value[, 2][rep_index]) 66 | 67 | p1 <- df_pred %>% 68 | ggplot(aes(x, y, group=z2, col = z2)) + 69 | geom_line() + 70 | transition_time(as.integer(z1)) + 71 | labs(title = "NP trained on GP draws", subtitle = 'z1 = {frame_time}') + 72 | scale_color_viridis_c() + 73 | theme_classic() 74 | animate(p1, nframes = 31, fps = 7, width=300, height=175) 75 | # anim_save("fig/experiment3_1.gif") 76 | 77 | p2 <- df_pred %>% 78 | ggplot(aes(x, y, group=z1, col = z1)) + 79 | geom_line() + 80 | transition_time(as.integer(z2)) + 81 | labs(title = " ", subtitle = 'z2 = {frame_time}') + 82 | scale_color_viridis_c() + 83 | theme_classic() 84 | animate(p2, nframes = 31, fps = 7, width=300, height=175) 85 | # anim_save("fig/experiment3_2.gif") 86 | 87 | 88 | # Predictions for context points 89 | x_star <- seq(-4, 4, length=100) 90 | true_f <- function(x) -1.0 * sin(0.5*x) 91 | 92 | 93 | x0 <- seq(-3, 3, length=3) 94 | y0 <- true_f(x0) 95 | p1 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) + 96 | labs(title = "NP predictions") 97 | 98 | x0 <- seq(-3, 3, length=5) 99 | y0 <- true_f(x0) 100 | p2 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) 101 | 102 | x0 <- seq(-3, 3, length=11) 103 | y0 <- true_f(x0) 104 | p3 <- plot_posterior_draws(x0, y0, x_star, n_draws = 50L) 105 | 106 | p1 + p2 + p3 107 | 108 | # GP predictions for the same set of points 109 | library(gpflowr) 110 | 111 | x0 <- seq(-3, 3, length=3) 112 | y0 <- true_f(x0) 113 | gp1 <- fit_and_plot_GP(x0, y0, x_star) + 114 | labs(title = "GP predictions") 115 | 116 | x0 <- seq(-3, 3, length=5) 117 | y0 <- true_f(x0) 118 | gp2 <- fit_and_plot_GP(x0, y0, x_star) 119 | 120 | x0 <- seq(-3, 3, length=11) 121 | y0 <- true_f(x0) 122 | gp3 <- fit_and_plot_GP(x0, y0, x_star) 123 | 124 | p <- ((p1 | p2 | p3) * coord_cartesian(ylim = c(-2, 2)) / 125 | ((gp1 | gp2 | gp3) * coord_cartesian(ylim = c(-2, 2)))) 126 | 127 | ggsave("fig/experiment3_pred1.png", p, width = 8, height = 4.5) 128 | # ggsave("fig/experiment3_pred2.png", p, width = 8, height = 4.5) 129 | 130 | -------------------------------------------------------------------------------- /fig/GP_draws.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/GP_draws.png -------------------------------------------------------------------------------- /fig/NP_banner.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/NP_banner.gif -------------------------------------------------------------------------------- /fig/draws_from_prior.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/draws_from_prior.png -------------------------------------------------------------------------------- /fig/draws_from_prior_relu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/draws_from_prior_relu.png -------------------------------------------------------------------------------- /fig/experiment1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment1.gif -------------------------------------------------------------------------------- /fig/experiment1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment1.png -------------------------------------------------------------------------------- /fig/experiment2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment2.gif -------------------------------------------------------------------------------- /fig/experiment2_latent_space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment2_latent_space.png -------------------------------------------------------------------------------- /fig/experiment2_misspecification.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment2_misspecification.png -------------------------------------------------------------------------------- /fig/experiment2_pred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment2_pred.png -------------------------------------------------------------------------------- /fig/experiment3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment3.gif -------------------------------------------------------------------------------- /fig/experiment3_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment3_1.gif -------------------------------------------------------------------------------- /fig/experiment3_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment3_2.gif -------------------------------------------------------------------------------- /fig/experiment3_pred1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment3_pred1.png -------------------------------------------------------------------------------- /fig/experiment3_pred2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment3_pred2.png -------------------------------------------------------------------------------- /fig/experiment4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/experiment4.png -------------------------------------------------------------------------------- /fig/observed_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/observed_data.png -------------------------------------------------------------------------------- /fig/schema1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/schema1.png -------------------------------------------------------------------------------- /fig/schema2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/schema2.png -------------------------------------------------------------------------------- /fig/schema3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/schema3.png -------------------------------------------------------------------------------- /fig/two_scenarios.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kasparmartens/NeuralProcesses/0792c3487f743f85c3c5e912382c11aa3f594619/fig/two_scenarios.png -------------------------------------------------------------------------------- /helpers_for_plotting.R: -------------------------------------------------------------------------------- 1 | reshape_predictions <- function(y_star_mat, x_star){ 2 | y_star_mat %>% 3 | reshape2::melt() %>% 4 | rename(index = Var1, rep_index = Var2, y = value) %>% 5 | mutate(x = x_star[index]) %>% 6 | select(-index) 7 | } 8 | 9 | plot_posterior_draws <- function(x, y, x_star, n_draws = 50L){ 10 | df_obs <- data.frame(x = x, y = y) 11 | predict_op <- posterior_predict(cbind(x), cbind(y), cbind(x_star), n_draws = n_draws) 12 | y_star_mat <- sess$run(predict_op$mu) 13 | df_pred <- reshape_predictions(y_star_mat, x_star) 14 | 15 | df_pred %>% 16 | ggplot(aes(x, y)) + 17 | geom_line(aes(group=rep_index), alpha = 0.2) + 18 | geom_point(data = df_obs, col = "#b2182b", size = 3) + 19 | theme_classic() 20 | } 21 | -------------------------------------------------------------------------------- /previous_implementation/_NP_core.R: -------------------------------------------------------------------------------- 1 | # create and return all weight tensors used in the NP model 2 | init_weights <- function(dim_r = 8L, dim_z = dim_r, dim_h_hidden = 16L, dim_g_hidden = 16L){ 3 | # weights for the encoder h, mapping (x, y) -> hidden -> r 4 | W_h1 <- tf$Variable(tf$random_normal(shape(2L, dim_h_hidden))) 5 | b_h1 <- tf$Variable(tf$random_normal(shape(dim_h_hidden))) 6 | W_h2 <- tf$Variable(tf$random_normal(shape(dim_h_hidden, dim_r))) 7 | b_h2 <- tf$Variable(tf$random_normal(shape(dim_r))) 8 | 9 | # weights for mapping r to (mu_z, sigma_z) 10 | W_z_mu <- tf$Variable(tf$random_normal(shape(dim_r, dim_z))) 11 | W_z_sigma <- tf$Variable(tf$random_normal(shape(dim_r, dim_z))) 12 | 13 | # weights for the decoder g, mapping (z, x) -> hidden -> y 14 | W_g1 <- tf$Variable(tf$random_normal(shape(dim_z + 1L, dim_g_hidden))) 15 | b_g1 <- tf$Variable(tf$random_normal(shape(dim_g_hidden))) 16 | W_g2 <- tf$Variable(tf$random_normal(shape(dim_g_hidden, 1L))) 17 | b_g2 <- tf$Variable(tf$random_normal(shape(1L))) 18 | 19 | # return all weights 20 | list(W_h1 = W_h1, 21 | b_h1 = b_h1, 22 | W_h2 = W_h2, 23 | b_h2 = b_h2, 24 | W_z_mu = W_z_mu, 25 | W_z_sigma = W_z_sigma, 26 | W_g1 = W_g1, 27 | b_g1 = b_g1, 28 | W_g2 = W_g2, 29 | b_g2 = b_g2, 30 | dim_z = dim_z) 31 | } 32 | 33 | # helper function to map (x, y) -> z directly without intermediate steps 34 | map_xy_to_z_params <- function(x, y, weights){ 35 | list(x, y) %>% 36 | tf$concat(axis = 1L) %>% 37 | h(weights$W_h1, weights$b_h1, weights$W_h2, weights$b_h2) %>% 38 | aggregate_r() %>% 39 | get_z_params(weights$W_z_mu, weights$W_z_sigma) 40 | } 41 | 42 | # set up the NN architecture with train_op and loss 43 | init_NP <- function(weights, x_context, y_context, x_target, y_target, learning_rate = 0.001){ 44 | 45 | # concatenate context and target 46 | x_all <- tf$concat(list(x_context, x_target), axis = 0L) 47 | y_all <- tf$concat(list(y_context, y_target), axis = 0L) 48 | 49 | # map input to z 50 | z_context <- map_xy_to_z_params(x_context, y_context, weights) 51 | z_all <- map_xy_to_z_params(x_all, y_all, weights) 52 | 53 | # sample z using reparametrisation, z = mu + sigma*eps 54 | epsilon <- tf$random_normal(shape(7L, weights$dim_z)) 55 | z_sample <- epsilon %>% 56 | tf$multiply(z_all$sigma) %>% 57 | tf$add(z_all$mu) 58 | 59 | # map (z, x*) to y* 60 | y_pred_params <- g(z_sample, x_target, weights$W_g1, weights$b_g1, weights$W_g2, weights$b_g2) 61 | 62 | # ELBO 63 | loglik <- loglikelihood(y_target, y_pred_params) 64 | KL_loss <- KLqp_gaussian(z_all$mu, z_all$sigma, z_context$mu, z_context$sigma) 65 | loss <- tf$negative(loglik) + KL_loss 66 | 67 | # optimisation 68 | optimizer <- tf$train$AdamOptimizer(learning_rate) 69 | train_op <- optimizer$minimize(loss) 70 | 71 | # return train_op and loss 72 | list(train_op, loss) 73 | } 74 | 75 | prior_predict <- function(weights, x_star_value, epsilon = NULL, n_draws = 1L){ 76 | N_star <- nrow(x_star_value) 77 | x_star <- tf$constant(x_star_value, dtype = tf$float32) 78 | 79 | # the source of randomness can be optionally passed as an argument 80 | if(is.null(epsilon)){ 81 | epsilon <- tf$random_normal(shape(n_draws, weights$dim_z)) 82 | } 83 | # draw z ~ N(0, 1) 84 | z_sample <- epsilon 85 | 86 | # y ~ g(z, x*) 87 | y_star <- g(z_sample, x_star, weights$W_g1, weights$b_g1, weights$W_g2, weights$b_g2) 88 | 89 | y_star 90 | } 91 | 92 | 93 | posterior_predict <- function(weights, x, y, x_star_value, epsilon = NULL, n_draws = 1L){ 94 | # inputs for prediction time 95 | x_obs <- tf$constant(x, dtype = tf$float32) 96 | y_obs <- tf$constant(y, dtype = tf$float32) 97 | x_star <- tf$constant(x_star_value, dtype = tf$float32) 98 | 99 | # for out-of-sample new points 100 | z_params <- map_xy_to_z_params(x_obs, y_obs, weights) 101 | 102 | # the source of randomness can be optionally passed as an argument 103 | if(is.null(epsilon)){ 104 | epsilon <- tf$random_normal(shape(n_draws, weights$dim_z)) 105 | } 106 | # sample z using reparametrisation 107 | z_sample <- epsilon %>% 108 | tf$multiply(z_params$sigma) %>% 109 | tf$add(z_params$mu) 110 | 111 | # predictions 112 | y_star <- g(z_sample, x_star, weights$W_g1, weights$b_g1, weights$W_g2, weights$b_g2) 113 | 114 | y_star 115 | } 116 | -------------------------------------------------------------------------------- /previous_implementation/_NP_helpers.R: -------------------------------------------------------------------------------- 1 | # Helper functions for Neural Processes 2 | 3 | # encoder h -- map inputs (x_i, y_i) to r_i 4 | h <- function(input, W1, b1, W2, b2){ 5 | input %>% 6 | tf$matmul(W1) %>% 7 | tf$add(b1) %>% 8 | tf$nn$sigmoid() %>% 9 | tf$matmul(W2) %>% 10 | tf$add(b2) 11 | } 12 | 13 | # aggregate the output of h (i.e. values of r_i) to a single vector r 14 | aggregate_r <- function(input){ 15 | input %>% 16 | tf$reduce_mean(axis=0L) %>% 17 | tf$reshape(shape(1L, -1L)) 18 | } 19 | 20 | # map aggregated r to (mu_z, sigma_z) 21 | get_z_params <- function(input_r, W_mu, W_sigma){ 22 | mu <- input_r %>% 23 | tf$matmul(W_mu) 24 | 25 | sigma <- input_r %>% 26 | tf$matmul(W_sigma) %>% 27 | tf$nn$softplus() 28 | 29 | list(mu = mu, sigma = sigma) 30 | } 31 | 32 | 33 | # decoder g -- map (z, x*) -> hidden -> y* 34 | g <- function(z_sample, x_star, W1, b1, W2, b2, noise_sd = 0.05){ 35 | # inputs dimensions 36 | # z_sample has dim [n_draws, dim_z] 37 | # x_star has dim [N_star, dim_x] 38 | 39 | n_draws <- z_sample$get_shape()$as_list()[1] 40 | N_star <- tf$shape(x_star)[1] 41 | 42 | # z_sample_rep will have dim [n_draws, N_star, dim_z] 43 | z_sample_rep <- z_sample %>% 44 | tf$expand_dims(axis = 1L) %>% 45 | tf$tile(c(1L, N_star, 1L)) 46 | 47 | # x_star_rep will have dim [n_draws, N_star, dim_x] 48 | x_star_rep <- x_star %>% 49 | tf$expand_dims(axis = 0L) %>% 50 | tf$tile(shape(n_draws, 1L, 1L)) 51 | 52 | # concatenate x* and z 53 | input <- list(x_star_rep, z_sample_rep) %>% 54 | tf$concat(axis = 2L) 55 | 56 | # batch matmul 57 | W1_rep <- W1 %>% 58 | tf$expand_dims(axis=0L) %>% 59 | tf$tile(shape(n_draws, 1L, 1L)) 60 | W2_rep <- W2 %>% 61 | tf$expand_dims(axis=0L) %>% 62 | tf$tile(shape(n_draws, 1L, 1L)) 63 | 64 | # hidden layer 65 | hidden <- input %>% 66 | tf$matmul(W1_rep) %>% 67 | tf$add(b1) %>% 68 | tf$nn$sigmoid() 69 | 70 | # mu will be of the shape [N_star, n_draws] 71 | mu_star <- hidden %>% 72 | tf$matmul(W2_rep) %>% 73 | tf$add(b2) %>% 74 | tf$squeeze(axis = 2L) %>% 75 | tf$transpose() 76 | 77 | # for the toy example, assume y* ~ N(mu, sigma) with fixed sigma 78 | sigma_star <- tf$constant(noise_sd, dtype = tf$float32) 79 | 80 | list(mu = mu_star, sigma = sigma_star) 81 | } 82 | 83 | # KLqp helper 84 | KLqp_gaussian <- function(mu_q, sigma_q, mu_p, sigma_p){ 85 | sigma2_q <- tf$square(sigma_q) + 1e-16 86 | sigma2_p <- tf$square(sigma_p) + 1e-16 87 | temp <- sigma2_q / sigma2_p + tf$square(mu_q - mu_p) / sigma2_p - 1.0 + tf$log(sigma2_p / sigma2_q + 1e-16) 88 | 0.5 * tf$reduce_sum(temp) 89 | } 90 | 91 | # for ELBO 92 | loglikelihood <- function(y_star, y_pred_params){ 93 | 94 | p_normal <- tf$distributions$Normal(loc = y_pred_params$mu, scale = y_pred_params$sigma) 95 | 96 | loglik <- y_star %>% 97 | p_normal$log_prob() %>% 98 | # sum over data points 99 | tf$reduce_sum(axis=0L) %>% 100 | # average over n_draws 101 | tf$reduce_mean() 102 | 103 | loglik 104 | } 105 | 106 | # for training 107 | helper_context_and_target <- function(x, y, N_context, x_context, y_context, x_target, y_target){ 108 | N <- length(y) 109 | context_set <- sample(1:N, N_context) 110 | dict( 111 | x_context = cbind(x[context_set]), 112 | y_context = cbind(y[context_set]), 113 | x_target = cbind(x[-context_set]), 114 | y_target = cbind(y[-context_set]) 115 | ) 116 | } 117 | -------------------------------------------------------------------------------- /previous_implementation/_helpers_for_plotting.R: -------------------------------------------------------------------------------- 1 | reshape_predictions <- function(y_star_mat, x_star){ 2 | y_star_mat %>% 3 | reshape2::melt() %>% 4 | rename(index = Var1, rep_index = Var2, y = value) %>% 5 | mutate(x = x_star[index]) %>% 6 | select(-index) 7 | } 8 | 9 | plot_posterior_draws <- function(x, y, x_star, n_draws = 50L){ 10 | df_obs <- data.frame(x = x, y = y) 11 | predict_op <- posterior_predict(weights, cbind(x), cbind(y), cbind(x_star), n_draws = n_draws) 12 | y_star_mat <- sess$run(predict_op$mu) 13 | df_pred <- reshape_predictions(y_star_mat, x_star) 14 | 15 | df_pred %>% 16 | ggplot(aes(x, y)) + 17 | geom_line(aes(group=rep_index), alpha = 0.2) + 18 | geom_point(data = df_obs, col = "#b2182b", size = 3) + 19 | theme_classic() 20 | } --------------------------------------------------------------------------------