├── requirements.txt ├── src ├── util.py ├── jvicount.py ├── data.py ├── updater.py ├── jvi.py ├── elbo_updater.py ├── numfun.py ├── mnist.py ├── evaluate.py ├── train.py ├── ais.py └── model.py ├── LICENSE ├── .gitignore └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | chainer==3.1.0 2 | numpy==1.11.0 3 | h5py==2.6.0 4 | cupy==2.1.0 5 | docopt==0.6.2 6 | scipy==1.0.0 7 | PyYAML==3.12 8 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | import numpy as np 6 | import subprocess 7 | from subprocess import PIPE 8 | 9 | from chainer import cuda 10 | from chainer.cuda import cupy 11 | 12 | def print_compute_graph(file, g): 13 | format = file.split('.')[-1] 14 | cmd = 'dot -T%s -o %s'%(format,file) 15 | p=subprocess.Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=PIPE, shell=True) 16 | p.stdin.write(g.dump()) 17 | p.communicate() 18 | return p.returncode 19 | 20 | -------------------------------------------------------------------------------- /src/jvicount.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | """Count number of JVI terms for a given sample size and JVI order. 6 | 7 | Usage: 8 | jvicount.py 9 | 10 | The sample size is n and JVI order zero corresponds to the IWAE bound. 11 | The tool returns the total number of terms in the JVI objective. 12 | """ 13 | 14 | from docopt import docopt 15 | import jvi 16 | 17 | args = docopt(__doc__, version='jvicount 0.1') 18 | 19 | n = int(args['']) 20 | order = int(args['']) 21 | 22 | count = jvi.jvi_size(n,order) 23 | print count 24 | 25 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | import sys 6 | 7 | import mnist 8 | 9 | # The only function to access data sets and models for particular datasets. 10 | def prepare_dataset_and_model(dataset, nhidden, nlatent, fancy=False): 11 | if dataset == "mnist": 12 | print "Using dynamically binarized MNIST handwritten digit dataset" 13 | train, val, test = mnist.get_mnist_vae() 14 | din = len(train[0]) 15 | 16 | encoder = mnist.MNISTEncoder(din, nhidden, nlatent) 17 | decoder = mnist.MNISTDecoder(din, nhidden, nlatent) 18 | else: 19 | sys.exit("Unknown dataset name ('%s') supplied to option '-d'." % dataset) 20 | 21 | return train, val, test, encoder, decoder 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. All rights reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /src/updater.py: -------------------------------------------------------------------------------- 1 | 2 | import chainer 3 | import chainer.functions as F 4 | from chainer import Variable 5 | 6 | import model 7 | 8 | class ISVAEUpdater(chainer.training.StandardUpdater): 9 | def __init__(self, num_zsamples=8, *args, **kwargs): 10 | self.encoder, self.decoder = kwargs.pop('models') 11 | self.num_zsamples = num_zsamples 12 | super(ISVAEUpdater, self).__init__(*args, **kwargs) 13 | self.elbo = model.ELBOObjective(self.encoder, self.decoder, self.num_zsamples) 14 | self.isobjective = model.ISObjective(self.encoder, self.decoder, self.num_zsamples) 15 | 16 | def encoder_objective(self, encoder, x): 17 | obj_elbo = self.elbo(x) 18 | chainer.report({'elbo': obj_elbo}, encoder) 19 | return obj_elbo 20 | 21 | def decoder_objective(self, decoder, x): 22 | obj_is = self.isobjective(x) 23 | chainer.report({'is': obj_is}, decoder) 24 | return obj_is 25 | 26 | def update_core(self): 27 | batch = self.get_iterator('main').next() 28 | x = Variable(self.converter(batch, self.device)) 29 | xp = chainer.cuda.get_array_module(x.data) 30 | 31 | self.encoder.cleargrads() 32 | enc_optimizer = self.get_optimizer('encoder') 33 | enc_optimizer.update(self.encoder_objective, self.encoder, x) 34 | 35 | self.decoder.cleargrads() 36 | dec_optimizer = self.get_optimizer('decoder') 37 | dec_optimizer.update(self.decoder_objective, self.decoder, x) 38 | 39 | -------------------------------------------------------------------------------- /src/jvi.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | import math 6 | import numpy as np 7 | import scipy.special as sp 8 | from scipy import misc 9 | import functools 10 | import itertools 11 | 12 | # Compute the Sharot coefficients a_{j,k,n,r}. 13 | def sharot(k,n,r): 14 | p = float(r)/float(n) 15 | res = np.zeros(k+1) 16 | for j in xrange(0,k+1): 17 | c = sp.binom(k,j)*((1.0-j*p)**k) 18 | if (j % 2) == 1: 19 | c *= -1 20 | res[j] = c / ((p**k)*math.factorial(k)) 21 | return res 22 | 23 | def jvi_size(n,k): 24 | return sum([int(sp.binom(n,j)) for j in xrange(0,k+1)]) 25 | 26 | # Compute the JVI weighting vector and weighting matrix. 27 | # 28 | # Return A,B, where: 29 | # A: (M,) vector of weights, one for each of M sets. 30 | # B: (M,n) matrix; each row corresponds to a subset of n samples, 31 | # withing a row, the weights correspond to the weighted sum over samples. 32 | def jvi_matrix(n,k): 33 | if n <= k: 34 | raise ValueError("JVI order must be smaller than number of samples.") 35 | 36 | sc = jvi_size(n,k) 37 | B = np.zeros((sc,n)) 38 | A = np.zeros(sc) 39 | SH = sharot(k,n,1) 40 | 41 | j = 0 42 | for setsize in xrange(0,k+1): 43 | for index_set in itertools.combinations(range(n), n-setsize): 44 | B[j,index_set] = 1.0/float(n-setsize) 45 | A[j] = SH[setsize] / sp.binom(n, n-setsize) 46 | j += 1 47 | 48 | return A, B 49 | -------------------------------------------------------------------------------- /src/elbo_updater.py: -------------------------------------------------------------------------------- 1 | 2 | import chainer 3 | import chainer.functions as F 4 | 5 | from chainer import Variable 6 | from chainer import training 7 | 8 | class ELBOUpdater(training.StandardUpdater): 9 | def __init__(self, *args, **kwargs): 10 | self.elbo, self.p_obj = kwargs.pop('models') 11 | self.encode = self.elbo.encode 12 | self.decode = self.p_obj.decode 13 | 14 | super(ELBOUpdater, self).__init__(*args, **kwargs) 15 | 16 | def compute_elbo(self, elbo, x): 17 | obj = elbo(x) 18 | chainer.report({'elbo': obj}, elbo) 19 | return obj 20 | 21 | def compute_pobj(self, p_obj, x): 22 | obj = p_obj(x) 23 | chainer.report({'obj': obj}, p_obj) 24 | return obj 25 | 26 | def update_core(self): 27 | elbo_optimizer = self.get_optimizer('elbo') 28 | pobj_optimizer = self.get_optimizer('p_obj') 29 | 30 | batch = self.get_iterator('main').next() 31 | x = Variable(self.converter(batch, self.device)) 32 | xp = chainer.cuda.get_array_module(x.data) 33 | 34 | elbo, p_obj = self.elbo, self.p_obj 35 | 36 | # Update q, hold p fixed 37 | self.encode.enable_update() 38 | self.decode.disable_update() 39 | elbo.cleargrads() 40 | elbo_optimizer.update(self.compute_elbo, elbo, x) 41 | 42 | # Fix q, update p 43 | self.encode.disable_update() 44 | self.decode.enable_update() 45 | p_obj.cleargrads() 46 | pobj_optimizer.update(self.compute_pobj, p_obj, x) 47 | 48 | -------------------------------------------------------------------------------- /src/numfun.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | import numpy 6 | 7 | from chainer import cuda 8 | from chainer import function 9 | from chainer import utils 10 | from chainer.utils import type_check 11 | 12 | 13 | class Log1mExp(function.Function): 14 | def check_type_forward(self, in_types): 15 | type_check.expect(in_types.size() == 1) 16 | x_type, = in_types 17 | 18 | type_check.expect(x_type.dtype.kind == 'f') 19 | 20 | def forward_cpu(self, inputs): 21 | x, = inputs 22 | # y = log(1 - exp(x)) 23 | y = numpy.log1p(-numpy.exp(x)) 24 | return utils.force_array(y, x.dtype), 25 | 26 | def forward_gpu(self, inputs): 27 | x, = inputs 28 | y = cuda.elementwise( 29 | 'T x', 'T y', 30 | ''' 31 | y = log1p(-exp(x)); 32 | ''', 33 | 'log1mexp_fwd' 34 | )(x) 35 | return y, 36 | 37 | def backward_cpu(self, inputs, grads): 38 | x, = inputs 39 | g, = grads 40 | gx = (-1 / (numpy.exp(-x) - 1)) * g 41 | return utils.force_array(gx, x.dtype), 42 | 43 | def backward_gpu(self, inputs, grads): 44 | x, = inputs 45 | g, = grads 46 | gx = cuda.elementwise( 47 | 'T x, T g', 'T gx', 48 | 'gx = - 1 / (exp(-x) - 1) * g', 49 | 'log1mexp_bwd' 50 | )(x, g) 51 | return gx, 52 | 53 | 54 | def log1mexp(x): 55 | return Log1mExp()(x) 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /src/mnist.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | import numpy as np 6 | import math 7 | import chainer 8 | import chainer.functions as F 9 | from chainer.functions.loss.vae import bernoulli_nll 10 | import chainer.links as L 11 | from chainer.datasets import mnist 12 | from chainer.datasets import TransformDataset 13 | 14 | import model 15 | 16 | # Dynamic binarization. x is an array with values 0 <= v <= 1. 17 | def mnist_transform(x): 18 | x = np.copy(x) 19 | U = np.random.uniform(size=x.shape) 20 | # e.g. U = 0.6, x = 0.92 21 | x[U <= x] = 1.0 22 | x[U > x] = 0.0 23 | return x 24 | 25 | # Binarized MNIST with dynamic binarization, see 26 | # [Salakhutdinov and Murray, ICML 2008] 27 | def get_mnist_vae(): 28 | train, test = mnist.get_mnist(withlabel=False) 29 | val = train[50000:60000] 30 | train = train[0:50000] 31 | train = TransformDataset(train, mnist_transform) 32 | val = TransformDataset(val, mnist_transform) 33 | test = TransformDataset(test, mnist_transform) 34 | 35 | return train, val, test 36 | 37 | def bernoulli_logp_inst(x, h): 38 | """log B(x; p=sigmoid(h))""" 39 | L = bernoulli_nll(x, h, reduce='no') 40 | return -F.sum(L,axis=1) 41 | 42 | # MNIST encoder 43 | class MNISTEncoder(chainer.Chain): 44 | def __init__(self, dim_in, dim_hidden, dim_latent): 45 | super(MNISTEncoder, self).__init__( 46 | # encoder 47 | qlin0 = L.Linear(dim_in, dim_hidden), 48 | qlin1 = L.Linear(2*dim_hidden, dim_hidden), 49 | qlin_mu = L.Linear(2*dim_hidden, dim_latent), 50 | qlin_ln_var = L.Linear(2*dim_hidden, dim_latent), 51 | ) 52 | 53 | def __call__(self, x): 54 | h = F.crelu(self.qlin0(x)) 55 | h = F.crelu(self.qlin1(h)) 56 | qmu = self.qlin_mu(h) 57 | qln_var = self.qlin_ln_var(h) 58 | 59 | return qmu, qln_var 60 | 61 | class MNISTLikelihood: 62 | def __init__(self, ph): 63 | self.ph = ph 64 | 65 | def __call__(self, x): 66 | return bernoulli_logp_inst(x, self.ph) 67 | 68 | class MNISTDecoder(chainer.Chain): 69 | def __init__(self, dim_in, dim_hidden, dim_latent): 70 | super(MNISTDecoder, self).__init__( 71 | # decoder 72 | plin0 = L.Linear(dim_latent, dim_hidden), 73 | plin1 = L.Linear(2*dim_hidden, dim_hidden), 74 | plin2 = L.Linear(2*dim_hidden, dim_in), 75 | ) 76 | self.nz = dim_latent 77 | 78 | def __call__(self, z): 79 | h = F.crelu(self.plin0(z)) 80 | h = F.crelu(self.plin1(h)) 81 | ph = self.plin2(h) 82 | 83 | return MNISTLikelihood(ph) 84 | 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Jackknife Variational Inference, Python implementation 3 | 4 | This repository contains code related to the following 5 | [ICLR 2018](https://iclr.cc/Conferences/2018) paper: 6 | 7 | * _Sebastian Nowozin_, "Debiasing Evidence Approximations: On 8 | Importance-weighted Autoencoders and Jackknife Variational Inference", 9 | [Forum](https://openreview.net/forum?id=HyZoi-WRb), 10 | [PDF](https://openreview.net/pdf?id=HyZoi-WRb). 11 | 12 | 13 | ## Citation 14 | 15 | If you use this code or build upon it, please cite the following paper (BibTeX 16 | format): 17 | 18 | ``` 19 | @InProceedings{ 20 | title = "Debiasing Evidence Approximations: On Importance-weighted Autoencoders and Jackknife Variational Inference", 21 | author = "Sebastian Nowozin", 22 | booktitle = "International Conference on Learning Representations (ICLR 2018)", 23 | year = "2018" 24 | } 25 | ``` 26 | 27 | ## Installation 28 | 29 | Install the required Python2 prequisites via running: 30 | 31 | ``` 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | Currently this installs: 36 | 37 | * [Chainer](http://chainer.org/), the deep learning framework, version 3.1.0 38 | * [CuPy](http://cupy.chainer.org/), a CUDA linear algebra framework compatible 39 | with NumPy, version 2.1.0 40 | * [NumPy](http://www.numpy.org/), numerical linear algebra for Python, version 1.11.0 41 | * [SciPy](http://www.scipy.org/), scientific computing framework for Python, version 1.0.0 42 | * [H5py](http://www.h5py.org/), an HDF5 interface for Python, version 2.6.0 43 | * [docopt](http://docopt.org/), Pythonic command line arguments parser, version 0.6.2 44 | * [PyYAML](https://github.com/yaml/pyyaml), Python library for 45 | [YAML](http://yaml.org) data language, version 3.12 46 | 47 | ## Running the MNIST experiment 48 | 49 | To train the MNIST model from the paper, use the following parameters: 50 | 51 | ``` 52 | python ./train.py -g 0 -d mnist -e 1000 -b 2048 --opt adam \ 53 | --vae-type jvi --vae-samples 8 --jvi-order 1 --nhidden 300 --nlatent 40 \ 54 | -o modeloutput 55 | ``` 56 | 57 | Here the parameters are: 58 | 59 | * `-g 0`: train on GPU device 0 60 | * `-d mnist`: use the dynamically binarized MNIST data set 61 | * `-e 1000`: train for 1000 epochs 62 | * `-b 2048`: use a batch size of 2048 samples 63 | * `--opt adam`: use the Adam optimizer 64 | * `--vae-type jvi`: use _jackknife_ variational inference 65 | * `--vae-samples 8`: use eight Monte Carlo samples 66 | * `--jvi-order 1`: use first-order JVI bias correction 67 | * `--nhidden 300`: in each hidden layer use 300 hidden neurons 68 | * `--nlatent 40`: use 40 dimensions for the VAE latent variable 69 | 70 | The training process creates a file `modeloutput.meta.yaml` containing the 71 | training parameters as well as a directoy `modeloutput/` which contains a log 72 | file and the serialized model which performed best on the validation set. 73 | 74 | To evaluate the trained model on the test set, use 75 | 76 | ``` 77 | python ./evaluate.py -g 0 -d mnist -E iwae -s 256 modeloutput 78 | ``` 79 | 80 | This evaluates the model trained previously using the following test-time 81 | evaluation setup: 82 | 83 | * `-g 0`: use GPU device 0 for evaluation 84 | * `-d mnist`: evaluate on the mnist data set 85 | * `-E iwae`: use the IWAE objective for evaluation 86 | * `-s 256`: use 256 Monte Carlo samples in the IWAE objective 87 | 88 | Because test-time evaluation does not require backpropagation, we can evaluate 89 | the IWAE and JVI objectives accurately using a large number of samples, e.g. 90 | `-s 65536`. 91 | 92 | The `evaluate.py` script also supports a `--reps 10` parameter which would 93 | evaluate the same model ten times to investigate variance in the Monte Carlo 94 | approximation to the evaluation objective. 95 | 96 | 97 | ## Choosing different objectives 98 | 99 | As illustrated in the paper, the JVI objective generalizes both the ELBO and 100 | the IWAE objectives. 101 | 102 | For example, you can train on the importance-weighted autoencoder (IWAE) 103 | objective using the parameter `--jvi-order 0` instead of `--jvi-order 1`. 104 | 105 | You can train using the regular evidence lower bound (ELBO) by using the 106 | special case of JVI, `--jvi-order 0 --vae-samples 1`, or directly via 107 | `--vae-type vae`. 108 | 109 | # Counting JVI sets 110 | 111 | We include a small utility to count the number of subsets used by the 112 | different JVI approximations. There are two parameters, `n` and `order`, 113 | where `n` is the number of samples of latent space variables per instance, and 114 | `order` is the order of the JVI approximation (order zero corresponds to the 115 | IWAE). 116 | 117 | To run the utility, use: 118 | 119 | ``` 120 | python ./jvicount.py 16 2 121 | ``` 122 | 123 | This utility is useful because the set size can grow very rapidly for larger 124 | JVI orders. Therefore we can use the utility to assess the total number of 125 | terms quickly and make informed choices about batch sizes and order of the 126 | approximation. 127 | 128 | 129 | # Contact 130 | 131 | _Sebastian Nowozin_, `Sebastian.Nowozin@microsoft.com` 132 | 133 | 134 | # Contributing 135 | 136 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 137 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 138 | the rights to use your contribution. For details, visit https://cla.microsoft.com. 139 | 140 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide 141 | a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions 142 | provided by the bot. You will only need to do this once across all repos using our CLA. 143 | 144 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 145 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 146 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 147 | 148 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Microsoft Corporation. All rights reserved. 4 | # Licensed under the MIT license. 5 | 6 | """Evaluate log-likelihood of VAE models. 7 | 8 | Usage: 9 | evaluate.py (-h | --help) 10 | evaluate.py [options] 11 | 12 | Options: 13 | -h --help Show this help screen. 14 | -d , --dataset Dataset to use, must be "mnist" in public release [default: mnist]. 15 | -g , --device GPU id to train model on. Use -1 for CPU [default: -1]. 16 | -s , --samples Number of evaluation samples [default: 64]. 17 | -b , --batchsize Evaluation minibatch size [default: 4096]. 18 | -r , --reps Replications [default: 1]. 19 | --bootstrap Perform resampling with replacement from test set. 20 | --resultfile Write logp estimate to file, or "+" to auto-filename. 21 | --logw Write logw array for first 1024 samples. 22 | -E , --eval Evaluation type, "vae", "iwae", "iwae++", or "ais" [default: iwae] 23 | --jvi-order Order of jackknife, zero is IWAE [default: 1]. 24 | --ais-temps Number of temperatures [default: 3000]. 25 | --ais-prior Use p(z) instead of q(z|x) for initial z sample. 26 | --ais-sigma Moment parameter standard deviation [default: 1.0]. 27 | --ais-steps Number of leapfrog steps [default: 10]. 28 | --ais-stepsize Leapfrog step size [default: 0.03]. 29 | 30 | The modelprefix.meta.yaml file must exist. 31 | """ 32 | 33 | import sys 34 | import time 35 | import timeit 36 | import yaml 37 | import math 38 | import numpy as np 39 | import scipy.io as sio 40 | from docopt import docopt 41 | 42 | import chainer 43 | from chainer.training import extensions 44 | from chainer import reporter 45 | from chainer import serializers 46 | from chainer import optimizers 47 | from chainer import cuda 48 | from chainer import computational_graph 49 | import chainer.functions as F 50 | import cupy 51 | 52 | import data 53 | import model 54 | import util 55 | import ais 56 | 57 | args = docopt(__doc__, version='evaluate 0.4') 58 | print(args) 59 | 60 | mpref = args[''] 61 | model_file = mpref+'/snapshot_bestvalobj' 62 | yaml_file = mpref+'.meta.yaml' 63 | 64 | # Read meta data 65 | with open(yaml_file, 'r') as f: 66 | argsy = yaml.load(f) 67 | 68 | nhidden = int(argsy['--nhidden']) 69 | nlatent = int(argsy['--nlatent']) 70 | zcount_train = int(argsy['--vae-samples']) 71 | print "Model was trained using %d vae samples" % zcount_train 72 | 73 | batchsize = int(args['--batchsize']) 74 | print "Using a batchsize of %d instances for evaluation" % batchsize 75 | 76 | reps = int(args['--reps']) 77 | print "Performing %d replications" % reps 78 | 79 | # Load data and instantiate matching encoder/decoder models 80 | dataset = args['--dataset'] 81 | train, val, test, encoder, decoder = data.prepare_dataset_and_model(dataset, nhidden, nlatent) 82 | 83 | vae_type_train = argsy['--vae-type'] 84 | zcount = int(args['--samples']) 85 | print "Using %d samples for likelihood evaluation" % zcount 86 | 87 | # Check GPU 88 | gpu_id = int(args['--device']) 89 | print "Running on GPU %d" % gpu_id 90 | if gpu_id >= 0: 91 | cuda.check_cuda_available() 92 | xp = cuda.cupy 93 | else: 94 | xp = np 95 | 96 | # Evaluation type 97 | vae_type = args['--eval'] 98 | jvi_order = int(args['--jvi-order']) 99 | if vae_type == "vae": 100 | vae = model.ELBOObjective(encoder, decoder, zcount) 101 | elif vae_type == "iwae": 102 | vae = model.IWAEObjective(encoder, decoder, zcount) 103 | elif vae_type == "iwae++": 104 | vae = model.ImprovedIWAEObjective(encoder, decoder, zcount) 105 | elif vae_type == "jvi": 106 | vae = model.JVIObjective(encoder, decoder, zcount, jvi_order, device=gpu_id) 107 | elif vae_type == "is": 108 | vae = model.ISObjective(encoder, decoder, zcount) 109 | elif vae_type == "ais": 110 | ais_temps = int(args['--ais-temps']) 111 | ais_sigma = float(args['--ais-sigma']) 112 | ais_steps = int(args['--ais-steps']) 113 | ais_stepsize = float(args['--ais-stepsize']) 114 | print "Annealed importance sampling, %d samples, %d temperatures" % (zcount, ais_temps) 115 | print "Leapfrog integrator, %d steps, %f stepsize" % (ais_steps, ais_stepsize) 116 | if args['--ais-prior']: 117 | vae = ais.AIS(decoder, M=zcount, T=ais_temps, 118 | steps=ais_steps, stepsize=ais_stepsize, sigma=ais_sigma, 119 | encoder=None) 120 | else: 121 | vae = ais.AIS(decoder, M=zcount, T=ais_temps, 122 | steps=ais_steps, stepsize=ais_stepsize, sigma=ais_sigma, 123 | encoder=encoder) 124 | else: 125 | sys.exit("Unsupported VAE type") 126 | 127 | #serializers.load_hdf5(model_file, vae) 128 | #serializers.load_npz(model_file, vae, path='updater/model:main/') 129 | try: 130 | with np.load(model_file) as f: 131 | d = serializers.NpzDeserializer(f,path='updater/model:main/') 132 | d.load(vae) 133 | except: 134 | with np.load(model_file) as f: 135 | d = serializers.NpzDeserializer(f,path='updater/model:elbo/') 136 | d.load(vae) 137 | 138 | print "Deserialized model '%s' of type '%s'" % (model_file, vae_type_train) 139 | 140 | if gpu_id >= 0: 141 | vae.to_gpu(gpu_id) 142 | print "Moved model to GPU %d" % gpu_id 143 | 144 | # For debugging purposes, optionally, obtain and write logw value for the 145 | # first few test samples 146 | if '--logw' in args and args['--logw'] is not None: 147 | logw_file = args['--logw'] 148 | with cupy.cuda.Device(gpu_id): 149 | with chainer.no_backprop_mode(): 150 | xt = test[0:256,:] 151 | print xt.shape 152 | xt = chainer.Variable(xp.asarray(xt, dtype=np.float32)) 153 | logw = vae.compute_logw(xt) 154 | logw.to_cpu() 155 | sio.savemat(logw_file, {"logw": logw.data}) 156 | 157 | print "Wrote logw for first 256 samples to file '%s'." % logw_file 158 | 159 | print "Evaluating..." 160 | 161 | for ri in xrange(reps): 162 | if args['--bootstrap']: 163 | print "Bootstrap resampling..." 164 | test_idx = np.random.choice(len(test), size=(len(test),), replace=True) 165 | test_cur = test[test_idx,:] 166 | print test_cur.shape 167 | else: 168 | test_cur = test 169 | 170 | test_iter = chainer.iterators.SerialIterator(test_cur, batchsize, repeat=False, shuffle=False) 171 | obs = {} 172 | reprt = reporter.Reporter() 173 | reprt.add_observer('main', vae) 174 | with cupy.cuda.Device(gpu_id): 175 | start_time = timeit.default_timer() 176 | with reprt.scope(obs): 177 | teval = extensions.Evaluator(test_iter, vae, device=gpu_id) 178 | res = teval.evaluate() 179 | 180 | runtime = timeit.default_timer() - start_time 181 | print "Evaluation took %.2fs" % runtime 182 | print res 183 | obj_mean = -res['main/obj'] 184 | obj_sem = res['main/obj_var'] 185 | obj_sem = math.sqrt(obj_sem/len(test)) 186 | print "%.8f +/- %.8f # logp(%s) %d" % (obj_mean, obj_sem, vae_type, zcount) 187 | 188 | if '--resultfile' in args and args['--resultfile'] is not None: 189 | resultfile = args['--resultfile'] 190 | if resultfile == "+": 191 | resultfile = mpref+".test.perf" 192 | 193 | print "Writing test set results to '%s'." % resultfile 194 | with open(resultfile, "a") as rf: 195 | rf.write("%.8f,%.8f,%.8f\n" % (obj_mean, obj_sem, runtime)) 196 | 197 | sys.exit() 198 | 199 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Microsoft Corporation. All rights reserved. 4 | # Licensed under the MIT license. 5 | 6 | """Train a variational autoencoder (VAE/IWAE/JVI) model. 7 | 8 | Usage: 9 | train.py (-h | --help) 10 | train.py [options] 11 | 12 | Options: 13 | -h --help Show this help screen. 14 | -d , --dataset Dataset to use, one of "mnist", "msrc12", "celeb1m", "celeb1m-small" [default: msrc12]. 15 | -o Write trained model to given file.h5 [default: output]. 16 | -g , --device GPU id to train model on. Use -1 for CPU [default: -1]. 17 | -e , --epochs Number of epochs to train [default: 1000]. 18 | -b , --batchsize Minibatch size [default: 8192]. 19 | --fancy Use fancy model variant. 20 | --lr Initial learning rate [default: 1.0e-4]. 21 | --opt Optimizer to use, one of "sgd", "smorms3" or "adam" [default: smorms3]. 22 | --vae-type VAE type, one of "vae", "iwae", "iwae++", "jvi", "jvi+elbo" [default: vae]. 23 | --vae-samples Number of samples in VAE z [default: 1]. 24 | --jvi-order Order of jackknife, zero is IWAE [default: 1]. 25 | --nhidden Number of hidden dimensions [default: 128]. 26 | --nlatent Number of latent VAE dimensions [default: 16]. 27 | --vis Visualize computation graph. 28 | 29 | For "msrc12" the data.mat file must contain a (N,d) array of N instances, d 30 | dimensions each. 31 | """ 32 | 33 | import sys 34 | import time 35 | import yaml 36 | import numpy as np 37 | from docopt import docopt 38 | 39 | import chainer 40 | from chainer import training 41 | from chainer.training import extensions 42 | from chainer.datasets import tuple_dataset 43 | from chainer import reporter 44 | from chainer import serializers 45 | from chainer import optimizers 46 | from chainer import cuda 47 | from chainer import computational_graph 48 | import chainer.functions as F 49 | import cupy 50 | 51 | import model 52 | import updater 53 | import util 54 | import data 55 | from elbo_updater import ELBOUpdater 56 | 57 | class TestModeEvaluator(extensions.Evaluator): 58 | def evaluate(self): 59 | model = self.get_target('main') 60 | #use_elbo = model.use_elbo 61 | #model.use_elbo = True 62 | model.train = False 63 | ret = super(TestModeEvaluator, self).evaluate() 64 | model.train = True 65 | #model.use_elbo = use_elbo 66 | return ret 67 | 68 | def save_model(args, vae): 69 | # Save model 70 | if args['-o'] is not None: 71 | modelmeta = args['-o'] + '.meta.yaml' 72 | print "Writing model metadata to '%s' ..." % modelmeta 73 | with open(modelmeta, 'w') as outfile: 74 | outfile.write(yaml.dump(dict(args), default_flow_style=False)) 75 | 76 | modelfile = args['-o'] + '.h5' 77 | print "Writing model to '%s' ..." % modelfile 78 | serializers.save_hdf5(modelfile, vae) 79 | 80 | 81 | args = docopt(__doc__, version='train 0.2') 82 | print(args) 83 | 84 | print "Using chainer version %s" % chainer.__version__ 85 | nhidden = int(args['--nhidden']) 86 | print "%d hidden dimensions" % nhidden 87 | nlatent = int(args['--nlatent']) 88 | print "%d latent VAE dimensions" % nlatent 89 | 90 | fancy = False 91 | if args['--fancy']: 92 | fancy = True 93 | print "Using fancy model version." 94 | 95 | # Load data and instantiate matching encoder/decoder models 96 | dataset = args['--dataset'] 97 | train, val, test, encoder, decoder = data.prepare_dataset_and_model(dataset, nhidden, nlatent, fancy) 98 | 99 | print "Training set size: %d instances" % len(train) 100 | print "Validation set size: %d instances" % len(val) 101 | 102 | epochs = int(args['--epochs']) 103 | print "Training for %d epochs" % epochs 104 | 105 | # Setup model 106 | zcount = int(args['--vae-samples']) 107 | print "Using %d VAE samples per instance" % zcount 108 | 109 | gpu_id = int(args['--device']) 110 | 111 | vae_type = args['--vae-type'] 112 | jvi_order = int(args['--jvi-order']) 113 | elbo = None 114 | print "Training using '%s' objective" % vae_type 115 | if vae_type == "vae": 116 | vae = model.ELBOObjective(encoder, decoder, zcount) 117 | elif vae_type == "iwae": 118 | vae = model.IWAEObjective(encoder, decoder, zcount) 119 | elif vae_type == "iwae++": 120 | vae = model.ImprovedIWAEObjective(encoder, decoder, zcount) 121 | elif vae_type == "jvi": 122 | vae = model.JVIObjective(encoder, decoder, zcount, jvi_order, device=gpu_id) 123 | elif vae_type == "jvi+elbo": 124 | vae = model.JVIObjective(encoder, decoder, zcount, jvi_order, device=gpu_id) 125 | elbo = model.ELBOObjective(encoder, decoder, zcount) 126 | elif vae_type == "is": 127 | vae = model.ISObjective(encoder, decoder, zcount) 128 | else: 129 | sys.exit("Unsupported VAE type (%s)." % vae_type) 130 | 131 | lr = float(args['--lr']) 132 | print "Using initial learning rate %f" % lr 133 | opt_type = args['--opt'] 134 | if opt_type == "adam": 135 | opt = optimizers.Adam(alpha=lr) 136 | opt_elbo = optimizers.Adam(alpha=lr) 137 | elif opt_type == "smorms3": 138 | opt = optimizers.SMORMS3(lr=lr) 139 | opt_elbo = optimizers.SMORMS3(lr=lr) 140 | elif opt_type == "sgd": 141 | opt = optimizers.SGD(lr=lr) 142 | opt_elbo = optimizers.SGD(lr=lr) 143 | else: 144 | sys.exit("Unsupported optimizer type (%s)." % opt_type) 145 | 146 | opt.setup(vae) 147 | opt.add_hook(chainer.optimizer.GradientClipping(4.0)) 148 | 149 | if elbo: 150 | opt_elbo.setup(elbo) 151 | opt_elbo.add_hook(chainer.optimizer.GradientClipping(4.0)) 152 | 153 | # Move to GPU 154 | if gpu_id >= 0: 155 | cuda.check_cuda_available() 156 | if gpu_id >= 0: 157 | xp = cuda.cupy 158 | vae.to_gpu(gpu_id) 159 | if elbo: 160 | elbo.to_gpu(gpu_id) 161 | else: 162 | xp = np 163 | 164 | # Setup training parameters 165 | batchsize = int(args['--batchsize']) 166 | print "Using a batchsize of %d instances" % batchsize 167 | 168 | # Save model meta data 169 | if args['-o'] is not None: 170 | modelmeta = args['-o'] + '.meta.yaml' 171 | print "Writing model metadata to '%s' ..." % modelmeta 172 | with open(modelmeta, 'w') as outfile: 173 | outfile.write(yaml.dump(dict(args), default_flow_style=False)) 174 | 175 | train_iter = chainer.iterators.SerialIterator(train, batchsize) 176 | val_iter = None 177 | if val is not None: 178 | val_iter = chainer.iterators.SerialIterator(val, batchsize, 179 | repeat=False, shuffle=False) 180 | 181 | if vae_type == "jvi+elbo": 182 | vae = model.JVIObjective(encoder, decoder, zcount, jvi_order, device=gpu_id) 183 | elbo = model.ELBOObjective(encoder, decoder, zcount) 184 | updater = ELBOUpdater( 185 | models=(elbo, vae), 186 | iterator=train_iter, 187 | optimizer={ 'elbo': opt_elbo, 'p_obj': opt }, 188 | device=gpu_id) 189 | else: 190 | updater = training.StandardUpdater(train_iter, opt, device=gpu_id) 191 | 192 | trainer = training.Trainer(updater, (epochs, 'epoch'), out=args['-o']) 193 | if val is not None: 194 | trainer.extend(TestModeEvaluator(val_iter, vae, device=gpu_id)) 195 | #trainer.extend(extensions.ExponentialShift('lr', 0.5), trigger=(50, 'epoch')) 196 | #trainer.extend(extensions.dump_graph('main/obj')) 197 | #trainer.extend(extensions.snapshot(), trigger=(1, 'epoch')) 198 | trainer.extend(extensions.LogReport()) 199 | valtrigger=chainer.training.triggers.MinValueTrigger('validation/main/obj', 200 | trigger=(1, 'epoch')) 201 | trainer.extend(extensions.snapshot(filename='snapshot_bestvalobj'), 202 | trigger=valtrigger) 203 | 204 | if val is not None: 205 | if elbo: 206 | trainer.extend(extensions.PrintReport( 207 | ['epoch', 'iteration', 'elbo/elbo', 'p_obj/obj', 208 | 'validation/main/obj', 'validation/main/obj_elbo', 'elapsed_time'])) 209 | trainer.reporter.add_observer('elbo', elbo) 210 | trainer.reporter.add_observer('p_obj', vae) 211 | else: 212 | trainer.extend(extensions.PrintReport( 213 | ['epoch', 'iteration', 'main/obj', 'main/obj_elbo', 214 | 'validation/main/obj', 'validation/main/obj_elbo', 'elapsed_time'])) 215 | else: 216 | trainer.extend(extensions.PrintReport( 217 | ['epoch', 'main/obj', 'elapsed_time'])) 218 | trainer.extend(extensions.ProgressBar(update_interval=1)) 219 | 220 | print "Training..." 221 | trainer.run() 222 | print "Minimum validation loss: %.4f" % valtrigger._best_value 223 | perffilename = args['-o'] + '.perf' 224 | print "Writing model validation performance to '%s' ..." % perffilename 225 | perffile = open(perffilename, 'w') 226 | perffile.write("%.6f" % valtrigger._best_value) 227 | 228 | -------------------------------------------------------------------------------- /src/ais.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | import math 6 | import numpy as np 7 | 8 | import cupy 9 | import chainer 10 | import chainer.functions as F 11 | from chainer import cuda 12 | from chainer import reporter 13 | import cupy 14 | 15 | import model 16 | 17 | # Convenience wrapper to define the HMC energy function. 18 | # This is the only code that is VAE specific, the code that follows is general 19 | # HMC code. 20 | class EnergyFunction: 21 | def __init__(self, zprior, decoder, X, inv_temp): 22 | self.zprior = zprior 23 | self.decode = decoder 24 | self.X = X 25 | self.batchsize = X.shape[0] 26 | self.inv_temp = inv_temp 27 | 28 | def E(self, Z): 29 | M = Z.shape[0] / self.batchsize 30 | zs = F.split_axis(Z, M, 0) 31 | Es = list() # energies 32 | 33 | # Process one (batchsize,nlatent) sample at a time 34 | for z in zs: 35 | #logpz = model.gaussian_logp01_inst(z) 36 | logpz = self.zprior(z) 37 | pxz = self.decode(z) 38 | logpxz = pxz(self.X) 39 | energy = -logpz - self.inv_temp*logpxz 40 | Es.append(energy) 41 | 42 | Efull = F.flatten(F.vstack(Es)) 43 | return Efull 44 | 45 | # \nabla_z (-log p(z) - inv_temp log p(x|z)) 46 | def grad(self, Z): 47 | with chainer.force_backprop_mode(): 48 | ZV = chainer.Variable(Z, requires_grad=True) 49 | energy = F.sum(self.E(ZV)) 50 | energy.backward() 51 | 52 | return ZV.grad 53 | 54 | # Z: (M*batchsize,nlatent), where 55 | # Z[0:(batchsize-1),:] is the first sample. 56 | # 57 | # Return E, (M*batchsize,), one energy for each sample. 58 | # 59 | # -log p(z) - inv_temp log p(x|z) 60 | def __call__(self, Z): 61 | with chainer.no_backprop_mode(): 62 | return self.E(Z) 63 | 64 | # Perform Hamiltonian Monte Carlo step with leapfrog integrator. 65 | # n: number of leapfrog steps 66 | # leapfrog_eps: stepsize. 67 | # 68 | # The samples are updated in place in 'sinit' and we return the average 69 | # acceptance rate over samples. 70 | def leapfrog(efun, sinit, n=10, leapfrog_eps=0.1, moment_sigma=1.0): 71 | xp = cupy.get_array_module(sinit) 72 | moment_var = moment_sigma**2.0 73 | 74 | phi = moment_sigma*xp.random.normal(0,1, sinit.shape).astype(np.float32) 75 | phi_prev = xp.empty_like(phi) 76 | phi_prev[:] = phi 77 | 78 | s = xp.empty_like(sinit) 79 | s[:] = sinit 80 | 81 | phi -= 0.5*leapfrog_eps * efun.grad(s) # initial half-step for momentum 82 | for m in xrange(2,n): 83 | s += leapfrog_eps*phi/moment_var 84 | if m < n: 85 | phi -= leapfrog_eps * efun.grad(s) 86 | phi -= 0.5*leapfrog_eps * efun.grad(s) # final half-step 87 | 88 | # Compute acceptance probability 89 | log_alpha = efun(sinit) + 0.5*xp.sum(phi_prev*phi_prev, axis=1)/moment_var 90 | log_alpha -= efun(s) + 0.5*xp.sum(phi*phi, axis=1)/moment_var 91 | log_uniform = xp.log(xp.random.uniform(size=log_alpha.shape)) 92 | accept = log_uniform <= log_alpha.data 93 | sinit[accept,:] = s[accept,:] 94 | 95 | return xp.mean(accept.astype(np.float32)) 96 | 97 | # Compute inverse temperature ladder; 98 | # t: temperature index, 1 <= t <= T 99 | # T: number of stages, T >= 2 100 | # beta1: initial temperature at t=2, with 0 < beta1 < 1.0 101 | def ais_beta(t, T, beta1=1.0e-4): 102 | if t == 1: 103 | return 0.0 104 | 105 | gamma = (1.0/beta1)**(1.0/(T-2)) 106 | return beta1*(gamma**(t-2)) 107 | 108 | def sigmoid(x): 109 | return 1.0 / (1.0 + math.exp(-x)) 110 | 111 | # Sigmoid ladder from Wu et al., 2016 112 | def ais_beta_sigmoid(t, T, rad=4.0): 113 | min_s = sigmoid(-rad) 114 | max_s = sigmoid(rad) 115 | s = sigmoid(-rad + ((t-1.0)/(T-1.0))*2*rad) 116 | return (s - min_s) / (max_s - min_s) 117 | 118 | class ZDistribution: 119 | # Return log w, where w = log r(z) - log p(z). 120 | # Here r is the sampling distribution of z, and p is the prior 121 | # distribution of z. 122 | def initial_logw(self, X, Z): 123 | raise NotImplementedError 124 | 125 | # Return log p(z) 126 | def __call__(self, Z): 127 | raise NotImplementedError 128 | 129 | class ZPrior(ZDistribution): 130 | def __init__(self, nz): 131 | self.nz = nz 132 | 133 | def initial_logw(self, X, Z): 134 | xp = cupy.get_array_module(X) 135 | Mb = Z.shape[0] 136 | logw = xp.zeros((Mb,)) 137 | 138 | return logw 139 | 140 | # X: (batchsize,**) data matrix 141 | # M: number of z to sample for each datum. 142 | def sample(self, X, M): 143 | batchsize = X.shape[0] 144 | xp = cupy.get_array_module(X) 145 | Z = xp.random.normal(size=(M*batchsize, self.nz)).astype(np.float32) 146 | 147 | return Z 148 | 149 | def __call__(self, Z): 150 | return model.gaussian_logp01_inst(Z) 151 | 152 | class ZEncoder(ZDistribution): 153 | def __init__(self, encoder, X): 154 | self.qmu, self.qln_var = encoder(X) # pre-compute q(z|x) 155 | 156 | def initial_logw(self, X, Z): 157 | batchsize = X.shape[0] 158 | M = Z.shape[0] / batchsize 159 | zs = F.split_axis(Z, M, 0) 160 | logw = list() # energies 161 | 162 | # Process one (batchsize,nlatent) sample at a time 163 | for z in zs: 164 | # log w = log p(z) - log q(z|x) 165 | # FIXME: we should sample from q(z|x) as prior but then treat 166 | # p(z) p(x|z) as target. See [Wu et al., 2016] 167 | #logw_i = model.gaussian_logp01_inst(z) 168 | #logw_i -= model.gaussian_logp_inst(z, self.qmu, self.qln_var) 169 | logw_i = model.gaussian_logp_inst(z, self.qmu, self.qln_var) 170 | logw.append(logw_i) 171 | 172 | logw = F.flatten(F.vstack(logw)) 173 | 174 | return logw 175 | 176 | def sample(self, X, M): 177 | Z = list() 178 | for i in xrange(M): 179 | Zm = F.gaussian(self.qmu, self.qln_var) 180 | Z.append(Zm) 181 | 182 | Z = F.vstack(Z) # (M*batchsize, nz) 183 | return Z.data 184 | 185 | def __call__(self, Z): 186 | return model.gaussian_logp01_inst(Z) 187 | 188 | 189 | # Run annealed importance sampling (AIS) to estimate the marginal 190 | # log-probability log p(x) of the samples X given the decoder model. 191 | # 192 | # M: number of AIS chains to run per sample. 193 | # T: number of temperatures in the temperature ladder. 194 | def ais(decoder, X, M=32, T=100, steps=10, stepsize=0.1, sigma=1.0, 195 | encoder=None): 196 | 197 | xp = cupy.get_array_module(X) 198 | batchsize = X.shape[0] # number of samples in X 199 | nz = decoder.nz # number of latent dimensions 200 | 201 | # Sample initial z and initialize log weights 202 | if encoder == None: 203 | print "Using p(z)" 204 | zprior = ZPrior(nz) 205 | else: 206 | print "Using q(z|x)" 207 | zprior = ZEncoder(encoder, X) 208 | 209 | Z = zprior.sample(X, M) 210 | #logw = xp.zeros((M*batchsize,)) 211 | logw = zprior.initial_logw(X, Z) 212 | 213 | for t in xrange(2,T+1): 214 | efun_cur = EnergyFunction(zprior, decoder, X, ais_beta_sigmoid(t, T)) 215 | efun_prev = EnergyFunction(zprior, decoder, X, ais_beta_sigmoid(t-1, T)) 216 | accept_rate = leapfrog(efun_cur, Z, 217 | n=steps, leapfrog_eps=stepsize, moment_sigma=sigma) 218 | if t % 100 == 0: 219 | print "AIS t=%d accept rate %.3f" % (t, accept_rate) 220 | logw += efun_prev(Z).data - efun_cur(Z).data 221 | 222 | logw = F.reshape(logw, (M, batchsize)) 223 | logZ = F.logsumexp(logw, axis=0) - math.log(M) 224 | 225 | return logZ 226 | 227 | class AIS(chainer.Chain): 228 | def __init__(self, decoder, M=32, T=100, steps=10, stepsize=0.1, 229 | sigma=1.0, encoder=None): 230 | super(AIS, self).__init__( 231 | decode = decoder, 232 | ) 233 | if encoder == None: 234 | self.encode = None 235 | else: 236 | self.add_link('encode', encoder) 237 | 238 | self.M = M 239 | self.T = T 240 | self.steps = steps 241 | self.stepsize = stepsize 242 | self.sigma = sigma 243 | 244 | def __call__(self, x): 245 | print "Calling ais()" 246 | logpx = ais(self.decode, x, M=self.M, T=self.T, 247 | steps=self.steps, stepsize=self.stepsize, sigma=self.sigma, 248 | encoder=self.encode) 249 | batchsize = x.shape[0] 250 | 251 | logpx_mean = F.sum(logpx) / batchsize 252 | print "E[log p(x)] %.5f" % logpx_mean.data 253 | obj = -logpx_mean 254 | 255 | # Variance computation 256 | obj_c = logpx - F.broadcast_to(logpx_mean, logpx.shape) 257 | obj_var = F.sum(obj_c*obj_c) / (batchsize-1) 258 | reporter.report({'obj': obj, 'obj_var': obj_var}, self) 259 | 260 | return obj 261 | 262 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright (c) Microsoft Corporation. All rights reserved. 3 | # Licensed under the MIT license. 4 | 5 | import numpy as np 6 | import math 7 | import cupy 8 | import chainer 9 | import chainer.functions as F 10 | import chainer.links as L 11 | from chainer import cuda 12 | from chainer import reporter 13 | import numfun 14 | import jvi 15 | 16 | def gaussian_kl_divergence_inst(mu, ln_var): 17 | """D_{KL}(N(mu,var) | N(0,1))""" 18 | axis_sum = tuple(di for di in range(1,mu.data.ndim)) 19 | dim = np.prod([mu.data.shape[i] for i in axis_sum]) 20 | S = F.exp(ln_var) 21 | 22 | KL_sum = 0.5*(F.sum(S,axis=axis_sum) + F.sum(mu*mu,axis=axis_sum) - 23 | F.sum(ln_var,axis=axis_sum) - dim) 24 | 25 | return KL_sum 26 | 27 | def gaussian_kl_divergence(mu, ln_var): 28 | """D_{KL}(N(mu,var) | N(0,1))""" 29 | batchsize = mu.data.shape[0] 30 | S = F.exp(ln_var) 31 | D = mu.data.size 32 | 33 | KL_sum = 0.5*(F.sum(S) + F.sum(mu*mu) - F.sum(ln_var) - D) 34 | 35 | return KL_sum / batchsize 36 | 37 | def gaussian_logp01_inst(x): 38 | """log N(x ; 0, 1)""" 39 | batchsize = x.shape[0] 40 | axis_sum = tuple(di for di in range(1,x.ndim)) 41 | dim = np.prod([x.shape[i] for i in axis_sum]) 42 | 43 | logp_sum = -0.5*(F.sum(x*x,axis=axis_sum) + dim*math.log(2.0*math.pi)) 44 | 45 | return logp_sum 46 | 47 | def gaussian_logp01(x): 48 | """log N(x ; 0, 1)""" 49 | batchsize = x.shape[0] 50 | D = x.size 51 | 52 | logp_sum = -0.5*(F.sum(x*x) + D*math.log(2.0*math.pi)) 53 | 54 | return logp_sum / batchsize 55 | 56 | def gaussian_logp_inst(x, mu, ln_var): 57 | """log N(x ; mu, var)""" 58 | batchsize = mu.data.shape[0] 59 | axis_sum = tuple(di for di in range(1,mu.data.ndim)) 60 | dim = np.prod([mu.data.shape[i] for i in axis_sum]) 61 | S = F.exp(ln_var) 62 | xc = x - mu 63 | 64 | logp_sum = -0.5*(F.sum((xc*xc) / S, axis=axis_sum) + F.sum(ln_var, axis=axis_sum) 65 | + dim*math.log(2.0*math.pi)) 66 | 67 | return logp_sum 68 | 69 | def gaussian_logp(x, mu, ln_var): 70 | """log N(x ; mu, var)""" 71 | batchsize = mu.data.shape[0] 72 | #D = x.data.size 73 | D = x.size 74 | S = F.exp(ln_var) 75 | xc = x - mu 76 | 77 | logp_sum = -0.5*(F.sum((xc*xc) / S) + F.sum(ln_var) 78 | + D*math.log(2.0*math.pi)) 79 | 80 | return logp_sum / batchsize 81 | 82 | class GaussianLikelihood: 83 | def __init__(self, pmu, pln_var): 84 | self.pmu = pmu 85 | self.pln_var = pln_var 86 | 87 | def __call__(self, x): 88 | return gaussian_logp_inst(x, self.pmu, self.pln_var) 89 | 90 | class IWAEObjective(chainer.Chain): 91 | def __init__(self, encoder, decoder, num_zsamples=1): 92 | super(IWAEObjective, self).__init__( 93 | encode = encoder, 94 | decode = decoder, 95 | ) 96 | self.num_zsamples = num_zsamples 97 | 98 | def compute_elbo(self, logw): 99 | k = logw.shape[0] 100 | batchsize = logw.shape[1] 101 | # ELBO = (1/k) sum_i log w_i 102 | elbo = F.sum(logw, axis=0) / k 103 | elbo = F.sum(elbo) / batchsize 104 | 105 | return elbo 106 | 107 | # Return (num_zsamples, batchsize) 108 | def compute_logw(self, x): 109 | # Compute q(z|x) 110 | qmu, qln_var = self.encode(x) 111 | 112 | logw = list() 113 | for j in xrange(self.num_zsamples): 114 | # z ~ q(z|x) 115 | z = F.gaussian(qmu, qln_var) 116 | 117 | # Compute p(x|z) 118 | pxz = self.decode(z) 119 | 120 | logpxz = pxz(x) 121 | logpz = gaussian_logp01_inst(z) 122 | logqz = gaussian_logp_inst(z, qmu, qln_var) 123 | 124 | logwi = logpxz + logpz - logqz 125 | logw.append(logwi) 126 | 127 | logw = F.stack(logw) # (num_zsamples,batchsize) 128 | return logw 129 | 130 | def __call__(self, x): 131 | batchsize = x.shape[0] 132 | 133 | logw = self.compute_logw(x) 134 | 135 | # IWAE = log (1/k) sum_i w_i 136 | logp = F.logsumexp(logw, axis=0) - math.log(self.num_zsamples) 137 | logp_mean = F.sum(logp) / batchsize 138 | obj = -logp_mean 139 | 140 | # Variance computation 141 | obj_c = logp - F.broadcast_to(logp_mean, logp.shape) 142 | obj_var = F.sum(obj_c*obj_c) / (batchsize-1) 143 | 144 | obj_elbo = -self.compute_elbo(logw) 145 | 146 | reporter.report({'obj': obj, 'obj_var': obj_var, 'obj_elbo': obj_elbo}, self) 147 | 148 | return obj 149 | 150 | # Jackknife variational inference 151 | class JVIObjective(chainer.Chain): 152 | def __init__(self, encoder, decoder, num_zsamples=2, jvi_order=0, device=0): 153 | super(JVIObjective, self).__init__( 154 | encode = encoder, 155 | decode = decoder, 156 | ) 157 | self.num_zsamples = num_zsamples 158 | self.jvi_order = jvi_order 159 | 160 | # Pre-generate JVI matrix 161 | self.A, B = jvi.jvi_matrix(num_zsamples, jvi_order) 162 | self.A = np.reshape(self.A, (self.A.shape[0],1)) 163 | self.A = self.A.astype(np.float32) 164 | self.logB = np.log(B).T # (num_zsamples,M) 165 | self.logB = self.logB.astype(np.float32) # (num_zsamples,M) 166 | M = self.logB.shape[1] 167 | print "Using %d JVI subsets (%d z-samples, jvi order %d)" % (M, num_zsamples, jvi_order) 168 | 169 | # Copy to GPU 170 | self.A = cuda.to_gpu(self.A, device=device) 171 | self.logB = cuda.to_gpu(self.logB, device=device) 172 | 173 | def __call__(self, x): 174 | batchsize = x.shape[0] 175 | 176 | iwae = IWAEObjective(self.encode, self.decode, self.num_zsamples) 177 | logw = iwae.compute_logw(x) # (num_zsamples,batchsize) 178 | obj_elbo = -iwae.compute_elbo(logw) 179 | 180 | M = self.logB.shape[1] # number of subsets 181 | n = self.num_zsamples 182 | 183 | # (n,M,batchsize) 184 | logw = F.broadcast_to(F.reshape(logw, (n,1,batchsize)), (n,M,batchsize)) 185 | logB = F.broadcast_to(F.reshape(self.logB, (n,M,1)), (n,M,batchsize)) 186 | R = F.logsumexp(logw + logB, axis=0) # (M,batchsize) 187 | logp = F.matmul(self.A, R, transa=True) # (batchsize,) 188 | 189 | obj_c = logp - F.broadcast_to(F.mean(logp), logp.shape) 190 | obj_var = F.sum(obj_c*obj_c) / (batchsize-1) 191 | obj = -F.mean(logp) 192 | 193 | reporter.report({'obj': obj, 'obj_var': obj_var, 'obj_elbo': obj_elbo}, self) 194 | return obj 195 | 196 | class ImprovedIWAEObjective(chainer.Chain): 197 | def __init__(self, encoder, decoder, num_zsamples=1): 198 | super(ImprovedIWAEObjective, self).__init__( 199 | encode = encoder, 200 | decode = decoder, 201 | ) 202 | self.num_zsamples = num_zsamples 203 | 204 | def __call__(self, x): 205 | batchsize = x.shape[0] 206 | 207 | iwae = IWAEObjective(self.encode, self.decode, self.num_zsamples) 208 | logw = iwae.compute_logw(x) # (num_zsamples,batchsize) 209 | obj_elbo = -iwae.compute_elbo(logw) 210 | 211 | # Jackknife bias corrected logp estimate 212 | A = F.logsumexp(logw, axis=0) 213 | logp_iwae = A - math.log(self.num_zsamples) 214 | logp_iwae = F.sum(logp_iwae) / batchsize 215 | 216 | k = float(self.num_zsamples) 217 | wnorm = F.exp(logw - F.broadcast_to(A, logw.shape)) 218 | #wmax = F.max(wnorm) 219 | #print wmax 220 | #ess = F.sum(1.0 / F.sum(wnorm*wnorm, axis=0)) / batchsize 221 | #B = F.sum(F.log1p(-F.exp(logw - F.broadcast_to(A, logw.shape))), axis=0) 222 | #print logw 223 | B = F.sum(numfun.log1mexp(logw - F.broadcast_to(A, logw.shape) - 1.0e-6), axis=0) 224 | #print B 225 | logp_jk = A - ((k-1)/k)*B - k*math.log(k) + (k-1)*math.log(k-1) 226 | logp_jk_mean = F.sum(logp_jk) / batchsize 227 | obj = -logp_jk_mean 228 | correction = logp_jk_mean - logp_iwae 229 | 230 | # Variance computation 231 | obj_c = logp_jk - F.broadcast_to(logp_jk_mean, logp_jk.shape) 232 | obj_var = F.sum(obj_c*obj_c) / (batchsize-1) 233 | 234 | reporter.report({'obj': obj, 'obj_var': obj_var, 'obj_elbo': obj_elbo, 235 | 'corr': correction}, self) 236 | 237 | return obj 238 | 239 | 240 | class ELBOObjective(chainer.Chain): 241 | def __init__(self, encoder, decoder, num_zsamples=1): 242 | super(ELBOObjective, self).__init__( 243 | encode = encoder, 244 | decode = decoder, 245 | ) 246 | self.num_zsamples = num_zsamples 247 | 248 | # ELBO objective: E_{z ~ q(z|x)}[log p(x|z)] - D(q(z|x) | p(z)) 249 | def __call__(self, x): 250 | batchsize = x.shape[0] 251 | 252 | # Compute q(z|x) 253 | qmu, qln_var = self.encode(x) 254 | 255 | kl_inst = gaussian_kl_divergence_inst(qmu, qln_var) 256 | logp_inst = None 257 | self.kl = F.sum(kl_inst)/batchsize 258 | self.logp = 0 259 | for j in xrange(self.num_zsamples): 260 | # z ~ q(z|x) 261 | z = F.gaussian(qmu, qln_var) 262 | 263 | # Compute p(x|z) 264 | pxz = self.decode(z) 265 | logpxz = pxz(x) 266 | if logp_inst is None: 267 | logp_inst = logpxz 268 | else: 269 | logp_inst += logpxz 270 | 271 | # Compute objective 272 | batchsize = logpxz.shape[0] 273 | self.logp += F.sum(logpxz) / batchsize 274 | 275 | # Compute standard deviation 276 | logp_inst /= self.num_zsamples 277 | obj_inst = kl_inst - logp_inst 278 | obj_inst_mean = F.sum(obj_inst) / batchsize 279 | obj_c = obj_inst - F.broadcast_to(obj_inst_mean, obj_inst.shape) 280 | obj_var = F.sum(obj_c*obj_c)/(batchsize-1) 281 | 282 | self.logp /= self.num_zsamples 283 | self.obj = self.kl - self.logp 284 | 285 | reporter.report({'obj': self.obj, 'obj_var': obj_var, 'kl': self.kl, 'logp': self.logp}, self) 286 | 287 | return self.obj 288 | 289 | class ISObjective(chainer.Chain): 290 | def __init__(self, encoder, decoder, num_zsamples=1): 291 | super(ISObjective, self).__init__( 292 | encode = encoder, 293 | decode = decoder, 294 | ) 295 | self.num_zsamples = num_zsamples 296 | 297 | # Importance sampling estimator 298 | def __call__(self, x): 299 | # Compute q(z|x) 300 | qmu, qln_var = self.encode(x) 301 | batchsize = qmu.data.shape[0] 302 | 303 | # Perform unnormalized importance sampling 304 | logw = list() 305 | logpxz = list() 306 | for i in xrange(self.num_zsamples): 307 | # z ~ q(z|x) 308 | z = F.gaussian(qmu, qln_var) 309 | logqz = gaussian_logp_inst(z, qmu, qln_var) 310 | logpz = gaussian_logp01_inst(z) 311 | 312 | # Compute p(x|z) 313 | pxz = self.decode(z) 314 | logpxz_i = pxz(x) 315 | logpxz.append(logpxz_i) 316 | 317 | logw_i = logpz + logpxz_i - logqz 318 | logw.append(logw_i) 319 | 320 | # Self-normalize importance weights 321 | logw = F.stack(logw) # (num_zsamples,batchsize) 322 | lse = F.logsumexp(logw, axis=0) 323 | logw -= F.broadcast_to(lse, logw.shape) 324 | w = F.exp(logw) 325 | 326 | # Compute effective sample size 327 | ess = F.sum(1.0 / F.sum(w*w, axis=0)) / batchsize 328 | 329 | logpxz = F.stack(logpxz) # (num_zsamples,batchsize) 330 | 331 | # XXX: break dependency in computational graph 332 | w = chainer.Variable(w.data) 333 | obj = -F.sum(w*logpxz) / batchsize 334 | 335 | reporter.report({'obj': obj, 'ess': ess}, self) 336 | 337 | return obj 338 | 339 | --------------------------------------------------------------------------------