├── .gitignore ├── README.md ├── README_old.md ├── config └── nn_config.yaml ├── graphs ├── adam0.01_relu.png ├── adg0.01_relu.png ├── adg0.01_relu2.png ├── adg0.1_relu.png ├── adg0.1_relu_2.png ├── fullpass_dim2.png ├── fullpass_dim20.png ├── optimize_tests.png └── optimize_tests2.png ├── likelihoods.py ├── nets.py ├── tests.py ├── tf_helpers.py ├── vae.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pyo 3 | *.swp 4 | MNIST_data/* 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VINNy 2 | 3 | An implementation of VAEs in (an old version of) Tensorflow. 4 | 5 | This repo is also a project submission for Tamara Broderick's "Bayesian Inference" (6.882) course at MIT. 6 | 7 | # References 8 | 9 | \[1\] [_Auto-encoding Variational Bayes_](http://arxiv.org/abs/1312.6114); Kingma, Welling; NIPS 2014 10 | 11 | \[2\] [_Variational Inference Lecture Notes_](https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf); Blei; Princeton 2011 12 | -------------------------------------------------------------------------------- /README_old.md: -------------------------------------------------------------------------------- 1 | # VINNy 2 | 3 | Fun with Variational Autoencoder neural nets. 4 | 5 | Ping me with any questions at shraman (at) mit (dot) edu! This repo will be continually updated with some cool (experimental) uses for variational inference performed by neural networks. 6 | 7 | This repo is also a project submission for Tamara Broderick's "Bayesian Inference" (6.882) course at MIT. 8 | 9 | ## Introduction 10 | 11 | ### What is a "Variational Autoencoder"? 12 | 13 | A variational autoencoder is an encoder-decoder pair of neural networks that are designed to perform approximate posterior inference on latent variables. The loss function is (an estimate of) the evidence lower bound, optimizing which decreases the KL-divergence (or "relative entropy") between the approximating and true posteriors. A quick introduction to VI can be found in \[2\] and the auto-encoding framework for VI in \[1\]. 14 | 15 | ## Optimizations 16 | 17 | ### Statistical Optimizations 18 | #### Reparameterization I: Reducing MCE Variance 19 | 20 | In my opinion, this is the coolest aspect of [1] despite being so simple, since it brings out the ability to achieve good results with AEVB, but at a larger scale, broaches a new methodology altogether in approximation posterior inference with Variational Inference that depends on Monte Carlo estimation to project VI onto a broader class of models. It is definitely worth understanding this reparameterization trick in depth -- motivations, derivations, and all. 21 | 22 | When performing Variational Inference, there are two approaches to finding the gradient of the lower bound (which is an intractable expectation over q(z)): (1) find an analytic (and computationally attractive) closed form of the gradient w.r.t. variational parameters, or (2) Monte-Carlo Estimation (MCE): sampling 'z ~ q(z)' and using these samples z^(l) to somehow estimate the expectation of a function (in this case, a gradient of an expectation). Unfortunately, (1) is oftentimes infeasible when we want to remove the constraint of conjugacy between the prior and likelihood and (2) is troublesome when we naively perform expectation over the gradients of the random samples since this estimator exhibits high variance [citation]. 23 | 24 | As a side note, high variance is especially troublesome in a neural network setting where there are 2 sources of variance for the true gradient: one from using a random minibatch of data points (rather than the entire batch) to compute a gradient estimate (i.e. SGD), and another from calculating the gradient of that minibatch using Monte-Carlo methods (which we only use when the *inputs* to the loss function are also stochastic, e.g. when the loss function is an expectation like the variational lower bound). This "doubly stochastic" nature where both the inputs to the function as well as the function itself is stochastic can lead to problems when either source exhibits high variance. 25 | 26 | To ameliorate this from a purely statistical point of view, we reformulate (2) not as a computing an expectation over gradients of single points, but rather a gradient on an expectation (of a function, g) over single points. The latter is only possible if we are able to (analytically) reparameterize 'z' into a random component (ep) independent of phi (i.e. the variable with respect to which we are taking the gradient) and deterministic component (g) which includes phi (ex: for a Gaussian, this would be mu/sigma) so that we can analytically compute the gradient (w.r.t. phi) of g(ep, phi) after summing MCE samples of g(ep, phi). In other words, we must find a way to outsource the uncertainty on 'z' to an auxiliary variable (not dependent on phi) so we can still accurately leverage Monte Carlo sampling on the auxiliary variable (i.e. g(ep, phi), ep~p(ep) should give the same distribution as q(z; phi)) and while also being able to take a gradient of an MCE rather than an rely on an MCE of a gradient (thereby greatly reducing the variance). 27 | 28 | Having approached this from a purely statistical point of view, we can expect better performance with evidence *grounded in theory*, whereas SGD is still missing substantial theoretical justification. This reparameterization also motivates a cool way to incorporate MCMC into Variational Inference to produce even better approximations, results of which can be tuned to tradeoff time (longer MCMC) with better results [7]. 29 | 30 | #### Reparameterization II: Reducing Variance at Scale 31 | 32 | #### Variational Dropout 33 | 34 | #### Natural Gradients 35 | 36 | #### Xavier Initialization 37 | 38 | #### Streaming VB 39 | 40 | #### Posterior Predictive Checking 41 | 42 | ### Numerical Optimizations 43 | #### ADAM Optimizer 44 | 45 | #### Batch Normalization 46 | 47 | #### Activation Functions 48 | 49 | 50 | ## References 51 | \[1\] [_Auto-encoding Variational Bayes_](http://arxiv.org/abs/1312.6114); Kingma, Welling; NIPS 2014 52 | 53 | \[2\] [_Variational Inference Lecture Notes_](https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf); Blei; Princeton 2011 54 | 55 | \[3\] [_Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift_](http://arxiv.org/abs/1502.03167); Ioffe, Szegedy; JMLR 2015 56 | 57 | \[4\] [_Streaming Variational Bayes_](http://papers.nips.cc/paper/4980-streaming-variational-bayes.pdf); Broderick, Boyd, Wibisono, Wilson, Jordan; NIPS 2013 58 | 59 | \[5\] [_Variational Dropout and Local Reparameterization Trick_](http://arxiv.org/pdf/1506.02557v2.pdf); Kingma, Salimans, Welling; NIPS 2015 60 | 61 | \[6\] [_Stochastic Variational Inference_](http://arxiv.org/pdf/1206.7051.pdf); Hoffman, Blei, Wang, Paisley; JMLR 2013 62 | 63 | \[7\] [_MCMC and Variational Inference: Bridging the Gap_](http://arxiv.org/pdf/1410.6460v4.pdf); Salimans, Kingma, Welling; JMLR 2015 64 | -------------------------------------------------------------------------------- /config/nn_config.yaml: -------------------------------------------------------------------------------- 1 | architecture: { 2 | encoder: { 3 | n_layers: 1, 4 | n_units: [500], 5 | distribution: Gaussian 6 | }, 7 | decoder: { 8 | n_layers: 1, 9 | n_units: [500], 10 | distribution: Bernoulli 11 | }, 12 | activation: "ReLU" 13 | } 14 | 15 | dims: { 16 | data: 784, # MNIST, flattened 17 | latent: 20 18 | } 19 | 20 | optimization: { 21 | type: Adam, 22 | Adagrad_rate: 0.01, 23 | Adam_rate: 0.01, 24 | max_grad: 1000000 25 | } 26 | 27 | training: { 28 | batch_size: 100, 29 | n_iters: 1000 30 | } 31 | 32 | AEVB: { # Params specified in [1] 33 | L: 1 34 | } 35 | -------------------------------------------------------------------------------- /graphs/adam0.01_relu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/adam0.01_relu.png -------------------------------------------------------------------------------- /graphs/adg0.01_relu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/adg0.01_relu.png -------------------------------------------------------------------------------- /graphs/adg0.01_relu2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/adg0.01_relu2.png -------------------------------------------------------------------------------- /graphs/adg0.1_relu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/adg0.1_relu.png -------------------------------------------------------------------------------- /graphs/adg0.1_relu_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/adg0.1_relu_2.png -------------------------------------------------------------------------------- /graphs/fullpass_dim2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/fullpass_dim2.png -------------------------------------------------------------------------------- /graphs/fullpass_dim20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/fullpass_dim20.png -------------------------------------------------------------------------------- /graphs/optimize_tests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/optimize_tests.png -------------------------------------------------------------------------------- /graphs/optimize_tests2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shraman-rc/vae/4142113f5f50b82281c0850ff369dd8648fc28cf/graphs/optimize_tests2.png -------------------------------------------------------------------------------- /likelihoods.py: -------------------------------------------------------------------------------- 1 | """ likelihoods.py: TF implementations of closed-form likelihoods given data/parameters 2 | """ 3 | 4 | __author__ = "shraman-rc" 5 | 6 | import tensorflow as tf 7 | from collections import namedtuple 8 | 9 | ''' 10 | Notation: 11 | - 'll' stands for 'log likelihood' 12 | ''' 13 | 14 | BARRIER = 1e-9 # Prevents evaluation of log(0) = nan 15 | # TODO: Bug: have 1e-10 creates nan's for likelihood in this implementation 16 | # but not in Metzen's 17 | # TODO: Effect is exacerbated by higher learning rates 18 | 19 | def ll_bernoulli(data, rho): 20 | ''' 21 | Params: 22 | - data: The data point(s) (possibly a batch) for which we are computing 23 | the likelihood value 24 | - rho: The probabilities that parameterize a multivariate Bernoulli 25 | ''' 26 | # Breakdown of the Bernoulli log-likelihood equation: 27 | # 28 | # The data vectors (should) come as binary: {0,1}^n 29 | # The probability of producing element 'i' in that vector is: 30 | # rho_i if x_i = 1, and (1-rho_i) if x_i = 0 31 | # A concise way to write this probability is: 32 | # p(x_i) = x_i * rho_i + (1-x_i)*(1-rho_i) 33 | # Therefore the likelihood of the entire vector is: 34 | # \prod_i {x_i * rho_i + (1-x_i)*(1-rho_i)} 35 | # Taking the log of this (only the probabilities rho_i, not the 36 | # 0-1 coefficients x_i) we get our log-likelihood: 37 | # \sum_i {x_i*log(rho_i) + (1-x_i)*log(1-rho_i)} 38 | return tf.reduce_sum(data*tf.log(BARRIER + rho) + (1-data)*tf.log(BARRIER + 1-rho), 1) 39 | #return tf.reduce_sum(data*tf.log(BARRIER) + (1-data)*tf.log(BARRIER), 1) 40 | 41 | def ll_gaussian(data, mu, log_var): 42 | ''' 43 | Params: 44 | - data: The data point(s) (possibly a batch) for which we are computing 45 | the likelihood value 46 | - mu: Mean of the Gaussian 47 | - log_var: Log(variance of the Gaussian) 48 | ''' 49 | # Simply implements the log Normal equation in TF 50 | return None #TODO 51 | -------------------------------------------------------------------------------- /nets.py: -------------------------------------------------------------------------------- 1 | """ nets.py: Defines and recontructs basic neural net architectures for the auto-encoder 2 | 3 | Construction closely follows [1]: Appendix C 4 | """ 5 | 6 | __author__ = "shraman-rc" 7 | 8 | import tensorflow as tf 9 | import tf_helpers as tfh 10 | from collections import namedtuple 11 | 12 | 13 | DEFAULT_LAYER_SIZES = [100] 14 | DEFAULT_LATENT_DIM = 10 15 | 16 | BernoulliParam = namedtuple('BernoulliParam', ['p']) 17 | GaussianParam = namedtuple('GaussianParam', ['mu', 'log_var']) 18 | 19 | 20 | class MLP(object): 21 | 22 | ''' 23 | Wrapper for creating fully-connected Multi-Layer Perceptron models 24 | that parameterize certain distributions 25 | ''' 26 | 27 | def __init__(self, input_batch, layer_sizes, latent_dim, fn_activate=tf.nn.relu, 28 | fn_init_w=tfh.xavier, fn_init_b=tf.zeros): 29 | ''' 30 | Params: 31 | - input_batch: The input tensor to the MLP 32 | - NOTE: This should be a 'tf.placeholder' 33 | - layer_sizes: The number of weights in each hidden layer 34 | - latent_dim: dimension of the resulting probability distribution 35 | - fn_activate: function to use for perceptron activation 36 | - fn_init_w: function to use for weight initialization 37 | - fn_init_b: function to use for bias initialization 38 | ''' 39 | self.batch_size, self.input_dim = tfh.shape(input_batch) 40 | self.input_batch = input_batch 41 | self.layer_sizes = layer_sizes 42 | self.out_dim = latent_dim 43 | self.activate, self.w_init, self.b_init = fn_activate, fn_init_w, fn_init_b 44 | 45 | # Generate input layer 46 | self.weights = [tf.Variable( 47 | self.w_init([self.input_dim, layer_sizes[0]]))] 48 | self.biases = [tf.Variable( 49 | self.b_init([layer_sizes[0]]))] 50 | 51 | # Keeps track of the hidden layer output as we build 52 | self.hidden_out = self.activate( 53 | tf.matmul(self.input_batch, self.weights[-1]) + self.biases[-1]) 54 | 55 | # Generate arbitrarily deep hidden layers 56 | for in_dim,out_dim in zip(layer_sizes, layer_sizes[1:]): 57 | self.weights.append(tf.Variable(self.w_init([in_dim, out_dim]))) 58 | self.biases.append(tf.Variable(self.b_init([out_dim]))) 59 | self.hidden_out = self.activate( 60 | tf.matmul(self.hidden_out, self.weights[-1]) + self.biases[-1]) 61 | 62 | # Generate output layer (parameters of a distribution) 63 | self.out_params = self._gen_params() 64 | 65 | 66 | class BernoulliMLP(MLP): 67 | 68 | def __init__(self, input_batch, 69 | layer_sizes=DEFAULT_LAYER_SIZES, dim=DEFAULT_LATENT_DIM): 70 | super(self.__class__, self).__init__(input_batch, layer_sizes, dim) 71 | 72 | def _gen_params(self): 73 | ''' 74 | Setup TF computation graph (neural net) to compute the value that 75 | parameterizes the Bernoulli distribution 76 | ''' 77 | self.bias_out = tf.Variable(self.b_init([self.out_dim])) 78 | self.weights_out = tf.Variable(self.w_init([self.layer_sizes[-1], self.out_dim])) 79 | 80 | # The output is a (multivariate) probability vector that represents 81 | # the "success probabilities" in a Bernoulli dist. 82 | p = tf.nn.sigmoid(tf.matmul(self.hidden_out, self.weights_out) + self.bias_out) 83 | 84 | return BernoulliParam(p) 85 | 86 | 87 | class GaussianMLP(MLP): 88 | 89 | def __init__(self, input_batch, 90 | layer_sizes=DEFAULT_LAYER_SIZES, dim=DEFAULT_LATENT_DIM): 91 | super(self.__class__, self).__init__(input_batch, layer_sizes, dim) 92 | 93 | def _gen_params(self): 94 | ''' 95 | Setup TF computation graph (neural net) to compute the values that 96 | parameterize the Gaussian distribution 97 | 98 | NOTE: The variance parameter output is actually log(variance) as 99 | suggested by [1] 100 | ''' 101 | self.bias_mu = tf.Variable(self.b_init([self.out_dim])) 102 | self.weights_mu = tf.Variable(self.w_init([self.layer_sizes[-1], self.out_dim])) 103 | 104 | self.bias_logvar = tf.Variable(self.b_init([self.out_dim])) 105 | self.weights_logvar = tf.Variable(self.w_init([self.layer_sizes[-1], self.out_dim])) 106 | 107 | mu = tf.matmul(self.hidden_out, self.weights_mu) + self.bias_mu 108 | # TODO: Why do we use log? Vanishing weights phenomenon when training? 109 | log_var = tf.matmul(self.hidden_out, self.weights_logvar) + self.bias_logvar 110 | 111 | return GaussianParam(mu,log_var) 112 | -------------------------------------------------------------------------------- /tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """ tests.py: Quantitative and qualitative tests on the effectiveness of the 4 | VAE implementation 5 | 6 | TODO: 7 | - Use nosetests-like framework 8 | - Auto-save to graphs/ 9 | - Click CLI 10 | """ 11 | 12 | import tensorflow as tf 13 | import click as cl 14 | import numpy as np 15 | import yaml 16 | 17 | import vae 18 | import vis 19 | 20 | FCONFIG = 'config/nn_config.yaml' 21 | 22 | try: 23 | config = yaml.load(file(FCONFIG, 'r')) 24 | except yaml.YAMLError, exc: 25 | cl.secho("Error in configuration file: {}".format(exc), fg='red') 26 | 27 | ARCH = config["architecture"] 28 | DIMS = config["dims"] 29 | OPT = config["optimization"] 30 | PARAMS = config["AEVB"] 31 | TRAIN = config["training"] 32 | 33 | # TF backend initializations 34 | tf.set_random_seed(0) 35 | 36 | 37 | def simple_test(): 38 | ''' 39 | Train VAE using default parameters specified in config/nn_config.yaml and 40 | plot progress over training iterations 41 | ''' 42 | # Instantiate and train vanilla autoencoder 43 | nn = vae.VAE(ARCH, DIMS, OPT, PARAMS, TRAIN) 44 | results = nn.train() 45 | iters = results["iters"] 46 | ELBOs = results["ELBO"] 47 | KLs = results["KL"] 48 | LLs = results["LL"] 49 | 50 | # Graph training results 51 | titles = ["$\mathcal{L}(\phi,\\theta;x)$", 52 | "$KL(q_{\phi}(z|x)||p_{\\theta}(z))$", 53 | "$\log(p_{\\theta}(x|z))$"] 54 | params = {"$\eta_{%s}$" % OPT["type"]: OPT["{}_rate".format(OPT["type"])], 55 | "$Activation$": ARCH["activation"], 56 | "$Batch$ $Size$": TRAIN["batch_size"], 57 | "$MCE$ $Samples$": PARAMS["L"], 58 | "$Latent$ $Dim$": DIMS["data"]} 59 | vis.basic_multiplot([iters]*3, [[ELBOs], [KLs], [LLs]], titles, 60 | show_legend=False, params=params) 61 | 62 | def bivariate_latent_space(): 63 | ''' 64 | Train VAE using default parameters specified in config/nn_config.yaml and 65 | plot examples of latent space distribution generated by each data point 66 | ''' 67 | # Instantiate and train vanilla autoencoder 68 | nn = vae.VAE(ARCH, DIMS, OPT, PARAMS, TRAIN) 69 | nn.train() 70 | 71 | # Do single forward pass on the encoder 72 | new_data = vae.mnist.test.next_batch(TRAIN["batch_size"])[0] 73 | mu_q, stddev_q, rho_p = nn.full_fp(new_data) 74 | 75 | vis.full_pass_vis( 76 | [im.reshape(28,28) for im in new_data[:3]], 77 | [im.reshape(28,28) for im in rho_p[:3]], 78 | mu_q[:3], stddev_q[:3]) 79 | 80 | 81 | def multivariate_latent_space_PCA(): 82 | ''' 83 | Do same as bivariate latent space visualization except perform PCA 84 | to project onto lower dimension before visualizing 85 | ''' 86 | pass 87 | 88 | 89 | def reconstruction_test(): 90 | ''' 91 | Reconstruct samples after training VAE 92 | ''' 93 | # Instantiate and train vanilla autoencoder 94 | nn = vae.VAE(ARCH, DIMS, OPT, PARAMS, TRAIN) 95 | nn.train() 96 | 97 | # Do single forward pass on the encoder 98 | new_data = vae.mnist.test.next_batch(TRAIN["batch_size"])[0] 99 | mu_q, stddev_q, rho_p = nn.full_fp(new_data) 100 | 101 | vis.juxtapose_images( 102 | [im.reshape(28,28) for im in new_data[:5]], 103 | [im.reshape(28,28) for im in rho_p[:5]]) 104 | 105 | 106 | if __name__ == "__main__": 107 | #simple_test() 108 | #bivariate_latent_space() 109 | reconstruction_test() 110 | -------------------------------------------------------------------------------- /tf_helpers.py: -------------------------------------------------------------------------------- 1 | """ tf_helpers.py: Random relevant TensorFlow augmentations 2 | """ 3 | 4 | __author__ = "shraman-rc" 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | tf.set_random_seed(0) 10 | 11 | X_INIT=tf.contrib.layers.xavier_initializer(uniform=True, seed=0) 12 | 13 | def xavier(dims): 14 | return X_INIT(dims) 15 | 16 | def gaussian(dims): 17 | return tf.random_normal(dims) 18 | 19 | def shape(t): 20 | return t.get_shape().as_list() 21 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | """ vae.py: Run the Variational Auto-encoder. 2 | 3 | TODO: 4 | - Incorporate L (MC samples) without blowing up the decoder variable count 5 | - There's a blow-up of the log_var output of the encoder which makes 6 | the KL-divergence term of the error function go to infinity since there 7 | is a var term (where var = e^{log_var}). This seems to be irreversible 8 | when the learning rate is high. 9 | 10 | UPDATE: The network weights seem to be such that at the beginning, there 11 | is high variance output (exploding gradients when NN still malleable) 12 | 13 | UPDATE: Also exacerbated by batch size. Gradient clipping ineffective 14 | when gradient becomes nan! 15 | 16 | UPDATE: This problem is *super* sensitive to learning rate and highly 17 | nondeterministic. At a learning of 0.016 (using Adam) gradients will 18 | sometimes blow up to 3e+28 and when it doesn't, they will go no higher 19 | than 100! However, there is much more stability across runs even at 0.015! 20 | 21 | UPDATE: Also sensitive to latent space dimensionality changes (reduction 22 | from 10 -> 2 made it go haywire even with Adam rate of 0.01) 23 | 24 | TODO: Try value clipping of the KL divergence. Norm clipping too? 25 | - We use output of sigmoid to parameterize the multivariate Bernoulli. 26 | Does its steepeness affect learning? 27 | - Apply batch normalization as in [3] 28 | - Try different reconstruction loss instead of log-likelihood: 29 | - Cross entropy 30 | - Quantitatively assess Adagrad, SGD, Adam performance 31 | - Use tf.nn.dropout to perform Variational Dropout 32 | - Other datasets besides MNIST 33 | - Examples: 34 | - DRAW network 35 | - Generative adversarial network 36 | - Music composition network (similar to DRAW) 37 | - Picasso network (similar to DRAW) 38 | """ 39 | 40 | __author__ = "shraman-rc" 41 | 42 | import tensorflow as tf 43 | import numpy as np 44 | import click as cl 45 | 46 | from nets import BernoulliMLP, GaussianMLP 47 | import likelihoods as lh 48 | 49 | from tensorflow.examples.tutorials.mnist import input_data 50 | mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 51 | 52 | tf.set_random_seed(0) 53 | 54 | 55 | class VAE(object): 56 | 57 | def __init__(self, ARCH, DIMS, OPT, PARAMS, TRAIN): 58 | ''' 59 | Initializes VAE with parameter dictionaries that should follow the 60 | same format as config/nn_config.yaml 61 | ''' 62 | self.ARCH, self.DIMS, self.OPT, self.PARAMS, self.TRAIN = \ 63 | ARCH, DIMS, OPT, PARAMS, TRAIN 64 | 65 | # Inputs - mini-batches of (flattened) images 66 | self.x_batch = tf.placeholder(tf.float32, 67 | shape=[None, self.DIMS["data"]]) 68 | 69 | # Encoder parameterizes posterior Gaussian approximation q(z|x) 70 | self.encoder = GaussianMLP(self.x_batch, 71 | self.ARCH["encoder"]["n_units"], self.DIMS["latent"]) 72 | self.latent = {} 73 | self.latent["mu"] = self.encoder.out_params.mu 74 | self.latent["log_var"] = self.encoder.out_params.log_var 75 | self.latent["var"] = tf.exp(self.latent["log_var"]) 76 | self.latent["stddev"] = tf.sqrt(self.latent["var"]) 77 | 78 | # Reparameterize latent space (z = g(ep,x); ep ~ p(ep)) 79 | # Note: Element-wise univariate Gaussian sampling <=> 80 | # multivariate Gaussian sampling 81 | self.ep = tf.random_normal([self.TRAIN["batch_size"], 82 | self.DIMS["latent"]], mean=0, stddev=1) 83 | self.z_batch = self.latent["mu"] + self.latent["stddev"]*self.ep 84 | 85 | # Decoder samples from latent distribution, parameterizes likelihood 86 | # (in this case, a multivariate Bernoulli - working with images) 87 | self.decoder = BernoulliMLP(self.z_batch, 88 | self.ARCH["decoder"]["n_units"], self.DIMS["data"]) 89 | 90 | # The (negative) KL divergence between the variational approx. and the *prior* 91 | # p_theta(z) acts as a regularizing term so that the latent distribution 92 | # doesn't overfit. The closed-form eq. is derived in [1]: Appedix B 93 | self.neg_KL_pr = 0.5 * tf.reduce_sum(1 + self.latent["log_var"] 94 | - self.latent["mu"]**2 - self.latent["var"], 1) 95 | 96 | # The 'reconstruction error' (predictive likelihood): log p_theta(x_batch|z) 97 | self.ll = lh.ll_bernoulli(self.x_batch, self.decoder.out_params.p) 98 | 99 | # ELBO estimate (total reward function) 100 | self.ELBO_est = tf.reduce_mean(self.neg_KL_pr + self.ll) # Mean over minibatch 101 | 102 | # Pick a flavor of gradient descent 103 | if self.OPT["type"].lower() == "adagrad": 104 | self.optimizer = tf.train.AdagradOptimizer(self.OPT["Adagrad_rate"]) 105 | elif self.OPT["type"].lower() == "adam": 106 | self.optimizer = tf.train.AdamOptimizer(self.OPT["Adam_rate"]) 107 | 108 | # Notice that we are minimizing the negative (i.e. maximizing) the ELBO 109 | # We also clip the gradients to prevent blowup during first few 110 | # training phases 111 | #self.train_op = self.optimizer.minimize(-self.ELBO_est) 112 | #self.vi_train_op = self.optimizer.minimize(-self.neg_KL_pr) 113 | #self.ll_train_op = self.optimizer.minimize(-self.ll) 114 | # ...with clipped gradients: 115 | gvs = self.optimizer.compute_gradients(-self.ELBO_est) 116 | capped_gvs = [(tf.clip_by_value( 117 | grad, -self.OPT["max_grad"], self.OPT["max_grad"]), var) 118 | for grad, var in gvs if grad != None] 119 | flat_grads = [tf.reshape(grad,[-1]) for grad, var in capped_gvs] 120 | self.max_grad = tf.reduce_max(tf.concat(0, flat_grads)) 121 | self.train_op = self.optimizer.apply_gradients(capped_gvs) 122 | 123 | # Default session to use for operations 124 | self.sess = tf.InteractiveSession() 125 | 126 | 127 | def _train_step(self, sess, data, verbose=True): 128 | ''' Common helper to run individual training steps, see _train() 129 | Returns output of salient variables above (e.g. ELBO) after 130 | each optimization iteration 131 | 132 | Params: 133 | - sess,verbose: See _train() 134 | - data: batch of raw data with correct dimensions 135 | ''' 136 | _, ELBO, ll, neg_KL, mu, log_var, ep, max_grad = sess.run([ 137 | self.train_op, 138 | self.ELBO_est, 139 | self.ll, 140 | self.neg_KL_pr, 141 | self.latent["mu"], 142 | self.latent["log_var"], 143 | self.ep, 144 | self.max_grad], 145 | feed_dict={self.x_batch: data}) 146 | # _, ELBO, ll, neg_KL, mu, log_var, ep, max_grad = sess.run([ 147 | # self.vi_train_op, 148 | # self.ELBO_est, 149 | # self.ll, 150 | # self.neg_KL_pr, 151 | # self.latent["mu"], 152 | # self.latent["log_var"], 153 | # self.ep, 154 | # self.max_grad], 155 | # feed_dict={x_batch: data}) 156 | 157 | # Perform some sort of reductions on minibatches if need be 158 | ELBO, neg_KL, ll, mu, log_var, ep = ( 159 | np.mean(ELBO), 160 | np.mean(neg_KL), 161 | np.mean(ll), 162 | mu[0], 163 | log_var[0], 164 | ep[0]) 165 | 166 | # Print stats 167 | if verbose: 168 | cl.secho(( 169 | "ELBO (estimate): {}\n" 170 | "KL Div (prior): {}\n" 171 | "Likelihood: {}\n" 172 | "Mu: {}\n" 173 | "Log var: {}\n" 174 | "Epsilon: {}\n" 175 | "Max grad: {}") 176 | .format(ELBO, -neg_KL, ll, mu, log_var, ep, max_grad), fg='cyan') 177 | 178 | return ELBO, ll, neg_KL, mu, log_var, ep, max_grad 179 | 180 | 181 | def _train(self, iters, mbsize, sess, verbose=True): 182 | ''' Trains the VAE end-to-end on MNIST (handwriting) dataset 183 | Returns progress through training phases on above variables 184 | via numpy arrays. 185 | 186 | Params: 187 | - iters: Number of training iteration per epoch 188 | - mbsize: Number of datapoints per minibatch 189 | - sess: TF session to use if already instantiated one 190 | if None, will use self.sess (InteractiveSession) 191 | - verbose: Print training progress after each timestep 192 | ''' 193 | # Train on MNIST 194 | # To keep track of progress 195 | progress = {} 196 | progress["ELBO"] = np.zeros(iters) 197 | progress["KL"] = np.zeros(iters) 198 | progress["LL"] = np.zeros(iters) 199 | 200 | # Optimize VAE 201 | sess.run(tf.initialize_all_variables()) 202 | timesteps = xrange(iters) 203 | for t in timesteps: 204 | cl.secho('Minibatch {}'.format(t), fg='green', bold=False) 205 | batch = mnist.train.next_batch(mbsize)[0] 206 | prog = self._train_step(sess, batch, verbose) 207 | progress["ELBO"][t] = prog[0] 208 | progress["LL"][t] = prog[1] 209 | progress["KL"][t] = -prog[2] 210 | 211 | progress["iters"] = timesteps 212 | 213 | cl.secho('Success!', fg='green', bold=True) 214 | return progress 215 | 216 | 217 | def train(self, iters=None, mbsize=None, sess=None, verbose=True): 218 | ''' Wrapper for above _train() function 219 | ''' 220 | iters = iters or self.TRAIN["n_iters"] 221 | mbsize = mbsize or self.TRAIN["batch_size"] 222 | sess = sess or self.sess 223 | 224 | return self._train(iters, mbsize, sess, verbose) 225 | 226 | 227 | def encoder_fp(self, data, sess=None): 228 | ''' Performs one forward pass of the encoder. 229 | Output are the parameters for the latent distribution. 230 | 231 | Params: 232 | - data: Datapoint(s) to be processed by encoder 233 | - sess: TF session to use if already instantiated one 234 | ''' 235 | sess = sess or self.sess 236 | return sess.run(self.latent["mu"], self.latent["stddev"], 237 | feed_dict={self.x_batch: data}) 238 | 239 | 240 | def decoder_fp(self, sess=None): 241 | ''' Performs one forward pass of the decoder. 242 | Output are the parameters for the data distribution. 243 | 244 | Params: 245 | - sess: TF session to use if already instantiated one 246 | ''' 247 | sess = sess or self.sess 248 | return sess.run(self.decoder.out_params.p) 249 | 250 | 251 | def full_fp(self, data, sess=None): 252 | ''' Performs one full forward pass of VAE end-to-end. 253 | 254 | Params: 255 | - data: Datapoint(s) to be processed by encoder 256 | - sess: TF session to use if already instantiated one 257 | ''' 258 | sess = sess or self.sess 259 | return sess.run([self.latent["mu"], self.latent["stddev"], 260 | self.decoder.out_params.p], feed_dict={self.x_batch: data}) 261 | 262 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | """ vis.py: Visualization tools for demos and interactive UI 2 | """ 3 | 4 | __author__ = 'shraman-rc' 5 | 6 | import os 7 | import matplotlib as mpl 8 | import matplotlib.pyplot as plt 9 | import matplotlib.mlab as mlab 10 | import matplotlib.widgets as widget 11 | from mpl_toolkits.mplot3d import Axes3D 12 | from matplotlib import cm 13 | from matplotlib.ticker import LinearLocator, FormatStrFormatter 14 | import numpy as np 15 | 16 | def basic_multiplot(data_xs, data_ys, titles, labels=None, unit_x="Minibatches", show_legend=True, params={}): 17 | """ Easily plot multiple lines (e.g. different error signals) 18 | 19 | - data_xs: List of nd.arrays for x-axis for each plot 20 | - data_ys: List of lists of nd.arrays for y-axis for each plot 21 | - titles: List of strings, one for each plot 22 | - labels: List of lists of strings, one for each (x,y) pair 23 | - unit_x: Shared units for x-axis (e.g. epochs, minibatches) 24 | - show_legend: true if should show label for each line 25 | - params: Dictionary of hyperparameters, will appear in textbox 26 | """ 27 | num_plots = len(data_xs) 28 | fig, axes = plt.subplots(num_plots) 29 | if num_plots == 1: 30 | axes = [axes] 31 | 32 | # Looks 33 | plt.rc('axes', color_cycle=['g', 'm', 'k', 'c']) 34 | plt.tight_layout() # So axis no overlap with title 35 | 36 | # Populate default labels if no legend to show 37 | if not show_legend: 38 | labels = [['']*len(arrs) for arrs in data_ys] 39 | else: 40 | assert(labels) 41 | 42 | # Plot each line in each subplot 43 | for i, x in enumerate(data_xs): 44 | for y, l in zip(data_ys[i], labels[i]): 45 | axes[i].plot(x, y, label=l) 46 | axes[i].set_title(titles[i], fontsize=22) 47 | axes[i].set_ylabel("Values", fontsize=16) 48 | axes[i].grid() 49 | if show_legend: 50 | axes[i].legend(loc="upper right", 51 | ncol=1, 52 | shadow=True, 53 | title="Heuristics", 54 | fancybox=True, 55 | prop={'size':15}) 56 | 57 | # Label common x-axis 58 | axes[-1].set_xlabel("{}".format(unit_x), fontsize=16) 59 | 60 | # Parameter legend 61 | paramtex = '\n'.join(['{}: ${}$'.format(k,v) for k,v in params.items()]) 62 | props = dict(boxstyle='round', facecolor='wheat', alpha=0.95) 63 | plt.text(0.90, 0.1, paramtex, transform=axes[0].transAxes, fontsize=20, 64 | verticalalignment='top', horizontalalignment='right', bbox=props) 65 | 66 | plt.show() 67 | 68 | 69 | def basic_multiline(data_x, data_ys, x_axis="Minibatch", y_axis="Error", 70 | title="Convergence Rate of ELBO"): 71 | 72 | for data_y in data_ys: 73 | line = plt.plot(data_x, data_y)[0] 74 | line.set_linewidth(2.0) 75 | 76 | plt.legend(loc="upper left", ncol=1, shadow=True, title="Errors", fancybox=True, prop={'size':25}) 77 | plt.title(title, fontsize=30) 78 | plt.xlabel(x_axis) 79 | plt.ylabel(y_axis) 80 | plt.xticks(fontsize=20) 81 | plt.yticks(fontsize=20) 82 | plt.grid() 83 | 84 | plt.show() 85 | 86 | 87 | def juxtapose_images(imset1, imset2): 88 | assert(len(imset1) == len(imset2)) 89 | N = len(imset1) 90 | 91 | plt.figure(figsize=(8, 12)) 92 | 93 | for i in range(N): 94 | plt.subplot(N, 2, 2*i + 1) 95 | plt.imshow(imset1[i], vmin=0, vmax=1) 96 | plt.axis('off') 97 | plt.title("Original (MNIST)") 98 | plt.subplot(N, 2, 2*i + 2) 99 | plt.imshow(imset2[i], vmin=0, vmax=1) 100 | plt.axis('off') 101 | plt.title("Reconstructed") 102 | 103 | plt.tight_layout() 104 | plt.show() 105 | 106 | 107 | def full_pass_vis(imset1, imset2, mu, stddev): 108 | ''' Visualize reconstructed images and latent distributions 109 | ''' 110 | assert(len(imset1) == len(imset2)) 111 | N = len(imset1) 112 | 113 | fig = plt.figure() 114 | 115 | for i in range(N): 116 | # Plot the images 117 | plt.subplot(N, 3, 3*i + 1) 118 | plt.imshow(imset1[i], vmin=0, vmax=1) 119 | plt.axis('off') 120 | plt.title("Original (MNIST)") 121 | plt.subplot(N, 3, 3*i + 3) 122 | plt.imshow(imset2[i], vmin=0, vmax=1) 123 | plt.axis('off') 124 | plt.title("Reconstructed") 125 | 126 | # Plot the latent distributions in between 127 | ax = plt.subplot(N, 3, 3*i + 2, projection='3d') 128 | mux, muy = mu[i] 129 | sigx, sigy = stddev[i] 130 | X = np.arange(-0.5, 0.5, 0.025) + mux 131 | Y = np.arange(-0.5, 0.5, 0.025) + muy 132 | X, Y = np.meshgrid(X, Y) 133 | Z = mlab.bivariate_normal(X,Y, sigmax=sigx, sigmay=sigy, mux=mux, muy=muy) 134 | surf = ax.plot_surface(X, Y, Z, rstride=2, cstride=2, cmap=cm.PuBu, 135 | linewidth=0.1, antialiased=False) 136 | 137 | # Style the 3D plot 138 | ax.set_zlim(np.min(Z), 1.5*np.max(Z)) 139 | ax.set_title("$q_{\phi}(z|x)$", fontsize=20) 140 | ax.zaxis.set_major_locator(LinearLocator(5)) 141 | ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f')) 142 | 143 | #plt.tight_layout() 144 | plt.show() 145 | --------------------------------------------------------------------------------