├── .env ├── .gitignore ├── LICENSE ├── README.md ├── data.py ├── environment-jax.yml ├── environment-pytorch.yml ├── environment.yml ├── flow.py ├── masks.py ├── setup.cfg ├── train_variational_autoencoder_jax.py ├── train_variational_autoencoder_pytorch.py └── train_variational_autoencoder_tensorflow.py /.env: -------------------------------------------------------------------------------- 1 | # dev.env - development configuration 2 | 3 | # suppress warnings for jax 4 | # JAX_PLATFORM_NAME=cpu 5 | 6 | # suppress tensorflow warnings 7 | TF_CPP_MIN_LOG_LEVEL=3 8 | 9 | # set tensorflow data directory 10 | TFDS_DATA_DIR=/scratch/gpfs/altosaar/tensorflow_datasets 11 | 12 | # disable JIT for debugging 13 | JAX_DISABLE_JIT=1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | launch.json 3 | settings.json 4 | *.code-workspace -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Jaan Altosaar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Autoencoder in tensorflow and pytorch 2 | [![DOI](https://zenodo.org/badge/65744394.svg)](https://zenodo.org/badge/latestdoi/65744394) 3 | 4 | Reference implementation for a variational autoencoder in TensorFlow and PyTorch. 5 | 6 | I recommend the PyTorch version. It includes an example of a more expressive variational family, the [inverse autoregressive flow](https://arxiv.org/abs/1606.04934). 7 | 8 | Variational inference is used to fit the model to binarized MNIST handwritten digits images. An inference network (encoder) is used to amortize the inference and share parameters across datapoints. The likelihood is parameterized by a generative network (decoder). 9 | 10 | Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/ 11 | 12 | 13 | ## PyTorch implementation 14 | 15 | (anaconda environment is in `environment-jax.yml`) 16 | 17 | Importance sampling is used to estimate the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. The final marginal likelihood on the test set was `-97.10` nats is comparable to published numbers. 18 | 19 | ``` 20 | $ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000 21 | Step 0 Train ELBO estimate: -558.027 Validation ELBO estimate: -384.432 Validation log p(x) estimate: -355.430 Speed: 2.72e+06 examples/s 22 | Step 10000 Train ELBO estimate: -111.323 Validation ELBO estimate: -109.048 Validation log p(x) estimate: -103.746 Speed: 2.64e+04 examples/s 23 | Step 20000 Train ELBO estimate: -103.013 Validation ELBO estimate: -107.655 Validation log p(x) estimate: -101.275 Speed: 2.63e+04 examples/s 24 | Step 29999 Test ELBO estimate: -106.642 Test log p(x) estimate: -100.309 25 | Total time: 2.49 minutes 26 | ``` 27 | 28 | 29 | Using a non mean-field, more expressive variational posterior approximation (inverse autoregressive flow, https://arxiv.org/abs/1606.04934), the test marginal log-likelihood improves to `-95.33` nats: 30 | 31 | ``` 32 | $ python train_variational_autoencoder_pytorch.py --variational flow 33 | step: 0 train elbo: -578.35 34 | step: 0 valid elbo: -407.06 valid log p(x): -367.88 35 | step: 10000 train elbo: -106.63 36 | step: 10000 valid elbo: -110.12 valid log p(x): -104.00 37 | step: 20000 train elbo: -101.51 38 | step: 20000 valid elbo: -105.02 valid log p(x): -99.11 39 | step: 30000 train elbo: -98.70 40 | step: 30000 valid elbo: -103.76 valid log p(x): -97.71 41 | ``` 42 | 43 | ## jax implementation 44 | 45 | Using jax (anaconda environment is in `environment-jax.yml`), to get a 3x speedup over pytorch: 46 | ``` 47 | $ python train_variational_autoencoder_jax.py --variational mean-field 48 | Step 0 Train ELBO estimate: -566.059 Validation ELBO estimate: -565.755 Validation log p(x) estimate: -557.914 Speed: 2.56e+11 examples/s 49 | Step 10000 Train ELBO estimate: -98.560 Validation ELBO estimate: -105.725 Validation log p(x) estimate: -98.973 Speed: 7.03e+04 examples/s 50 | Step 20000 Train ELBO estimate: -109.794 Validation ELBO estimate: -105.756 Validation log p(x) estimate: -97.914 Speed: 4.26e+04 examples/s 51 | Step 29999 Test ELBO estimate: -104.867 Test log p(x) estimate: -96.716 52 | Total time: 0.810 minutes 53 | ``` 54 | 55 | Inverse autoregressive flow in jax: 56 | ``` 57 | $ python train_variational_autoencoder_jax.py --variational flow 58 | Step 0 Train ELBO estimate: -727.404 Validation ELBO estimate: -726.977 Validation log p(x) estimate: -713.389 Speed: 2.56e+11 examples/s 59 | Step 10000 Train ELBO estimate: -100.093 Validation ELBO estimate: -106.985 Validation log p(x) estimate: -99.565 Speed: 2.57e+04 examples/s 60 | Step 20000 Train ELBO estimate: -113.073 Validation ELBO estimate: -108.057 Validation log p(x) estimate: -98.841 Speed: 3.37e+04 examples/s 61 | Step 29999 Test ELBO estimate: -106.803 Test log p(x) estimate: -97.620 62 | Total time: 2.350 minutes 63 | ``` 64 | 65 | (The difference between a mean field and inverse autoregressive flow may be due to several factors, chief being the lack of convolutions in the implementation. Residual blocks are used in https://arxiv.org/pdf/1606.04934.pdf to get the ELBO closer to -80 nats.) 66 | 67 | # Generating the GIFs 68 | 69 | 1. Run `python train_variational_autoencoder_tensorflow.py` 70 | 2. Install imagemagick (homebrew for Mac: https://formulae.brew.sh/formula/imagemagick or Chocolatey in Windows: https://community.chocolatey.org/packages/imagemagick.app) 71 | 3. Go to the directory where the jpg files are saved, and run the imagemagick command to generate the .gif: `convert -delay 20 -loop 0 *.jpg latent-space.gif` 72 | 4. 73 | 74 | ## TODO (help needed - feel free to send a PR!) 75 | - add multiple GPU / TPU option 76 | - add jaxtyping support for PyTorch and Jax implementations :) for runtime static type checking (using @beartype decorators) 77 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """Get the binarized MNIST dataset and convert to hdf5. 2 | From https://github.com/yburda/iwae/blob/master/datasets.py 3 | """ 4 | import urllib.request 5 | import os 6 | import numpy as np 7 | import h5py 8 | import torch 9 | 10 | 11 | def parse_binary_mnist(data_dir): 12 | def lines_to_np_array(lines): 13 | return np.array([[int(i) for i in line.split()] for line in lines]) 14 | 15 | with open(os.path.join(data_dir, "binarized_mnist_train.amat")) as f: 16 | lines = f.readlines() 17 | train_data = lines_to_np_array(lines).astype("float32") 18 | with open(os.path.join(data_dir, "binarized_mnist_valid.amat")) as f: 19 | lines = f.readlines() 20 | validation_data = lines_to_np_array(lines).astype("float32") 21 | with open(os.path.join(data_dir, "binarized_mnist_test.amat")) as f: 22 | lines = f.readlines() 23 | test_data = lines_to_np_array(lines).astype("float32") 24 | return train_data, validation_data, test_data 25 | 26 | 27 | def download_binary_mnist(fname): 28 | data_dir = "/tmp/" 29 | subdatasets = ["train", "valid", "test"] 30 | for subdataset in subdatasets: 31 | filename = "binarized_mnist_{}.amat".format(subdataset) 32 | url = "http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat".format( 33 | subdataset 34 | ) 35 | local_filename = os.path.join(data_dir, filename) 36 | urllib.request.urlretrieve(url, local_filename) 37 | 38 | train, validation, test = parse_binary_mnist(data_dir) 39 | 40 | data_dict = {"train": train, "valid": validation, "test": test} 41 | f = h5py.File(fname, "w") 42 | f.create_dataset("train", data=data_dict["train"]) 43 | f.create_dataset("valid", data=data_dict["valid"]) 44 | f.create_dataset("test", data=data_dict["test"]) 45 | f.close() 46 | print(f"Saved binary MNIST data to: {fname}") 47 | 48 | 49 | def load_binary_mnist(fname, batch_size, test_batch_size, use_gpu): 50 | f = h5py.File(fname, "r") 51 | x_train = f["train"][::] 52 | x_val = f["valid"][::] 53 | x_test = f["test"][::] 54 | train = torch.utils.data.TensorDataset(torch.from_numpy(x_train)) 55 | kwargs = {"num_workers": 4, "pin_memory": True} if use_gpu else {} 56 | train_loader = torch.utils.data.DataLoader( 57 | train, batch_size=batch_size, shuffle=True, **kwargs 58 | ) 59 | validation = torch.utils.data.TensorDataset(torch.from_numpy(x_val)) 60 | val_loader = torch.utils.data.DataLoader( 61 | validation, batch_size=test_batch_size, shuffle=False, **kwargs 62 | ) 63 | test = torch.utils.data.TensorDataset(torch.from_numpy(x_test)) 64 | test_loader = torch.utils.data.DataLoader( 65 | test, batch_size=test_batch_size, shuffle=False, **kwargs 66 | ) 67 | return train_loader, val_loader, test_loader 68 | -------------------------------------------------------------------------------- /environment-jax.yml: -------------------------------------------------------------------------------- 1 | name: /scratch/gpfs/altosaar/environment-jax 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - ca-certificates=2021.4.13=h06a4308_1 7 | - certifi=2020.12.5=py39h06a4308_0 8 | - ld_impl_linux-64=2.33.1=h53a641e_7 9 | - libffi=3.3=he6710b0_2 10 | - libgcc-ng=9.1.0=hdf63c60_0 11 | - libstdcxx-ng=9.1.0=hdf63c60_0 12 | - ncurses=6.2=he6710b0_1 13 | - openssl=1.1.1k=h27cfd23_0 14 | - python=3.9.5=hdb3f193_3 15 | - readline=8.1=h27cfd23_0 16 | - setuptools=52.0.0=py39h06a4308_0 17 | - sqlite=3.35.4=hdfb4753_0 18 | - tk=8.6.10=hbc83047_0 19 | - tzdata=2020f=h52ac0ba_0 20 | - wheel=0.36.2=pyhd3eb1b0_0 21 | - xz=5.2.5=h7b6447c_0 22 | - zlib=1.2.11=h7b6447c_3 23 | - pip: 24 | - absl-py==0.12.0 25 | - astunparse==1.6.3 26 | - attrs==21.2.0 27 | - cachetools==4.2.2 28 | - chardet==4.0.0 29 | - chex==0.0.7 30 | - cloudpickle==1.6.0 31 | - decorator==5.0.9 32 | - dill==0.3.3 33 | - dm-haiku==0.0.5.dev0 34 | - dm-tree==0.1.6 35 | - flatbuffers==1.12 36 | - future==0.18.2 37 | - gast==0.4.0 38 | - google-auth==1.30.1 39 | - google-auth-oauthlib==0.4.4 40 | - google-pasta==0.2.0 41 | - googleapis-common-protos==1.53.0 42 | - grpcio==1.38.0 43 | - h5py==3.1.0 44 | - idna==2.10 45 | - jax==0.2.13 46 | - jaxlib==0.1.67+cuda111 47 | - jmp==0.0.2 48 | - keras-nightly==2.6.0.dev2021052500 49 | - keras-preprocessing==1.1.2 50 | - markdown==3.3.4 51 | - numpy==1.19.5 52 | - oauthlib==3.1.0 53 | - opt-einsum==3.3.0 54 | - optax==0.0.6 55 | - pip==21.1.2 56 | - promise==2.3 57 | - protobuf==3.17.1 58 | - pyasn1==0.4.8 59 | - pyasn1-modules==0.2.8 60 | - requests==2.25.1 61 | - requests-oauthlib==1.3.0 62 | - rsa==4.7.2 63 | - scipy==1.6.3 64 | - six==1.15.0 65 | - tabulate==0.8.9 66 | - tb-nightly==2.6.0a20210525 67 | - tensorboard-data-server==0.6.1 68 | - tensorboard-plugin-wit==1.8.0 69 | - tensorflow-datasets==4.3.0 70 | - tensorflow-metadata==1.0.0 71 | - termcolor==1.1.0 72 | - tf-estimator-nightly==2.5.0.dev2021032601 73 | - tf-nightly==2.6.0.dev20210525 74 | - tfp-nightly==0.14.0.dev20210525 75 | - toolz==0.11.1 76 | - tqdm==4.61.0 77 | - typing-extensions==3.7.4.3 78 | - urllib3==1.26.4 79 | - werkzeug==2.0.1 80 | - wrapt==1.12.1 81 | prefix: /scratch/gpfs/altosaar/environment-jax 82 | -------------------------------------------------------------------------------- /environment-pytorch.yml: -------------------------------------------------------------------------------- 1 | name: /scratch/gpfs/altosaar/environment-pytorch 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - blas=1.0=mkl 9 | - bzip2=1.0.8=h7b6447c_0 10 | - ca-certificates=2021.4.13=h06a4308_1 11 | - certifi=2020.12.5=py38h06a4308_0 12 | - cudatoolkit=11.1.74=h6bb024c_0 13 | - ffmpeg=4.3=hf484d3e_0 14 | - freetype=2.10.4=h5ab3b9f_0 15 | - gmp=6.2.1=h2531618_2 16 | - gnutls=3.6.15=he1e5248_0 17 | - h5py=2.10.0=py38hd6299e0_1 18 | - hdf5=1.10.6=hb1b8bf9_0 19 | - intel-openmp=2021.2.0=h06a4308_610 20 | - jpeg=9b=h024ee3a_2 21 | - lame=3.100=h7b6447c_0 22 | - lcms2=2.12=h3be6417_0 23 | - ld_impl_linux-64=2.33.1=h53a641e_7 24 | - libffi=3.3=he6710b0_2 25 | - libgcc-ng=9.1.0=hdf63c60_0 26 | - libgfortran-ng=7.3.0=hdf63c60_0 27 | - libiconv=1.15=h63c8f33_5 28 | - libidn2=2.3.1=h27cfd23_0 29 | - libpng=1.6.37=hbc83047_0 30 | - libstdcxx-ng=9.1.0=hdf63c60_0 31 | - libtasn1=4.16.0=h27cfd23_0 32 | - libtiff=4.1.0=h2733197_1 33 | - libunistring=0.9.10=h27cfd23_0 34 | - libuv=1.40.0=h7b6447c_0 35 | - lz4-c=1.9.3=h2531618_0 36 | - mkl=2021.2.0=h06a4308_296 37 | - mkl-service=2.3.0=py38h27cfd23_1 38 | - mkl_fft=1.3.0=py38h42c9631_2 39 | - mkl_random=1.2.1=py38ha9443f7_2 40 | - ncurses=6.2=he6710b0_1 41 | - nettle=3.7.2=hbbd107a_1 42 | - ninja=1.10.2=hff7bd54_1 43 | - numpy=1.20.2=py38h2d18471_0 44 | - numpy-base=1.20.2=py38hfae3a4d_0 45 | - olefile=0.46=py_0 46 | - openh264=2.1.0=hd408876_0 47 | - openssl=1.1.1k=h27cfd23_0 48 | - pillow=8.2.0=py38he98fc37_0 49 | - pip=21.1.1=py38h06a4308_0 50 | - python=3.8.10=hdb3f193_7 51 | - pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0 52 | - readline=8.1=h27cfd23_0 53 | - setuptools=52.0.0=py38h06a4308_0 54 | - six=1.15.0=py38h06a4308_0 55 | - sqlite=3.35.4=hdfb4753_0 56 | - tk=8.6.10=hbc83047_0 57 | - torchaudio=0.8.1=py38 58 | - torchvision=0.9.1=py38_cu111 59 | - typing_extensions=3.7.4.3=pyha847dfd_0 60 | - wheel=0.36.2=pyhd3eb1b0_0 61 | - xz=5.2.5=h7b6447c_0 62 | - zlib=1.2.11=h7b6447c_3 63 | - zstd=1.4.9=haebb681_0 64 | prefix: /scratch/gpfs/altosaar/environment-pytorch 65 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dev 2 | channels: 3 | - defaults 4 | dependencies: 5 | - blas=1.0=mkl 6 | - ca-certificates=2019.5.15=1 7 | - certifi=2019.6.16=py37_1 8 | - freetype=2.9.1=hb4e5f40_0 9 | - imageio=2.5.0=py37_0 10 | - intel-openmp=2019.4=233 11 | - jpeg=9b=he5867d9_2 12 | - libcxx=4.0.1=hcfea43d_1 13 | - libcxxabi=4.0.1=hcfea43d_1 14 | - libedit=3.1.20181209=hb402a30_0 15 | - libffi=3.2.1=h475c297_4 16 | - libgfortran=3.0.1=h93005f0_2 17 | - libpng=1.6.37=ha441bb4_0 18 | - libtiff=4.0.10=hcb84e12_2 19 | - mkl=2019.4=233 20 | - mkl-service=2.3.0=py37hfbe908c_0 21 | - mkl_fft=1.0.14=py37h5e564d8_0 22 | - mkl_random=1.0.2=py37h27c97d8_0 23 | - ncurses=6.1=h0a44026_1 24 | - numpy=1.16.5=py37hacdab7b_0 25 | - numpy-base=1.16.5=py37h6575580_0 26 | - olefile=0.46=py37_0 27 | - openssl=1.1.1d=h1de35cc_1 28 | - pillow=6.1.0=py37hb68e598_0 29 | - python=3.7.4=h359304d_1 30 | - readline=7.0=h1de35cc_5 31 | - setuptools=41.0.1=py37_0 32 | - six=1.12.0=py37_0 33 | - sqlite=3.29.0=ha441bb4_0 34 | - tk=8.6.8=ha441bb4_0 35 | - wheel=0.33.4=py37_0 36 | - xz=5.2.4=h1de35cc_4 37 | - zlib=1.2.11=h1de35cc_3 38 | - zstd=1.3.7=h5bba6e5_0 39 | - pip: 40 | - absl-py==0.8.0 41 | - astor==0.8.0 42 | - attrs==19.1.0 43 | - chardet==3.0.4 44 | - cloudpickle==1.2.2 45 | - decorator==4.4.0 46 | - dill==0.3.0 47 | - future==0.17.1 48 | - gast==0.3.2 49 | - google-pasta==0.1.7 50 | - googleapis-common-protos==1.6.0 51 | - grpcio==1.23.0 52 | - h5py==2.10.0 53 | - idna==2.8 54 | - keras-applications==1.0.8 55 | - keras-preprocessing==1.1.0 56 | - markdown==3.1.1 57 | - pip==19.2.3 58 | - promise==2.2.1 59 | - protobuf==3.9.1 60 | - psutil==5.6.3 61 | - requests==2.22.0 62 | - tensorboard==1.14.0 63 | - tensorflow==1.14.0 64 | - tensorflow-datasets==1.2.0 65 | - tensorflow-estimator==1.14.0 66 | - tensorflow-metadata==0.14.0 67 | - tensorflow-probability==0.7.0 68 | - termcolor==1.1.0 69 | - tqdm==4.36.0 70 | - urllib3==1.25.3 71 | - werkzeug==0.15.6 72 | - wrapt==1.11.2 73 | prefix: /usr/local/anaconda3/envs/dev 74 | 75 | -------------------------------------------------------------------------------- /flow.py: -------------------------------------------------------------------------------- 1 | """Credit: mostly based on Ilya's excellent implementation here: https://github.com/ikostrikov/pytorch-flows""" 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | import masks 8 | 9 | 10 | class InverseAutoregressiveFlow(nn.Module): 11 | """Inverse Autoregressive Flows with LSTM-type update. One block. 12 | 13 | Eq 11-14 of https://arxiv.org/abs/1606.04934 14 | """ 15 | 16 | def __init__(self, num_input, num_hidden, num_context): 17 | super().__init__() 18 | self.made = MADE( 19 | num_input=num_input, 20 | num_outputs_per_input=2, 21 | num_hidden=num_hidden, 22 | num_context=num_context, 23 | ) 24 | # init such that sigmoid(s) is close to 1 for stability 25 | self.sigmoid_arg_bias = nn.Parameter(torch.ones(num_input) * 2) 26 | self.sigmoid = nn.Sigmoid() 27 | self.log_sigmoid = nn.LogSigmoid() 28 | 29 | def forward(self, input, context=None): 30 | m, s = torch.chunk(self.made(input, context), chunks=2, dim=-1) 31 | s = s + self.sigmoid_arg_bias 32 | sigmoid = self.sigmoid(s) 33 | z = sigmoid * input + (1 - sigmoid) * m 34 | return z, -self.log_sigmoid(s) 35 | 36 | 37 | class FlowSequential(nn.Sequential): 38 | """Forward pass.""" 39 | 40 | def forward(self, input, context=None): 41 | total_log_prob = torch.zeros_like(input, device=input.device) 42 | for block in self._modules.values(): 43 | input, log_prob = block(input, context) 44 | total_log_prob += log_prob 45 | return input, total_log_prob 46 | 47 | 48 | class MaskedLinear(nn.Module): 49 | """Linear layer with some input-output connections masked.""" 50 | 51 | def __init__( 52 | self, in_features, out_features, mask, context_features=None, bias=True 53 | ): 54 | super().__init__() 55 | self.linear = nn.Linear(in_features, out_features, bias) 56 | self.register_buffer("mask", mask) 57 | if context_features is not None: 58 | self.cond_linear = nn.Linear(context_features, out_features, bias=False) 59 | 60 | def forward(self, input, context=None): 61 | output = F.linear(input, self.mask * self.linear.weight, self.linear.bias) 62 | if context is None: 63 | return output 64 | else: 65 | return output + self.cond_linear(context) 66 | 67 | 68 | class MADE(nn.Module): 69 | """Implements MADE: Masked Autoencoder for Distribution Estimation. 70 | 71 | Follows https://arxiv.org/abs/1502.03509 72 | 73 | This is used to build MAF: Masked Autoregressive Flow (https://arxiv.org/abs/1705.07057). 74 | """ 75 | 76 | def __init__(self, num_input, num_outputs_per_input, num_hidden, num_context): 77 | super().__init__() 78 | # m corresponds to m(k), the maximum degree of a node in the MADE paper 79 | self._m = [] 80 | degrees = masks.create_degrees( 81 | input_size=num_input, 82 | hidden_units=[num_hidden] * 2, 83 | input_order="left-to-right", 84 | hidden_degrees="equal", 85 | ) 86 | self._masks = masks.create_masks(degrees) 87 | self._masks[-1] = np.hstack( 88 | [self._masks[-1] for _ in range(num_outputs_per_input)] 89 | ) 90 | self._masks = [torch.from_numpy(m.T) for m in self._masks] 91 | modules = [] 92 | self.input_context_net = MaskedLinear( 93 | num_input, num_hidden, self._masks[0], num_context 94 | ) 95 | self.net = nn.Sequential( 96 | nn.ReLU(), 97 | MaskedLinear(num_hidden, num_hidden, self._masks[1], context_features=None), 98 | nn.ReLU(), 99 | MaskedLinear( 100 | num_hidden, 101 | num_outputs_per_input * num_input, 102 | self._masks[2], 103 | context_features=None, 104 | ), 105 | ) 106 | 107 | def forward(self, input, context=None): 108 | # first hidden layer receives input and context 109 | hidden = self.input_context_net(input, context) 110 | # rest of the network is conditioned on both input and context 111 | return self.net(hidden) 112 | 113 | 114 | class Reverse(nn.Module): 115 | """An implementation of a reversing layer from 116 | Density estimation using Real NVP 117 | (https://arxiv.org/abs/1605.08803). 118 | 119 | From https://github.com/ikostrikov/pytorch-flows/blob/master/main.py 120 | """ 121 | 122 | def __init__(self, num_input): 123 | super(Reverse, self).__init__() 124 | self.perm = np.array(np.arange(0, num_input)[::-1]) 125 | self.inv_perm = np.argsort(self.perm) 126 | 127 | def forward(self, inputs, context=None, mode="forward"): 128 | if mode == "forward": 129 | return inputs[:, :, self.perm], torch.zeros_like( 130 | inputs, device=inputs.device 131 | ) 132 | elif mode == "inverse": 133 | return inputs[:, :, self.inv_perm], torch.zeros_like( 134 | inputs, device=inputs.device 135 | ) 136 | else: 137 | raise ValueError("Mode must be one of {forward, inverse}.") 138 | -------------------------------------------------------------------------------- /masks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """Use utility functions from https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/masked_autoregressive.py 4 | """ 5 | 6 | 7 | def create_input_order(input_size, input_order="left-to-right"): 8 | """Returns a degree vectors for the input.""" 9 | if input_order == "left-to-right": 10 | return np.arange(start=1, stop=input_size + 1) 11 | elif input_order == "right-to-left": 12 | return np.arange(start=input_size, stop=0, step=-1) 13 | elif input_order == "random": 14 | ret = np.arange(start=1, stop=input_size + 1) 15 | np.random.shuffle(ret) 16 | return ret 17 | 18 | 19 | def create_degrees( 20 | input_size, hidden_units, input_order="left-to-right", hidden_degrees="equal" 21 | ): 22 | input_order = create_input_order(input_size, input_order) 23 | degrees = [input_order] 24 | for units in hidden_units: 25 | if hidden_degrees == "random": 26 | # samples from: [low, high) 27 | degrees.append( 28 | np.random.randint( 29 | low=min(np.min(degrees[-1]), input_size - 1), 30 | high=input_size, 31 | size=units, 32 | ) 33 | ) 34 | elif hidden_degrees == "equal": 35 | min_degree = min(np.min(degrees[-1]), input_size - 1) 36 | degrees.append( 37 | np.maximum( 38 | min_degree, 39 | # Evenly divide the range `[1, input_size - 1]` in to `units + 1` 40 | # segments, and pick the boundaries between the segments as degrees. 41 | np.ceil( 42 | np.arange(1, units + 1) * (input_size - 1) / float(units + 1) 43 | ).astype(np.int32), 44 | ) 45 | ) 46 | return degrees 47 | 48 | 49 | def create_masks(degrees): 50 | """Returns a list of binary mask matrices enforcing autoregressivity.""" 51 | return [ 52 | # Create input->hidden and hidden->hidden masks. 53 | inp[:, np.newaxis] <= out 54 | for inp, out in zip(degrees[:-1], degrees[1:]) 55 | ] + [ 56 | # Create hidden->output mask. 57 | degrees[-1][:, np.newaxis] 58 | < degrees[0] 59 | ] 60 | 61 | 62 | def check_masks(masks): 63 | """Check that the connectivity matrix between layers is lower triangular.""" 64 | # (num_input, num_hidden) 65 | prev = masks[0].t() 66 | for i in range(1, len(masks)): 67 | # num_hidden is second axis 68 | prev = prev @ masks[i].t() 69 | final = prev.numpy() 70 | num_input = masks[0].shape[1] 71 | num_output = masks[-1].shape[0] 72 | assert final.shape == (num_input, num_output) 73 | if num_output == num_input: 74 | assert np.triu(final).all() == 0 75 | else: 76 | for submat in np.split( 77 | final, indices_or_sections=num_output // num_input, axis=1 78 | ): 79 | assert np.triu(submat).all() == 0 80 | 81 | 82 | def build_random_masks(num_input, num_output, num_hidden, num_layers): 83 | """Build the masks according to Eq 12 and 13 in the MADE paper.""" 84 | # assign input units a number between 1 and D 85 | rng = np.random.RandomState(0) 86 | m_list, masks = [], [] 87 | m_list.append(np.arange(1, num_input + 1)) 88 | for i in range(1, num_layers + 1): 89 | if i == num_layers: 90 | # assign output layer units a number between 1 and D 91 | m = np.arange(1, num_input + 1) 92 | assert ( 93 | num_output % num_input == 0 94 | ), "num_output must be multiple of num_input" 95 | m_list.append(np.hstack([m for _ in range(num_output // num_input)])) 96 | else: 97 | # assign hidden layer units a number between 1 and D-1 98 | # i.e. randomly assign maximum number of input nodes to connect to 99 | m_list.append(rng.randint(1, num_input, size=num_hidden)) 100 | if i == num_layers: 101 | mask = m_list[i][None, :] > m_list[i - 1][:, None] 102 | else: 103 | # input to hidden & hidden to hidden 104 | mask = m_list[i][None, :] >= m_list[i - 1][:, None] 105 | # need to transpose for torch linear layer, shape (num_output, num_input) 106 | masks.append(mask.astype(np.float32).T) 107 | return masks 108 | 109 | 110 | def _compute_neighborhood(system_size): 111 | """Compute (system_size, neighborhood_size) array.""" 112 | num_variables = system_size ** 2 113 | arange = np.arange(num_variables) 114 | grid = arange.reshape((system_size, system_size)) 115 | self_and_neighbors = np.zeros((system_size, system_size, 5), dtype=int) 116 | # four nearest-neighbors 117 | self_and_neighbors = np.zeros((system_size, system_size, 5), dtype=int) 118 | self_and_neighbors[..., 0] = grid 119 | neighbor_index = 1 120 | for axis in [0, 1]: 121 | for shift in [-1, 1]: 122 | self_and_neighbors[..., neighbor_index] = np.roll( 123 | grid, shift=shift, axis=axis 124 | ) 125 | neighbor_index += 1 126 | # reshape to (num_latent, num_neighbors) 127 | self_and_neighbors = self_and_neighbors.reshape(num_variables, -1) 128 | return self_and_neighbors 129 | 130 | 131 | def build_neighborhood_indicator(system_size): 132 | """Boolean indicator of (num_variables, num_variables) for whether nodes are neighbors.""" 133 | neighborhood = _compute_neighborhood(system_size) 134 | num_variables = system_size ** 2 135 | mask = np.zeros((num_variables, num_variables), dtype=bool) 136 | for i in range(len(mask)): 137 | mask[i, neighborhood[i]] = True 138 | return mask 139 | 140 | 141 | def build_deterministic_mask(num_variables, num_input, num_output, mask_type): 142 | if mask_type == "input": 143 | in_degrees = np.arange(num_input) % num_variables 144 | else: 145 | in_degrees = np.arange(num_input) % (num_variables - 1) 146 | 147 | if mask_type == "output": 148 | out_degrees = np.arange(num_output) % num_variables 149 | mask = np.expand_dims(out_degrees, -1) > np.expand_dims(in_degrees, 0) 150 | else: 151 | out_degrees = np.arange(num_output) % (num_variables - 1) 152 | mask = np.expand_dims(out_degrees, -1) >= np.expand_dims(in_degrees, 0) 153 | 154 | return mask, in_degrees, out_degrees 155 | 156 | 157 | def build_masks(num_variables, num_input, num_output, num_hidden, mask_fn): 158 | input_mask, _, _ = mask_fn(num_variables, num_input, num_hidden, "input") 159 | hidden_mask, _, _ = mask_fn(num_variables, num_hidden, num_hidden, "hidden") 160 | output_mask, _, _ = mask_fn(num_variables, num_hidden, num_output, "output") 161 | masks = [input_mask, hidden_mask, output_mask] 162 | masks = [torch.from_numpy(x.astype(np.float32)) for x in masks] 163 | return masks 164 | 165 | 166 | def build_neighborhood_mask(num_variables, num_input, num_output, mask_type): 167 | system_size = int(np.sqrt(num_variables)) 168 | # return context mask for input, with same assignment of m(k) maximum node degree 169 | mask, in_degrees, out_degrees = build_deterministic_mask( 170 | system_size ** 2, num_input, num_output, mask_type 171 | ) 172 | neighborhood = _compute_neighborhood(system_size) 173 | neighborhood_mask = np.zeros_like(mask) # shape len(out_degrees), len(in_degrees) 174 | for i in range(len(neighborhood_mask)): 175 | neighborhood_indicator = np.isin(in_degrees, neighborhood[out_degrees[i]]) 176 | neighborhood_mask[i, neighborhood_indicator] = True 177 | return mask * neighborhood_mask, in_degrees, out_degrees 178 | 179 | 180 | def checkerboard(shape): 181 | return (np.indices(shape).sum(0) % 2).astype(np.float32) 182 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 -------------------------------------------------------------------------------- /train_variational_autoencoder_jax.py: -------------------------------------------------------------------------------- 1 | """Train variational autoencoder or binary MNIST data. 2 | 3 | Largely follows https://github.com/deepmind/dm-haiku/blob/master/examples/vae.py""" 4 | 5 | import time 6 | import argparse 7 | from typing import Generator, Mapping, Sequence, Tuple, Optional 8 | 9 | import numpy as np 10 | import jax 11 | from jax import lax 12 | import haiku as hk 13 | import jax.numpy as jnp 14 | import optax 15 | import tensorflow_datasets as tfds 16 | from tensorflow_probability.substrates import jax as tfp 17 | 18 | import masks 19 | 20 | tfd = tfp.distributions 21 | tfb = tfp.bijectors 22 | 23 | Batch = Mapping[str, np.ndarray] 24 | MNIST_IMAGE_SHAPE: Sequence[int] = (28, 28, 1) 25 | PRNGKey = jnp.ndarray 26 | 27 | 28 | def add_args(parser): 29 | parser.add_argument("--variational", choices=["flow", "mean-field"]) 30 | parser.add_argument("--latent_size", type=int, default=128) 31 | parser.add_argument("--hidden_size", type=int, default=512) 32 | parser.add_argument("--learning_rate", type=float, default=0.001) 33 | parser.add_argument("--batch_size", type=int, default=128) 34 | parser.add_argument("--training_steps", type=int, default=30000) 35 | parser.add_argument("--log_interval", type=int, default=10000) 36 | parser.add_argument("--num_importance_samples", type=int, default=1000) 37 | parser.add_argument("--random_seed", type=int, default=42) 38 | 39 | 40 | def load_dataset( 41 | split: str, batch_size: int, seed: int, repeat: bool = False 42 | ) -> Generator[Batch, None, None]: 43 | ds = tfds.load( 44 | "binarized_mnist", 45 | split=split, 46 | shuffle_files=True, 47 | read_config=tfds.ReadConfig(shuffle_seed=seed), 48 | ) 49 | ds = ds.shuffle(buffer_size=10 * batch_size, seed=seed) 50 | ds = ds.batch(batch_size) 51 | ds = ds.prefetch(buffer_size=5) 52 | if repeat: 53 | ds = ds.repeat() 54 | return iter(tfds.as_numpy(ds)) 55 | 56 | 57 | class Model(hk.Module): 58 | """Deep latent Gaussian model or variational autoencoder.""" 59 | 60 | def __init__( 61 | self, 62 | latent_size: int, 63 | hidden_size: int, 64 | output_shape: Sequence[int] = MNIST_IMAGE_SHAPE, 65 | ): 66 | super().__init__(name="model") 67 | self._latent_size = latent_size 68 | self._hidden_size = hidden_size 69 | self._output_shape = output_shape 70 | self.generative_network = hk.Sequential( 71 | [ 72 | hk.Linear(self._hidden_size), 73 | jax.nn.relu, 74 | hk.Linear(self._hidden_size), 75 | jax.nn.relu, 76 | hk.Linear(np.prod(self._output_shape)), 77 | hk.Reshape(self._output_shape, preserve_dims=2), 78 | ] 79 | ) 80 | 81 | def __call__(self, x: jnp.ndarray, z: jnp.ndarray) -> jnp.ndarray: 82 | """Compute log probability""" 83 | p_z = tfd.Normal( 84 | loc=jnp.zeros(self._latent_size, dtype=jnp.float32), 85 | scale=jnp.ones(self._latent_size, dtype=jnp.float32), 86 | ) 87 | # sum over latent dimensions 88 | log_p_z = p_z.log_prob(z).sum(-1) 89 | logits = self.generative_network(z) 90 | p_x_given_z = tfd.Bernoulli(logits=logits) 91 | # sum over last three image dimensions (width, height, channels) 92 | log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1)) 93 | return log_p_z + log_p_x_given_z 94 | 95 | 96 | class VariationalMeanField(hk.Module): 97 | """Mean field variational distribution q(z | x) parameterized by inference network.""" 98 | 99 | def __init__(self, latent_size: int, hidden_size: int): 100 | super().__init__(name="variational") 101 | self._latent_size = latent_size 102 | self._hidden_size = hidden_size 103 | self.inference_network = hk.Sequential( 104 | [ 105 | hk.Flatten(), 106 | hk.Linear(self._hidden_size), 107 | jax.nn.relu, 108 | hk.Linear(self._hidden_size), 109 | jax.nn.relu, 110 | hk.Linear(self._latent_size * 2), 111 | ] 112 | ) 113 | 114 | def condition(self, inputs): 115 | """Compute parameters of a multivariate independent Normal distribution based on the inputs.""" 116 | out = self.inference_network(inputs) 117 | loc, scale_arg = jnp.split(out, 2, axis=-1) 118 | scale = jax.nn.softplus(scale_arg) 119 | return loc, scale 120 | 121 | def __call__( 122 | self, x: jnp.ndarray, num_samples: int 123 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 124 | """Compute sample and log probability""" 125 | loc, scale = self.condition(x) 126 | # IMPORTANT: need to check in source code that reparameterization_type=tfd.FULLY_REPARAMETERIZED for this class 127 | q_z = tfd.Normal(loc=loc, scale=scale) 128 | z = q_z.sample(sample_shape=[num_samples], seed=hk.next_rng_key()) 129 | # sum over latent dimension 130 | log_q_z = q_z.log_prob(z).sum(-1) 131 | return z, log_q_z 132 | 133 | 134 | class VariationalFlow(hk.Module): 135 | """Uses masked autoregressive networks and a shift scale transform. 136 | 137 | Follows Algorithm 1 from the Inverse Autoregressive Flow paper, Kingma et al. (2016) https://arxiv.org/abs/1606.04934. 138 | """ 139 | 140 | def __init__(self, latent_size: int, hidden_size: int): 141 | super().__init__(name="variational") 142 | self.encoder = hk.Sequential( 143 | [ 144 | hk.Flatten(), 145 | hk.Linear(hidden_size), 146 | jax.nn.relu, 147 | hk.Linear(hidden_size), 148 | jax.nn.relu, 149 | hk.Linear(latent_size * 3, w_init=jnp.zeros, b_init=jnp.zeros), 150 | ] 151 | ) 152 | self.first_block = InverseAutoregressiveFlow(latent_size, hidden_size) 153 | self.second_block = InverseAutoregressiveFlow(latent_size, hidden_size) 154 | 155 | def __call__( 156 | self, x: jnp.ndarray, num_samples: int 157 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 158 | """Compute sample and log probability.""" 159 | loc, scale_arg, h = jnp.split(self.encoder(x), 3, axis=-1) 160 | q_z0 = tfd.Normal(loc=loc, scale=jax.nn.softplus(scale_arg)) 161 | z0 = q_z0.sample(sample_shape=[num_samples], seed=hk.next_rng_key()) 162 | h = jnp.expand_dims(h, axis=0) # needed for the new sample dimension in z0 163 | log_q_z0 = q_z0.log_prob(z0).sum(-1) 164 | z1, log_det_q_z1 = self.first_block(z0, context=h) 165 | z2, log_det_q_z2 = self.second_block(z1, context=h) 166 | return z2, log_q_z0 + log_det_q_z1 + log_det_q_z2 167 | 168 | 169 | class MaskedLinear(hk.Module): 170 | """Masked Linear module. 171 | 172 | TODO: fix initialization according to number of inputs per unit 173 | (can compute this from the mask). 174 | """ 175 | 176 | def __init__( 177 | self, 178 | mask: jnp.ndarray, 179 | output_size: int, 180 | with_bias: bool = True, 181 | w_init: Optional[hk.initializers.Initializer] = None, 182 | b_init: Optional[hk.initializers.Initializer] = None, 183 | name: Optional[str] = None, 184 | ): 185 | super().__init__(name=name) 186 | self.input_size = None 187 | self.output_size = output_size 188 | self.with_bias = with_bias 189 | self.w_init = w_init 190 | self.b_init = b_init or jnp.zeros 191 | self._mask = mask 192 | 193 | def __call__( 194 | self, 195 | inputs: jnp.ndarray, 196 | *, 197 | precision: Optional[lax.Precision] = None, 198 | ) -> jnp.ndarray: 199 | """Computes a masked linear transform of the input.""" 200 | if not inputs.shape: 201 | raise ValueError("Input must not be scalar.") 202 | 203 | input_size = self.input_size = inputs.shape[-1] 204 | output_size = self.output_size 205 | dtype = inputs.dtype 206 | 207 | w_init = self.w_init 208 | if w_init is None: 209 | stddev = 1.0 / np.sqrt(self.input_size) 210 | w_init = hk.initializers.TruncatedNormal(stddev=stddev) 211 | w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init) 212 | 213 | out = jnp.dot(inputs, w * self._mask, precision=precision) 214 | 215 | if self.with_bias: 216 | b = hk.get_parameter("b", [self.output_size], dtype, init=self.b_init) 217 | b = jnp.broadcast_to(b, out.shape) 218 | out = out + b 219 | 220 | return out 221 | 222 | 223 | class MaskedAndConditionalLinear(hk.Module): 224 | """Assumes the conditional inputs have same size as inputs.""" 225 | 226 | def __init__(self, mask: jnp.ndarray, output_size: int, **kwargs): 227 | super().__init__() 228 | self.masked_linear = MaskedLinear(mask, output_size, **kwargs) 229 | self.conditional_linear = hk.Linear(output_size, with_bias=False, **kwargs) 230 | 231 | def __call__( 232 | self, inputs: jnp.ndarray, conditional_inputs: jnp.ndarray 233 | ) -> jnp.ndarray: 234 | return self.masked_linear(inputs) + self.conditional_linear(conditional_inputs) 235 | 236 | 237 | class MADE(hk.Module): 238 | """Masked Autoregressive Distribution Estimator. 239 | 240 | From https://arxiv.org/abs/1502.03509 241 | 242 | conditional_input specifies whether every layer of the network will be 243 | conditioned on an additional input. 244 | The additional input is conditioned on using a linear transformation 245 | (that does not use a mask) 246 | """ 247 | 248 | def __init__(self, input_size: int, hidden_size: int, num_outputs_per_input: int): 249 | super().__init__() 250 | self._num_outputs_per_input = num_outputs_per_input 251 | degrees = masks.create_degrees( 252 | input_size=input_size, 253 | hidden_units=[hidden_size] * 2, 254 | input_order="left-to-right", 255 | hidden_degrees="equal", 256 | ) 257 | self._masks = masks.create_masks(degrees) 258 | self._masks[-1] = np.hstack( 259 | [self._masks[-1] for _ in range(num_outputs_per_input)] 260 | ) 261 | self._input_size = input_size 262 | self._first_net = MaskedAndConditionalLinear(self._masks[0], hidden_size) 263 | self._second_net = MaskedAndConditionalLinear(self._masks[1], hidden_size) 264 | # multiply by two for the shift and log scale 265 | # initialize weights and biases to zero to init to the identity function 266 | self._final_net = MaskedAndConditionalLinear( 267 | self._masks[2], 268 | input_size * num_outputs_per_input, 269 | w_init=jnp.zeros, 270 | b_init=jnp.zeros, 271 | ) 272 | 273 | def __call__(self, inputs, conditional_inputs): 274 | outputs = jax.nn.relu(self._first_net(inputs, conditional_inputs)) 275 | outputs = outputs[::-1] # reverse 276 | outputs = jax.nn.relu(self._second_net(outputs, conditional_inputs)) 277 | outputs = outputs[::-1] # reverse 278 | outputs = self._final_net(outputs, conditional_inputs) 279 | return jnp.split(outputs, self._num_outputs_per_input, axis=-1) 280 | 281 | 282 | class InverseAutoregressiveFlow(hk.Module): 283 | def __init__(self, latent_size: int, hidden_size: int): 284 | super().__init__() 285 | # two outputs per latent input: shift and log scale parameter 286 | self._made = MADE( 287 | input_size=latent_size, hidden_size=hidden_size, num_outputs_per_input=2 288 | ) 289 | 290 | def __call__(self, inputs: jnp.ndarray, context: jnp.ndarray): 291 | m, s = self._made(inputs, conditional_inputs=context) 292 | # initialize sigmoid argument bias so the output is close to 1 293 | sigmoid = jax.nn.sigmoid(s + 2.0) 294 | z = sigmoid * inputs + (1 - sigmoid) * m 295 | return z, -jax.nn.log_sigmoid(s).sum(-1) 296 | 297 | 298 | def main(): 299 | start_time = time.time() 300 | parser = argparse.ArgumentParser() 301 | add_args(parser) 302 | args = parser.parse_args() 303 | print(args) 304 | print("Is jax using @jit decorators?", not jax.config.read("jax_disable_jit")) 305 | rng_seq = hk.PRNGSequence(args.random_seed) 306 | p_log_prob = hk.transform( 307 | lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)( 308 | x=x, z=z 309 | ) 310 | ) 311 | if args.variational == "mean-field": 312 | variational = VariationalMeanField 313 | elif args.variational == "flow": 314 | variational = VariationalFlow 315 | q_sample_and_log_prob = hk.transform( 316 | lambda x, num_samples: variational(args.latent_size, args.hidden_size)( 317 | x, num_samples 318 | ) 319 | ) 320 | p_params = p_log_prob.init( 321 | next(rng_seq), 322 | z=np.zeros((1, args.latent_size), dtype=np.float32), 323 | x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), 324 | ) 325 | q_params = q_sample_and_log_prob.init( 326 | next(rng_seq), 327 | x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), 328 | num_samples=1, 329 | ) 330 | optimizer = optax.rmsprop(args.learning_rate) 331 | params = (p_params, q_params) 332 | opt_state = optimizer.init(params) 333 | 334 | @jax.jit 335 | def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray: 336 | """Objective function is negative ELBO.""" 337 | x = batch["image"] 338 | p_params, q_params = params 339 | z, log_q_z = q_sample_and_log_prob.apply(q_params, rng_key, x=x, num_samples=1) 340 | log_p_x_z = p_log_prob.apply(p_params, rng_key, x=x, z=z) 341 | elbo = log_p_x_z - log_q_z 342 | # average elbo over number of samples 343 | elbo = elbo.mean(axis=0) 344 | # sum elbo over batch 345 | elbo = elbo.sum(axis=0) 346 | return -elbo 347 | 348 | @jax.jit 349 | def train_step( 350 | params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch 351 | ) -> Tuple[hk.Params, optax.OptState]: 352 | """Single update step to maximize the ELBO.""" 353 | grads = jax.grad(objective_fn)(params, rng_key, batch) 354 | updates, new_opt_state = optimizer.update(grads, opt_state) 355 | new_params = optax.apply_updates(params, updates) 356 | return new_params, new_opt_state 357 | 358 | @jax.jit 359 | def importance_weighted_estimate( 360 | params: hk.Params, rng_key: PRNGKey, batch: Batch 361 | ) -> Tuple[jnp.ndarray, jnp.ndarray]: 362 | """Estimate marginal log p(x) using importance sampling.""" 363 | x = batch["image"] 364 | p_params, q_params = params 365 | z, log_q_z = q_sample_and_log_prob.apply( 366 | q_params, rng_key, x=x, num_samples=args.num_importance_samples 367 | ) 368 | log_p_x_z = p_log_prob.apply(p_params, rng_key, x, z) 369 | elbo = log_p_x_z - log_q_z 370 | # importance sampling of approximate marginal likelihood with q(z) 371 | # as the proposal, and logsumexp in the sample dimension 372 | log_p_x = jax.nn.logsumexp(elbo, axis=0) - jnp.log(jnp.shape(elbo)[0]) 373 | # sum over the elements of the minibatch 374 | log_p_x = log_p_x.sum(0) 375 | # average elbo over number of samples 376 | elbo = elbo.mean(axis=0) 377 | # sum elbo over batch 378 | elbo = elbo.sum(axis=0) 379 | return elbo, log_p_x 380 | 381 | def evaluate( 382 | dataset: Generator[Batch, None, None], 383 | params: hk.Params, 384 | rng_seq: hk.PRNGSequence, 385 | ) -> Tuple[float, float]: 386 | total_elbo = 0.0 387 | total_log_p_x = 0.0 388 | dataset_size = 0 389 | for batch in dataset: 390 | elbo, log_p_x = importance_weighted_estimate(params, next(rng_seq), batch) 391 | total_elbo += elbo 392 | total_log_p_x += log_p_x 393 | dataset_size += len(batch["image"]) 394 | return total_elbo / dataset_size, total_log_p_x / dataset_size 395 | 396 | train_ds = load_dataset( 397 | tfds.Split.TRAIN, args.batch_size, args.random_seed, repeat=True 398 | ) 399 | test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed) 400 | 401 | def print_progress(step: int, examples_per_sec: float): 402 | valid_ds = load_dataset( 403 | tfds.Split.VALIDATION, args.batch_size, args.random_seed 404 | ) 405 | elbo, log_p_x = evaluate(valid_ds, params, rng_seq) 406 | train_elbo = ( 407 | -objective_fn(params, next(rng_seq), next(train_ds)) / args.batch_size 408 | ) 409 | print( 410 | f"Step {step:<10d}\t" 411 | f"Train ELBO estimate: {train_elbo:<5.3f}\t" 412 | f"Validation ELBO estimate: {elbo:<5.3f}\t" 413 | f"Validation log p(x) estimate: {log_p_x:<5.3f}\t" 414 | f"Speed: {examples_per_sec:<5.2e} examples/s" 415 | ) 416 | 417 | t0 = time.time() 418 | for step in range(args.training_steps): 419 | if step % args.log_interval == 0: 420 | t1 = time.time() 421 | examples_per_sec = args.log_interval * args.batch_size / (t1 - t0) 422 | print_progress(step, examples_per_sec) 423 | t0 = t1 424 | params, opt_state = train_step(params, next(rng_seq), opt_state, next(train_ds)) 425 | 426 | test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed) 427 | elbo, log_p_x = evaluate(test_ds, params, rng_seq) 428 | print( 429 | f"Step {step:<10d}\t" 430 | f"Test ELBO estimate: {elbo:<5.3f}\t" 431 | f"Test log p(x) estimate: {log_p_x:<5.3f}\t" 432 | ) 433 | print(f"Total time: {(time.time() - start_time) / 60:.3f} minutes") 434 | 435 | 436 | if __name__ == "__main__": 437 | main() 438 | -------------------------------------------------------------------------------- /train_variational_autoencoder_pytorch.py: -------------------------------------------------------------------------------- 1 | """Train variational autoencoder on binary MNIST data.""" 2 | 3 | import numpy as np 4 | import random 5 | import time 6 | 7 | import torch 8 | import torch.utils 9 | import torch.utils.data 10 | from torch import nn 11 | 12 | import data 13 | import flow 14 | import argparse 15 | import pathlib 16 | 17 | 18 | def add_args(parser): 19 | parser.add_argument("--latent_size", type=int, default=128) 20 | parser.add_argument("--variational", choices=["flow", "mean-field"]) 21 | parser.add_argument("--flow_depth", type=int, default=2) 22 | parser.add_argument("--data_size", type=int, default=784) 23 | parser.add_argument("--learning_rate", type=float, default=0.001) 24 | parser.add_argument("--batch_size", type=int, default=128) 25 | parser.add_argument("--test_batch_size", type=int, default=512) 26 | parser.add_argument("--max_iterations", type=int, default=30000) 27 | parser.add_argument("--log_interval", type=int, default=10000) 28 | parser.add_argument("--n_samples", type=int, default=1000) 29 | parser.add_argument("--use_gpu", action="store_true") 30 | parser.add_argument("--seed", type=int, default=582838) 31 | parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp") 32 | parser.add_argument("--data_dir", type=pathlib.Path, default="/tmp") 33 | 34 | 35 | class Model(nn.Module): 36 | """Variational autoencoder, parameterized by a generative network.""" 37 | 38 | def __init__(self, latent_size, data_size): 39 | super().__init__() 40 | self.register_buffer("p_z_loc", torch.zeros(latent_size)) 41 | self.register_buffer("p_z_scale", torch.ones(latent_size)) 42 | self.log_p_z = NormalLogProb() 43 | self.log_p_x = BernoulliLogProb() 44 | self.generative_network = NeuralNetwork( 45 | input_size=latent_size, output_size=data_size, hidden_size=latent_size * 2 46 | ) 47 | 48 | def forward(self, z, x): 49 | """Return log probability of model.""" 50 | log_p_z = self.log_p_z(self.p_z_loc, self.p_z_scale, z).sum(-1, keepdim=True) 51 | logits = self.generative_network(z) 52 | # unsqueeze sample dimension 53 | logits, x = torch.broadcast_tensors(logits, x.unsqueeze(1)) 54 | log_p_x = self.log_p_x(logits, x).sum(-1, keepdim=True) 55 | return log_p_z + log_p_x 56 | 57 | 58 | class VariationalMeanField(nn.Module): 59 | """Approximate posterior parameterized by an inference network.""" 60 | 61 | def __init__(self, latent_size, data_size): 62 | super().__init__() 63 | self.inference_network = NeuralNetwork( 64 | input_size=data_size, 65 | output_size=latent_size * 2, 66 | hidden_size=latent_size * 2, 67 | ) 68 | self.log_q_z = NormalLogProb() 69 | self.softplus = nn.Softplus() 70 | 71 | def forward(self, x, n_samples=1): 72 | """Return sample of latent variable and log prob.""" 73 | loc, scale_arg = torch.chunk( 74 | self.inference_network(x).unsqueeze(1), chunks=2, dim=-1 75 | ) 76 | scale = self.softplus(scale_arg) 77 | eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) 78 | z = loc + scale * eps # reparameterization 79 | log_q_z = self.log_q_z(loc, scale, z).sum(-1, keepdim=True) 80 | return z, log_q_z 81 | 82 | 83 | class VariationalFlow(nn.Module): 84 | """Approximate posterior parameterized by a flow (https://arxiv.org/abs/1606.04934).""" 85 | 86 | def __init__(self, latent_size, data_size, flow_depth): 87 | super().__init__() 88 | hidden_size = latent_size * 2 89 | self.inference_network = NeuralNetwork( 90 | input_size=data_size, 91 | # loc, scale, and context 92 | output_size=latent_size * 3, 93 | hidden_size=hidden_size, 94 | ) 95 | modules = [] 96 | for _ in range(flow_depth): 97 | modules.append( 98 | flow.InverseAutoregressiveFlow( 99 | num_input=latent_size, 100 | num_hidden=hidden_size, 101 | num_context=latent_size, 102 | ) 103 | ) 104 | modules.append(flow.Reverse(latent_size)) 105 | self.q_z_flow = flow.FlowSequential(*modules) 106 | self.log_q_z_0 = NormalLogProb() 107 | self.softplus = nn.Softplus() 108 | 109 | def forward(self, x, n_samples=1): 110 | """Return sample of latent variable and log prob.""" 111 | loc, scale_arg, h = torch.chunk( 112 | self.inference_network(x).unsqueeze(1), chunks=3, dim=-1 113 | ) 114 | scale = self.softplus(scale_arg) 115 | eps = torch.randn((loc.shape[0], n_samples, loc.shape[-1]), device=loc.device) 116 | z_0 = loc + scale * eps # reparameterization 117 | log_q_z_0 = self.log_q_z_0(loc, scale, z_0) 118 | z_T, log_q_z_flow = self.q_z_flow(z_0, context=h) 119 | log_q_z = (log_q_z_0 + log_q_z_flow).sum(-1, keepdim=True) 120 | return z_T, log_q_z 121 | 122 | 123 | class NeuralNetwork(nn.Module): 124 | def __init__(self, input_size, output_size, hidden_size): 125 | super().__init__() 126 | modules = [ 127 | nn.Linear(input_size, hidden_size), 128 | nn.ReLU(), 129 | nn.Linear(hidden_size, hidden_size), 130 | nn.ReLU(), 131 | nn.Linear(hidden_size, output_size), 132 | ] 133 | self.net = nn.Sequential(*modules) 134 | 135 | def forward(self, input): 136 | return self.net(input) 137 | 138 | 139 | class NormalLogProb(nn.Module): 140 | def __init__(self): 141 | super().__init__() 142 | 143 | def forward(self, loc, scale, z): 144 | var = torch.pow(scale, 2) 145 | return -0.5 * torch.log(2 * np.pi * var) - torch.pow(z - loc, 2) / (2 * var) 146 | 147 | 148 | class BernoulliLogProb(nn.Module): 149 | def __init__(self): 150 | super().__init__() 151 | self.bce_with_logits = nn.BCEWithLogitsLoss(reduction="none") 152 | 153 | def forward(self, logits, target): 154 | # bernoulli log prob is equivalent to negative binary cross entropy 155 | return -self.bce_with_logits(logits, target) 156 | 157 | 158 | def cycle(iterable): 159 | while True: 160 | for x in iterable: 161 | yield x 162 | 163 | 164 | @torch.no_grad() 165 | def evaluate(n_samples, model, variational, eval_data): 166 | model.eval() 167 | total_log_p_x = 0.0 168 | total_elbo = 0.0 169 | for batch in eval_data: 170 | x = batch[0].to(next(model.parameters()).device) 171 | z, log_q_z = variational(x, n_samples) 172 | log_p_x_and_z = model(z, x) 173 | # importance sampling of approximate marginal likelihood with q(z) 174 | # as the proposal, and logsumexp in the sample dimension 175 | elbo = log_p_x_and_z - log_q_z 176 | log_p_x = torch.logsumexp(elbo, dim=1) - np.log(n_samples) 177 | # average over sample dimension, sum over minibatch 178 | total_elbo += elbo.cpu().numpy().mean(1).sum() 179 | # sum over minibatch 180 | total_log_p_x += log_p_x.cpu().numpy().sum() 181 | n_data = len(eval_data.dataset) 182 | return total_elbo / n_data, total_log_p_x / n_data 183 | 184 | 185 | if __name__ == "__main__": 186 | start_time = time.time() 187 | parser = argparse.ArgumentParser() 188 | add_args(parser) 189 | cfg = parser.parse_args() 190 | 191 | device = torch.device("cuda:0" if cfg.use_gpu else "cpu") 192 | torch.manual_seed(cfg.seed) 193 | np.random.seed(cfg.seed) 194 | random.seed(cfg.seed) 195 | 196 | model = Model(latent_size=cfg.latent_size, data_size=cfg.data_size) 197 | if cfg.variational == "flow": 198 | variational = VariationalFlow( 199 | latent_size=cfg.latent_size, 200 | data_size=cfg.data_size, 201 | flow_depth=cfg.flow_depth, 202 | ) 203 | elif cfg.variational == "mean-field": 204 | variational = VariationalMeanField( 205 | latent_size=cfg.latent_size, data_size=cfg.data_size 206 | ) 207 | else: 208 | raise ValueError( 209 | "Variational distribution not implemented: %s" % cfg.variational 210 | ) 211 | 212 | model.to(device) 213 | variational.to(device) 214 | 215 | optimizer = torch.optim.RMSprop( 216 | list(model.parameters()) + list(variational.parameters()), 217 | lr=cfg.learning_rate, 218 | centered=True, 219 | ) 220 | 221 | fname = cfg.data_dir / "binary_mnist.h5" 222 | if not fname.exists(): 223 | print("Downloading binary MNIST data...") 224 | data.download_binary_mnist(fname) 225 | train_data, valid_data, test_data = data.load_binary_mnist( 226 | fname, cfg.batch_size, cfg.test_batch_size, cfg.use_gpu 227 | ) 228 | 229 | best_valid_elbo = -np.inf 230 | num_no_improvement = 0 231 | train_ds = cycle(train_data) 232 | t0 = time.time() 233 | 234 | for step in range(cfg.max_iterations): 235 | batch = next(train_ds) 236 | x = batch[0].to(device) 237 | model.zero_grad() 238 | variational.zero_grad() 239 | z, log_q_z = variational(x, n_samples=1) 240 | log_p_x_and_z = model(z, x) 241 | # average over sample dimension 242 | elbo = (log_p_x_and_z - log_q_z).mean(1) 243 | # sum over batch dimension 244 | loss = -elbo.sum(0) 245 | loss.backward() 246 | optimizer.step() 247 | 248 | if step % cfg.log_interval == 0: 249 | t1 = time.time() 250 | examples_per_sec = cfg.log_interval * cfg.batch_size / (t1 - t0) 251 | with torch.no_grad(): 252 | valid_elbo, valid_log_p_x = evaluate( 253 | cfg.n_samples, model, variational, valid_data 254 | ) 255 | print( 256 | f"Step {step:<10d}\t" 257 | f"Train ELBO estimate: {elbo.detach().cpu().numpy().mean():<5.3f}\t" 258 | f"Validation ELBO estimate: {valid_elbo:<5.3f}\t" 259 | f"Validation log p(x) estimate: {valid_log_p_x:<5.3f}\t" 260 | f"Speed: {examples_per_sec:<5.2e} examples/s" 261 | ) 262 | if valid_elbo > best_valid_elbo: 263 | num_no_improvement = 0 264 | best_valid_elbo = valid_elbo 265 | states = { 266 | "model": model.state_dict(), 267 | "variational": variational.state_dict(), 268 | } 269 | torch.save(states, cfg.train_dir / "best_state_dict") 270 | t0 = t1 271 | 272 | checkpoint = torch.load(cfg.train_dir / "best_state_dict") 273 | model.load_state_dict(checkpoint["model"]) 274 | variational.load_state_dict(checkpoint["variational"]) 275 | test_elbo, test_log_p_x = evaluate(cfg.n_samples, model, variational, test_data) 276 | print( 277 | f"Step {step:<10d}\t" 278 | f"Test ELBO estimate: {test_elbo:<5.3f}\t" 279 | f"Test log p(x) estimate: {test_log_p_x:<5.3f}\t" 280 | ) 281 | 282 | print(f"Total time: {(time.time() - start_time) / 60:.2f} minutes") 283 | -------------------------------------------------------------------------------- /train_variational_autoencoder_tensorflow.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | import tensorflow.keras as tfk 6 | import tensorflow.contrib.slim as slim 7 | import time 8 | import tensorflow_datasets as tfds 9 | import tensorflow_probability as tfp 10 | from imageio import imwrite 11 | from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets 12 | tfkl = tfk.layers 13 | tfc = tf.compat.v1 14 | 15 | flags = tf.app.flags 16 | flags.DEFINE_string('data_dir', '/tmp/dat/', 'Directory for data') 17 | flags.DEFINE_string('logdir', '/tmp/log/', 'Directory for logs') 18 | flags.DEFINE_integer('latent_dim', 100, 'Latent dimensionality of model') 19 | flags.DEFINE_integer('batch_size', 64, 'Minibatch size') 20 | flags.DEFINE_integer('n_samples', 1, 'Number of samples to save') 21 | flags.DEFINE_integer('print_every', 1000, 'Print every n iterations') 22 | flags.DEFINE_integer('hidden_size', 200, 'Hidden size for neural networks') 23 | flags.DEFINE_integer('n_iterations', 100000, 'number of iterations') 24 | 25 | FLAGS = flags.FLAGS 26 | 27 | 28 | def inference_network(x, latent_dim, hidden_size): 29 | """Construct an inference network parametrizing a Gaussian. 30 | 31 | Args: 32 | x: A batch of MNIST digits. 33 | latent_dim: The latent dimensionality. 34 | hidden_size: The size of the neural net hidden layers. 35 | 36 | Returns: 37 | mu: Mean parameters for the variational family Normal 38 | sigma: Standard deviation parameters for the variational family Normal 39 | """ 40 | inference_net = tfk.Sequential([ 41 | tfkl.Flatten(), 42 | tfkl.Dense(hidden_size, activation=tf.nn.relu), 43 | tfkl.Dense(hidden_size, activation=tf.nn.relu), 44 | tfkl.Dense(latent_dim * 2, activation=None) 45 | ]) 46 | gaussian_params = inference_net(x) 47 | # The mean parameter is unconstrained 48 | mu = gaussian_params[:, :latent_dim] 49 | # The standard deviation must be positive. Parametrize with a softplus 50 | sigma = tf.nn.softplus(gaussian_params[:, latent_dim:]) 51 | return mu, sigma 52 | 53 | 54 | def generative_network(z, hidden_size): 55 | """Build a generative network parametrizing the likelihood of the data 56 | 57 | Args: 58 | z: Samples of latent variables 59 | hidden_size: Size of the hidden state of the neural net 60 | 61 | Returns: 62 | bernoulli_logits: logits for the Bernoulli likelihood of the data 63 | """ 64 | generative_net = tfk.Sequential([ 65 | tfkl.Dense(hidden_size, activation=tf.nn.relu), 66 | tfkl.Dense(hidden_size, activation=tf.nn.relu), 67 | tfkl.Dense(28 * 28, activation=None) 68 | ]) 69 | bernoulli_logits = generative_net(z) 70 | return tf.reshape(bernoulli_logits, [-1, 28, 28, 1]) 71 | 72 | 73 | def train(): 74 | # Train a Variational Autoencoder on MNIST 75 | 76 | # Input placeholders 77 | with tf.name_scope('data'): 78 | x = tfc.placeholder(tf.float32, [None, 28, 28, 1]) 79 | tfc.summary.image('data', x) 80 | 81 | with tfc.variable_scope('variational'): 82 | q_mu, q_sigma = inference_network(x=x, 83 | latent_dim=FLAGS.latent_dim, 84 | hidden_size=FLAGS.hidden_size) 85 | # The variational distribution is a Normal with mean and standard 86 | # deviation given by the inference network 87 | q_z = tfp.distributions.Normal(loc=q_mu, scale=q_sigma) 88 | assert q_z.reparameterization_type == tfp.distributions.FULLY_REPARAMETERIZED 89 | 90 | with tfc.variable_scope('model'): 91 | # The likelihood is Bernoulli-distributed with logits given by the 92 | # generative network 93 | p_x_given_z_logits = generative_network(z=q_z.sample(), 94 | hidden_size=FLAGS.hidden_size) 95 | p_x_given_z = tfp.distributions.Bernoulli(logits=p_x_given_z_logits) 96 | posterior_predictive_samples = p_x_given_z.sample() 97 | tfc.summary.image('posterior_predictive', 98 | tf.cast(posterior_predictive_samples, tf.float32)) 99 | 100 | # Take samples from the prior 101 | with tfc.variable_scope('model', reuse=True): 102 | p_z = tfp.distributions.Normal(loc=np.zeros(FLAGS.latent_dim, dtype=np.float32), 103 | scale=np.ones(FLAGS.latent_dim, dtype=np.float32)) 104 | p_z_sample = p_z.sample(FLAGS.n_samples) 105 | p_x_given_z_logits = generative_network(z=p_z_sample, 106 | hidden_size=FLAGS.hidden_size) 107 | prior_predictive = tfp.distributions.Bernoulli(logits=p_x_given_z_logits) 108 | prior_predictive_samples = prior_predictive.sample() 109 | tfc.summary.image('prior_predictive', 110 | tf.cast(prior_predictive_samples, tf.float32)) 111 | 112 | # Take samples from the prior with a placeholder 113 | with tfc.variable_scope('model', reuse=True): 114 | z_input = tf.placeholder(tf.float32, [None, FLAGS.latent_dim]) 115 | p_x_given_z_logits = generative_network(z=z_input, 116 | hidden_size=FLAGS.hidden_size) 117 | prior_predictive_inp = tfp.distributions.Bernoulli(logits=p_x_given_z_logits) 118 | prior_predictive_inp_sample = prior_predictive_inp.sample() 119 | 120 | # Build the evidence lower bound (ELBO) or the negative loss 121 | kl = tf.reduce_sum(tfp.distributions.kl_divergence(q_z, p_z), 1) 122 | expected_log_likelihood = tf.reduce_sum(p_x_given_z.log_prob(x), 123 | [1, 2, 3]) 124 | 125 | elbo = tf.reduce_sum(expected_log_likelihood - kl, 0) 126 | optimizer = tfc.train.RMSPropOptimizer(learning_rate=0.001) 127 | train_op = optimizer.minimize(-elbo) 128 | 129 | # Merge all the summaries 130 | summary_op = tfc.summary.merge_all() 131 | 132 | init_op = tfc.global_variables_initializer() 133 | 134 | # Run training 135 | sess = tfc.InteractiveSession() 136 | sess.run(init_op) 137 | 138 | mnist_data = tfds.load(name='binary_mnist', split='train', shuffle_files=False) 139 | dataset = mnist_data.repeat().shuffle(buffer_size=1024).batch(FLAGS.batch_size) 140 | 141 | print('Saving TensorBoard summaries and images to: %s' % FLAGS.logdir) 142 | train_writer = tfc.summary.FileWriter(FLAGS.logdir, sess.graph) 143 | 144 | t0 = time.time() 145 | for i, batch in enumerate(tfds.as_numpy(dataset)): 146 | np_x = batch['image'] 147 | sess.run(train_op, {x: np_x}) 148 | if i % FLAGS.print_every == 0: 149 | np_elbo, summary_str = sess.run([elbo, summary_op], {x: np_x}) 150 | train_writer.add_summary(summary_str, i) 151 | print('Iteration: {0:d} ELBO: {1:.3f} s/iter: {2:.3e}'.format( 152 | i, 153 | np_elbo / FLAGS.batch_size, 154 | (time.time() - t0) / FLAGS.print_every)) 155 | # Save samples 156 | np_posterior_samples, np_prior_samples = sess.run( 157 | [posterior_predictive_samples, prior_predictive_samples], {x: np_x}) 158 | for k in range(FLAGS.n_samples): 159 | f_name = os.path.join( 160 | FLAGS.logdir, 'iter_%d_posterior_predictive_%d_data.jpg' % (i, k)) 161 | imwrite(f_name, np_x[k, :, :, 0].astype(np.uint8)) 162 | f_name = os.path.join( 163 | FLAGS.logdir, 'iter_%d_posterior_predictive_%d_sample.jpg' % (i, k)) 164 | imwrite(f_name, np_posterior_samples[k, :, :, 0].astype(np.uint8)) 165 | f_name = os.path.join( 166 | FLAGS.logdir, 'iter_%d_prior_predictive_%d.jpg' % (i, k)) 167 | imwrite(f_name, np_prior_samples[k, :, :, 0].astype(np.uint8)) 168 | t0 = time.time() 169 | 170 | if __name__ == '__main__': 171 | train() 172 | --------------------------------------------------------------------------------