├── main.py ├── vae ├── __init__.py ├── prior.py ├── encoder.py ├── decoder.py └── vae.py ├── requirements.txt ├── README.md └── .gitignore /main.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /vae/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.14.1 2 | six==1.11.0 3 | tensorflow==1.6.0 4 | -------------------------------------------------------------------------------- /vae/prior.py: -------------------------------------------------------------------------------- 1 | """The prior distribution over latent-variable values.""" 2 | import tensorflow as tf 3 | from tensorflow import distributions as ds 4 | 5 | 6 | def prior(latent_size): 7 | """Prior builds the prior distribution against the provided latent tensor. 8 | 9 | Args: 10 | latent_size (int): The dimension of the latent space. 11 | 12 | Returns: 13 | tf.distributions.Normal: The prior over a single latent tensor. 14 | """ 15 | shp = [latent_size] 16 | loc = tf.zeros(shp) 17 | scale = tf.ones(shp) 18 | return ds.Normal(loc, scale) 19 | -------------------------------------------------------------------------------- /vae/encoder.py: -------------------------------------------------------------------------------- 1 | """Encoder builds the encoder network on a given input image.""" 2 | import tensorflow as tf 3 | from tensorflow import distributions as ds 4 | 5 | 6 | def encoder(img, latent_size, units): 7 | """Encoder builds an encoder network against the provided image tensor. 8 | 9 | Args: 10 | img (tf.Tensor): batch_size x img_size tensor of flat images. 11 | 12 | Returns: 13 | (tf.distribution.Normal): The batch_shape = (batch_size, latent_size) 14 | batch of posterior normal distributions. 15 | """ 16 | hidden = tf.layers.dense(img, units) 17 | 18 | loc = tf.layers.dense(hidden, latent_size) 19 | scale = tf.layers.dense(hidden, latent_size) 20 | return ds.Normal(loc, scale) 21 | -------------------------------------------------------------------------------- /vae/decoder.py: -------------------------------------------------------------------------------- 1 | """Decoder builds the decoder network on a given latent variable.""" 2 | import tensorflow as tf 3 | from tensorflow import distributions as ds 4 | 5 | 6 | def decoder(latent, img_size, units): 7 | """Decoder builds a decoder network on the given latent variable tensor. 8 | 9 | Args: 10 | lv (tf.Tensor): sample_size x batch_size x latent_size latent tensor. 11 | 12 | Returns: 13 | (tf.distribution.Normal): The batch_shape = (sample x batch x img) 14 | normal distributions representing the sampled img likelihoods. 15 | """ 16 | hidden = tf.layers.dense(latent, units) 17 | 18 | loc = tf.layers.dense(hidden, img_size) 19 | scale = tf.layers.dense(hidden, img_size) 20 | return ds.Normal(loc, scale) 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Auto-Encoder (vanilla) 2 | Replication of [Auto-Encoding Variational Bayes](https://arxiv.org/pdf/1312.6114.pdf) (Kingma & Welling, 2013) 3 | 4 | 5 | ## Quick Start 6 | 7 | ```bash 8 | # Create and activate virtual environment 9 | virtualenv -p python3.5 venv 10 | source venv/bin/activate 11 | 12 | # Install dependencies with pip 13 | pip install -r requirements.txt 14 | 15 | # Run main.py, which trains vae and saves results to /img 16 | python main.py 17 | ``` 18 | 19 | ## More Details 20 | 21 | The variational autoencoder implementation is in vanilla tensorflow, and is in `/vae`. 22 | Since the same graph can be used in multiple ways, there is a simple `VAE` class that 23 | constructs the `tf` graph and has useful pointers to important tensors and methods to 24 | simplify interaction with those tensors. 25 | 26 | (being single use code, there are no unit tests here) 27 | 28 | 29 | ## Thoughts 30 | 31 | - VAE for MNIST 32 | - VAE for Frey Face 33 | - Functions to make encoders, decoders 34 | - Simple object for full graph 35 | - inference method for image sim 36 | - accessible input and loss 37 | - Make Figure 2 for z dim = 10 38 | 39 | Chuck all that stuff in package vae. 40 | 41 | Then in main.py in root, load those and relevant graphing tools, train the model, make the graphs and images. 42 | Then save to some gitignored subfolder. So our tf code is nicely separated but we can still easily go 43 | 44 | ```bash 45 | python main.py 46 | ``` 47 | 48 | to generate some images. -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # VSCode 104 | .vscode/ 105 | 106 | # Outputs 107 | img/ -------------------------------------------------------------------------------- /vae/vae.py: -------------------------------------------------------------------------------- 1 | """VAE contains the variational auto-encoder convenience class.""" 2 | import tensorflow as tf 3 | 4 | from vae.decoder import decoder 5 | from vae.encoder import encoder 6 | from vae.prior import prior 7 | 8 | 9 | class VAE: 10 | """VAE is a wrapper around a ful variational auto-encoder graph. 11 | 12 | Attributes: 13 | input (tf.Tensor): Points to image input placeholder. 14 | latent (tf.Tensor): Points to latent variable sample tensor. 15 | loss (tf.Tensor): Points to the ELBO loss tensor. 16 | prior (tf.distribution.Normal): Prior distribution. 17 | encoder (tf.distribution.Normal): Encoder / recognition distribution. 18 | decoder (tf.distribution.Normal): Decoder distribution. 19 | """ 20 | 21 | def __init__(self, img_size=225, latent_size=10, sample_size=1, units=500): 22 | """Creates a new instance of VAE. 23 | 24 | This creates the complete static graph, which is accessed afterwards 25 | only through session runs. 26 | 27 | Args: 28 | img_size (int): Flattened dim of input image. 29 | batch_size (int): The minibatch size, determines input tensor dims. 30 | latent_size (int): Dimension of the latent normal variable. 31 | sample_size (int): The sample size drawn from the recognition model. 32 | Usually 1, since we do stochastic integration. 33 | """ 34 | self.input = tf.placeholder(tf.float32, [None, img_size]) 35 | self.encoder = encoder(self.input, latent_size, units) 36 | self.latent = self.encoder.sample(sample_size) 37 | self.decoder = decoder(self.latent, img_size, units) 38 | self.prior = prior(latent_size) 39 | 40 | likelihood = self.decoder.log_prob(self.input) 41 | latent_prior = self.prior.log_prob(self.latent) 42 | latent_posterior = self.encoder.log_prob(self.latent) 43 | 44 | self.loss = ( 45 | tf.reduce_sum(likelihood) / sample_size + 46 | tf.reduce_sum(latent_prior) / sample_size - 47 | tf.reduce_sum(latent_posterior) / sample_size 48 | ) 49 | 50 | def decode(self, latent): 51 | """Decodes the provided latent array, returns a sample from the output. 52 | 53 | Args: 54 | latent (np.ndarray): A sample_size x batch_size x latent_size 55 | latent variable array. 56 | 57 | Returns: 58 | np.ndarray: A sample_size x batch_size, img_size array of sampled 59 | and decoded images. 60 | """ 61 | sess = tf.Session() 62 | img = sess.run(self.decoder.sample(), data_dict={self.latent: latent}) 63 | return img 64 | 65 | def encode(self, img): 66 | """Encodes the provided images, returns a sample from the latent posterior. 67 | 68 | Args: 69 | img (np.ndarray): A batch_size x img_size array of flattened images. 70 | 71 | Returns: 72 | np.ndarray: A sample_size x batch_size x latent_size ndarray of 73 | latent variables. 74 | """ 75 | sess = tf.Session() 76 | latent = sess.run(self.latent, data_dict={self.input: img}) 77 | return latent 78 | --------------------------------------------------------------------------------