├── .gitignore ├── LICENSE ├── README.md ├── bdmc.py ├── cvae.py ├── datasets ├── fashion │ ├── t10k-images-idx3-ubyte.gz │ ├── t10k-labels-idx1-ubyte.gz │ ├── train-images-idx3-ubyte.gz │ └── train-labels-idx1-ubyte.gz └── mnist.pkl.tar.gz ├── loader.py ├── local_ffg.py ├── local_flow.py ├── run.py ├── utils ├── ais.py ├── approx_posts.py ├── helper.py ├── hmc.py ├── hparams.py ├── math_ops.py ├── mnist_reader.py └── simulate.py └── vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Xuechen Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # inference-suboptimality 2 | Code regarding evaluation for paper 3 | *Inference Suboptimality in Variational Autoencoders.* 4 | [[arxiv](https://arxiv.org/abs/1801.03558)] 5 | 6 | ## Dependencies 7 | * `python3` 8 | * `pytorch==0.2.0` 9 | * `tqdm` 10 | 11 | ## Training 12 | To train on MNIST and Fashion, unzip the compressed files in folder `datasets/`. 13 | 14 | `python run.py --train --dataset (--lr-schedule --warmup --early-stopping)` 15 | 16 | To train on CIFAR, set the argument for the dataset flag to `cifar`. The dataset should be 17 | downloaded automatically, if not already downloaded. 18 | 19 | ## Evaluation 20 | * IWAE: `python run.py --eval-iwae --dataset --eval-path ` 21 | * AIS: `python run.py --eval-ais --dataset --eval-path ` 22 | * Local FFG: `python local_ffg.py --dataset --eval-path ` 23 | * Local Flow: `python local_flow.py --dataset --eval-path ` 24 | * BDMC: `python bdmc.py --eval-path --n-ais-iwae --n-ais-dist ` 25 | 26 | ## Other Experiments 27 | For decoder size, flow affect amortization, test set gap and other experiments, refer to [this](https://github.com/chriscremer/Inference-Suboptimality). 28 | 29 | ## Citation 30 | If you use our code, please consider cite the following: 31 | Chris Cremer, Xuechen Li, David Duvenaud. 32 | Inference Suboptimality in Variational Autoencoders. 33 | 34 | ``` 35 | @article{cremer2018inference, 36 | title={Inference Suboptimality in Variational Autoencoders}, 37 | author={Cremer, Chris and Li, Xuechen and Duvenaud, David}, 38 | journal={ICML}, 39 | year={2018} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /bdmc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | import time 4 | import argparse 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | from torch.autograd import grad as torchgrad 9 | import torch.nn.functional as F 10 | 11 | from utils.ais import ais_trajectory 12 | from utils.simulate import simulate_data 13 | from utils.hparams import HParams 14 | from utils.math_ops import sigmoidial_schedule 15 | from utils.helper import get_model 16 | 17 | 18 | parser = argparse.ArgumentParser(description='bidirectional_mc') 19 | # action configuration flags 20 | parser.add_argument('--n-ais-iwae', '-nai', type=int, default=100, 21 | help='number of IMPORTANCE samples for AIS evaluation (default: 100). \ 22 | This is different from MC samples.') 23 | parser.add_argument('--n-ais-dist', '-nad', type=int, default=10000, 24 | help='number of distributions for AIS evaluation (default: 10000)') 25 | parser.add_argument('--no-cuda', '-nc', action='store_true', help='force not use CUDA') 26 | 27 | # model configuration flags 28 | parser.add_argument('--z-size', '-zs', type=int, default=50, 29 | help='dimensionality of latent code (default: 50)') 30 | parser.add_argument('--batch-size', '-bs', type=int, default=100, 31 | help='batch size (default: 100)') 32 | parser.add_argument('--n-batch', '-nb', type=int, default=10, 33 | help='total number of batches (default: 10)') 34 | parser.add_argument('--eval-path', '-ep', type=str, default='model.pth', 35 | help='path to load evaluation ckpt (default: model.pth)') 36 | parser.add_argument('--dataset', '-d', type=str, default='mnist', choices=['mnist', 'fashion', 'cifar'], 37 | help='dataset to train and evaluate on (default: mnist)') 38 | parser.add_argument('--wide-encoder', '-we', action='store_true', 39 | help='use wider layer (more hidden units for FC, more channels for CIFAR)') 40 | parser.add_argument('--has-flow', '-hf', action='store_true', 41 | help='use flow for training and eval') 42 | parser.add_argument('--hamiltonian-flow', '-hamil-f', action='store_true') 43 | parser.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows') 44 | 45 | args = parser.parse_args() 46 | args.cuda = not args.no_cuda and torch.cuda.is_available() 47 | 48 | 49 | def get_default_hparams(): 50 | return HParams( 51 | z_size=args.z_size, 52 | act_func=F.elu, 53 | has_flow=args.has_flow, 54 | hamiltonian_flow=args.hamiltonian_flow, 55 | n_flows=args.n_flows, 56 | wide_encoder=args.wide_encoder, 57 | cuda=args.cuda, 58 | ) 59 | 60 | 61 | def bdmc(model, loader, forward_schedule=np.linspace(0., 1., 500), n_sample=100): 62 | """Bidirectional Monte Carlo. Integrate forward and backward AIS. 63 | The backward schedule is the reverse of the forward. 64 | 65 | Args: 66 | model (vae.VAE): VAE model 67 | loader (iterator): iterator to loop over pairs of Variables; the first 68 | entry being `x`, the second being `z` sampled from the true 69 | posterior `p(z|x)` 70 | forward_schedule (list or numpy.ndarray): forward temperature schedule; 71 | backward schedule is used as its reverse 72 | n_sample: number of importance (not simple MC) sample 73 | Returns: 74 | Two lists for forward and backward bounds on batchs of data 75 | """ 76 | 77 | # iterator is exhaustable in py3, so need duplicate 78 | load, load_ = itertools.tee(loader, 2) 79 | 80 | # forward chain 81 | forward_logws = ais_trajectory( 82 | model, load, 83 | mode='forward', schedule=forward_schedule, 84 | n_sample=n_sample 85 | ) 86 | 87 | # backward chain 88 | backward_schedule = np.flip(forward_schedule, axis=0) 89 | backward_logws = ais_trajectory( 90 | model, load_, 91 | mode='backward', 92 | schedule=backward_schedule, 93 | n_sample=n_sample 94 | ) 95 | 96 | upper_bounds = [] 97 | lower_bounds = [] 98 | 99 | for i, (forward, backward) in enumerate(zip(forward_logws, backward_logws)): 100 | lower_bounds.append(forward.mean()) 101 | upper_bounds.append(backward.mean()) 102 | 103 | upper_bounds = np.mean(upper_bounds) 104 | lower_bounds = np.mean(lower_bounds) 105 | 106 | print ('Average bounds on simulated data: lower %.4f, upper %.4f' % (lower_bounds, upper_bounds)) 107 | 108 | return forward_logws, backward_logws 109 | 110 | 111 | def main(): 112 | # sanity check 113 | model = get_model(args.dataset, get_default_hparams()) 114 | model.load_state_dict(torch.load(args.eval_path)['state_dict']) 115 | model.eval() 116 | 117 | loader = simulate_data(model, batch_size=args.batch_size, n_batch=args.n_batch) 118 | schedule = sigmoidial_schedule(args.n_ais_dist) 119 | bdmc(model, loader, forward_schedule=schedule, n_sample=args.n_ais_iwae) 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /cvae.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import sys 7 | import argparse 8 | 9 | import torch 10 | import torch.utils.data 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.utils import weight_norm 15 | from torch.autograd import Variable 16 | from torch.autograd import grad as torchgrad 17 | 18 | from utils.math_ops import log_normal, log_bernoulli, log_mean_exp 19 | from utils.approx_posts import Flow 20 | 21 | 22 | class CVAE(nn.Module): 23 | """Convolutional VAE for CIFAR.""" 24 | def __init__(self, hps): 25 | super(CVAE, self).__init__() 26 | 27 | self.z_size = hps.z_size 28 | self.has_flow = hps.has_flow 29 | self.hamiltonian_flow = hps.hamiltonian_flow 30 | self.n_flows = hps.n_flows 31 | self.use_cuda = hps.cuda 32 | self.act_func = hps.act_func 33 | 34 | self._init_layers(wide_encoder=hps.wide_encoder) 35 | 36 | if self.use_cuda: 37 | self.cuda() 38 | self.dtype = torch.cuda.FloatTensor 39 | else: 40 | self.dtype = torch.FloatTensor 41 | 42 | def _init_layers(self, wide_encoder=False): 43 | 44 | if wide_encoder: 45 | init_channel = 128 46 | else: 47 | init_channel = 64 48 | 49 | # encoder 50 | self.conv1 = nn.Conv2d(3, init_channel, 4, 2) 51 | self.conv2 = nn.Conv2d(init_channel, init_channel*2, 4, 2) 52 | self.conv3 = nn.Conv2d(init_channel*2, init_channel*4, 4, 2) 53 | self.fc_enc = nn.Linear(init_channel*4*2*2, self.z_size*2) 54 | 55 | self.bn_enc1 = nn.BatchNorm2d(init_channel) 56 | self.bn_enc2 = nn.BatchNorm2d(init_channel*2) 57 | self.bn_enc3 = nn.BatchNorm2d(init_channel*4) 58 | 59 | self.x_info_layer = nn.Linear(init_channel*4*2*2, self.z_size) 60 | 61 | # decoder 62 | self.fc_dec = nn.Linear(self.z_size, 256*2*2) 63 | self.deconv1 = nn.ConvTranspose2d(256, 128, 4, 2) 64 | self.deconv2 = nn.ConvTranspose2d(128, 64, 4, 2, output_padding=1) 65 | self.deconv3 = nn.ConvTranspose2d(64, 3, 4, 2) 66 | 67 | self.bn_dec1 = nn.BatchNorm2d(128) 68 | self.bn_dec2 = nn.BatchNorm2d(64) 69 | 70 | self.decoder_layers = [] 71 | self.decoder_layers.append(self.deconv1) 72 | self.decoder_layers.append(self.deconv2) 73 | self.decoder_layers.append(self.deconv3) 74 | self.decoder_layers.append(self.fc_dec) 75 | self.decoder_layers.append(self.bn_dec1) 76 | self.decoder_layers.append(self.bn_dec2) 77 | 78 | if self.has_flow: 79 | self.q_dist = Flow(self, n_flows=self.n_flows) 80 | if self.use_cuda: 81 | self.q_dist.cuda() 82 | 83 | def encode(self, net): 84 | 85 | net = self.act_func(self.bn_enc1(self.conv1(net))) 86 | net = self.act_func(self.bn_enc2(self.conv2(net))) 87 | net = self.act_func(self.bn_enc3(self.conv3(net))) 88 | net = net.view(net.size(0), -1) 89 | x_info = self.act_func(self.x_info_layer(net)) 90 | net = self.fc_enc(net) 91 | mean, logvar = net[:, :self.z_size], net[:, self.z_size:] 92 | 93 | return mean, logvar, x_info 94 | 95 | def decode(self, net): 96 | 97 | net = self.act_func(self.fc_dec(net)) 98 | net = net.view(net.size(0), -1, 2, 2) 99 | net = self.act_func(self.bn_dec1(self.deconv1(net))) 100 | net = self.act_func(self.bn_dec2(self.deconv2(net))) 101 | logit = self.deconv3(net) 102 | 103 | return logit 104 | 105 | def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None): 106 | # grad_fn default is identity, i.e. don't use grad info 107 | eps = Variable(torch.randn(mu.size()).type(self.dtype)) 108 | z = eps.mul(logvar.mul(0.5).exp()).add(mu) 109 | logqz = log_normal(z, mu, logvar) 110 | 111 | if self.has_flow: 112 | z, logprob = self.q_dist.forward(z, grad_fn, x_info) 113 | logqz += logprob 114 | 115 | zeros = Variable(torch.zeros(z.size()).type(self.dtype)) 116 | logpz = log_normal(z, zeros, zeros) 117 | 118 | return z, logpz, logqz 119 | 120 | def forward(self, x, k=1, warmup_const=1.): 121 | 122 | x = x.repeat(k, 1, 1, 1) # for computing iwae bound 123 | mu, logvar, x_info = self.encode(x) 124 | 125 | # posterior-aware inference 126 | def U(z): 127 | logpx = log_bernoulli(self.decode(z), x) 128 | logpz = log_normal(z) 129 | return -logpx - logpz # energy as -log p(x, z) 130 | 131 | def grad_U(z): 132 | grad_outputs = torch.ones(z.size(0)).type(self.dtype) 133 | grad = torchgrad(U(z), z, grad_outputs=grad_outputs, create_graph=True)[0] 134 | # gradient clipping by norm avoid numerical issue 135 | norm = torch.sqrt(torch.norm(grad, p=2, dim=1)) 136 | grad = grad / norm.view(-1, 1) 137 | return grad.detach() 138 | 139 | if self.hamiltonian_flow: 140 | z, logpz, logqz = self.sample(mu, logvar, grad_fn=grad_U, x_info=x_info) 141 | else: 142 | z, logpz, logqz = self.sample(mu, logvar, x_info=x_info) 143 | 144 | logit = self.decode(z) 145 | logpx = log_bernoulli(logit, x) 146 | elbo = logpx + logpz - warmup_const * logqz # custom warmup 147 | # correction for Tensor.repeat 148 | elbo = log_mean_exp(elbo.view(k, -1).transpose(0, 1)) 149 | elbo = torch.mean(elbo) 150 | 151 | logpx = torch.mean(logpx) 152 | logpz = torch.mean(logpz) 153 | logqz = torch.mean(logqz) 154 | 155 | return elbo, logpx, logpz, logqz 156 | 157 | def reconstruct_img(self, x): 158 | 159 | # for visualization 160 | mu, logvar, x_info = self.encode(x) 161 | z, logpz, logqz = self.sample(mu, logvar) 162 | logit = self.decode(z) 163 | x_hat = torch.sigmoid(logit) 164 | 165 | return x_hat 166 | 167 | def freeze_decoder(self): 168 | # freeze so that decoder is not optimized 169 | for layer in self.decoder_layers: 170 | for param_name in layer._parameters: 171 | layer._parameters[param_name].requires_grad = False 172 | 173 | def unfreeze_decoder(self): 174 | # unfreeze so that decoder is optimized 175 | for layer in self.decoder_layers: 176 | for param_name in layer._parameters: 177 | layer._parameters[param_name].requires_grad = True 178 | 179 | -------------------------------------------------------------------------------- /datasets/fashion/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxuechen/inference-suboptimality/72242d23ab010df8ba59bf5d507c81ebbf464416/datasets/fashion/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /datasets/fashion/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxuechen/inference-suboptimality/72242d23ab010df8ba59bf5d507c81ebbf464416/datasets/fashion/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /datasets/fashion/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxuechen/inference-suboptimality/72242d23ab010df8ba59bf5d507c81ebbf464416/datasets/fashion/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /datasets/fashion/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxuechen/inference-suboptimality/72242d23ab010df8ba59bf5d507c81ebbf464416/datasets/fashion/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /datasets/mnist.pkl.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxuechen/inference-suboptimality/72242d23ab010df8ba59bf5d507c81ebbf464416/datasets/mnist.pkl.tar.gz -------------------------------------------------------------------------------- /loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import io 3 | import sys 4 | import os 5 | import time 6 | 7 | import torch 8 | import torchvision.datasets as datasets 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.sampler import SubsetRandomSampler 12 | from torch.autograd import Variable 13 | import collections 14 | import pickle 15 | from torch.autograd import Variable 16 | 17 | 18 | class CIFAR10: 19 | 20 | def __init__(self, 21 | part='train', 22 | batch_size=128, 23 | partial=None, 24 | binarize=True, 25 | valid_size=0.1, 26 | num_workers=4, 27 | pin_memory=False): 28 | 29 | transform_list = [transforms.ToTensor()] 30 | if binarize: 31 | transform_list.append(lambda x: x >= 0.5) 32 | transform_list.append(lambda x: x.float()) 33 | 34 | data_transform = transforms.Compose(transform_list) 35 | train_set = datasets.CIFAR10('./datasets', train=True, download=True, transform=data_transform) 36 | valid_set = datasets.CIFAR10('./datasets', train=True, download=True, transform=data_transform) 37 | test_set = datasets.CIFAR10('./datasets', train=False, download=True, transform=data_transform) 38 | 39 | num_train = len(train_set) 40 | indices = list(range(num_train)) 41 | split = int(np.floor(valid_size * num_train)) 42 | train_idx, valid_idx = indices[split:], indices[:split] 43 | 44 | self.loader = { 45 | 'train': DataLoader(train_set, 46 | batch_size=batch_size, sampler=SubsetRandomSampler(train_idx), 47 | num_workers=num_workers, pin_memory=pin_memory, shuffle=False), 48 | 'valid': DataLoader(valid_set, 49 | batch_size=batch_size, sampler=SubsetRandomSampler(valid_idx), 50 | num_workers=num_workers, pin_memory=pin_memory, shuffle=False), 51 | 'test': DataLoader(test_set, batch_size=batch_size, 52 | num_workers=num_workers, pin_memory=pin_memory, shuffle=False) 53 | }[part] 54 | 55 | self.size = len(self.loader) if partial is None else partial // batch_size 56 | self._iter = iter(self.loader) 57 | self.batch_size = batch_size 58 | self.p = 0 59 | 60 | def __iter__(self): 61 | self.p = 0 62 | self._iter = iter(self.loader) 63 | return self 64 | 65 | def __next__(self): 66 | self.p += 1 67 | if self.p > self.size: 68 | raise StopIteration 69 | return next(self._iter) 70 | 71 | # due to inconsistency between py2 and py3 72 | def next(self): 73 | return self.__next__() 74 | 75 | 76 | class Larochelle_MNIST: 77 | 78 | def __init__(self, part='train', batch_size=128, partial=1000): 79 | with open('datasets/mnist.pkl', 'rb') as f: 80 | if sys.version_info[0] < 3: 81 | mnist = pickle.load(f) 82 | else: 83 | mnist = pickle.load(f, encoding='latin1') 84 | self.data = { 85 | 'train': np.concatenate((mnist[0][0], mnist[1][0])), 86 | 'test': mnist[2][0], 87 | 'partial_train': mnist[0][0][:partial], 88 | 'partial_test': mnist[2][0][:partial], 89 | }[part] 90 | self.size = self.data.shape[0] 91 | self.batch_size = batch_size 92 | self._construct() 93 | 94 | def __iter__(self): 95 | return iter(self.batch_list) 96 | 97 | def _construct(self): 98 | self.batch_list = [] 99 | for i in range(self.size // self.batch_size): 100 | batch = self.data[self.batch_size*i:self.batch_size*(i+1)] 101 | batch = torch.from_numpy(batch) 102 | # placeholder for second entry 103 | self.batch_list.append((batch, None)) 104 | 105 | 106 | class Binarized_Omniglot: 107 | 108 | def __init__(self, part='train', batch_size=128, partial=1000): 109 | omni_raw = io.loadmat('datasets/chardata.mat') 110 | reshape_data = lambda d: d.reshape( 111 | (-1, 28, 28)).reshape((-1, 28*28), order='fortran') 112 | 113 | def static_binarize(d): 114 | ids = d < 0.5 115 | d[ids] = 0. 116 | d[~ids] = 1. 117 | 118 | train_data = reshape_data(omni_raw['data'].T.astype('float32')) 119 | test_data = reshape_data(omni_raw['testdata'].T.astype('float32')) 120 | static_binarize(train_data) 121 | static_binarize(test_data) 122 | 123 | assert train_data.shape == (24345, 784) 124 | assert test_data.shape == (8070, 784) 125 | 126 | self.data = { 127 | 'train': train_data, 128 | 'test': test_data, 129 | 'partial_train': train_data[:partial], 130 | 'partial_test': test_data[:partial], 131 | }[part] 132 | self.size = self.data.shape[0] 133 | self.batch_size = batch_size 134 | self._construct() 135 | 136 | def __iter__(self): 137 | return iter(self.batch_list) 138 | 139 | def _construct(self): 140 | self.batch_list = [] 141 | for i in range(self.size // self.batch_size): 142 | batch = self.data[self.batch_size*i:self.batch_size*(i+1)] 143 | batch = torch.from_numpy(batch) 144 | self.batch_list.append((batch, None)) 145 | 146 | 147 | class Binarized_Fashion: 148 | 149 | def __init__(self, part='train', batch_size=128, partial=1000): 150 | 151 | from utils.mnist_reader import load_mnist 152 | train_raw, _ = load_mnist('datasets/fashion', kind='train') 153 | test_raw, _ = load_mnist('datasets/fashion', kind='t10k') 154 | 155 | grey_scale = lambda x: np.float32(x / 255.) 156 | 157 | def static_binarize(d): 158 | ids = d < 0.5 159 | d[ids] = 0. 160 | d[~ids] = 1. 161 | 162 | train_data = grey_scale(train_raw) 163 | test_data = grey_scale(test_raw) 164 | 165 | static_binarize(train_data) 166 | static_binarize(test_data) 167 | 168 | assert train_data.shape == (60000, 784) 169 | assert test_data.shape == (10000, 784) 170 | 171 | self.data = { 172 | 'train': train_data[:55000], 173 | 'valid': train_data[55000:], 174 | 'test': test_data, 175 | 'partial_train': train_data[:partial], 176 | 'partial_test': test_data[:partial], 177 | }[part] 178 | self.size = self.data.shape[0] 179 | self.batch_size = batch_size 180 | self._construct() 181 | 182 | def __iter__(self): 183 | return iter(self.batch_list) 184 | 185 | def _construct(self): 186 | self.batch_list = [] 187 | for i in range(self.size // self.batch_size): 188 | batch = self.data[self.batch_size*i:self.batch_size*(i+1)] 189 | batch = torch.from_numpy(batch) 190 | self.batch_list.append((batch, None)) 191 | 192 | 193 | def get_default_mnist_loader(): 194 | 195 | kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {} 196 | train_loader = torch.utils.data.DataLoader( 197 | datasets.MNIST('./datasets', train=True, download=True, 198 | transform=transforms.ToTensor()), 199 | batch_size=128, shuffle=True, **kwargs) 200 | 201 | test_loader = torch.utils.data.DataLoader( 202 | datasets.MNIST('./datasets', train=False, 203 | transform=transforms.ToTensor()), 204 | batch_size=100, shuffle=True, **kwargs) 205 | 206 | return train_loader, test_loader 207 | 208 | 209 | def get_cifar10_loader(batch_size=100, partial=False, num=1000): 210 | 211 | if partial: 212 | train_loader = CIFAR10(part='train', batch_size=batch_size, partial=num) 213 | test_loader = CIFAR10(part='test', batch_size=4) 214 | else: 215 | train_loader = CIFAR10(part='train', batch_size=batch_size) 216 | test_loader = CIFAR10(part='valid', batch_size=4) # really validation set 217 | 218 | return train_loader, test_loader 219 | 220 | 221 | def get_Larochelle_MNIST_loader(batch_size=100, partial=False, num=1000): 222 | 223 | if partial: 224 | train_loader = Larochelle_MNIST(part='partial_train', batch_size=batch_size, partial=num) 225 | test_loader = Larochelle_MNIST(part='partial_test') 226 | else: 227 | train_loader = Larochelle_MNIST(part='train', batch_size=batch_size) 228 | test_loader = Larochelle_MNIST(part='test', batch_size=batch_size) 229 | 230 | return train_loader, test_loader 231 | 232 | 233 | def get_omniglot_loader(batch_size=100, partial=False, num=1000): 234 | 235 | if partial: 236 | train_loader = Binarized_Omniglot(part='partial_train', batch_size=batch_size, partial=num) 237 | test_loader = Binarized_Omniglot(part='partial_test') 238 | else: 239 | train_loader = Binarized_Omniglot(part='train', batch_size=batch_size) 240 | test_loader = Binarized_Omniglot(part='valid', batch_size=10) 241 | 242 | return train_loader, test_loader 243 | 244 | 245 | def get_fashion_loader(batch_size=100, partial=False, num=1000): 246 | 247 | if partial: 248 | train_loader = Binarized_Fashion(part='partial_train', batch_size=batch_size, partial=num) 249 | test_loader = Binarized_Fashion(part='partial_test', batch_size=10) 250 | else: 251 | train_loader = Binarized_Fashion(part='train', batch_size=batch_size) 252 | test_loader = Binarized_Fashion(part='valid', batch_size=10) 253 | 254 | return train_loader, test_loader 255 | 256 | 257 | if __name__ == '__main__': 258 | # sanity checking 259 | train_loader, test_loader = get_cifar10_loader() 260 | for i, (batch, _) in enumerate(train_loader): 261 | batch = Variable(batch) 262 | print (i) 263 | 264 | for i, (batch, _) in enumerate(train_loader): 265 | batch = Variable(batch) 266 | print (i) 267 | -------------------------------------------------------------------------------- /local_ffg.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import sys 7 | from tqdm import tqdm 8 | import argparse 9 | import numpy as np 10 | 11 | import torch 12 | import torch.utils.data 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | 18 | from utils.math_ops import log_bernoulli, log_normal, log_mean_exp, safe_repeat 19 | from utils.hparams import HParams 20 | from utils.helper import get_model, get_loaders 21 | 22 | 23 | parser = argparse.ArgumentParser(description='local_factorized_gaussian') 24 | # action configuration flags 25 | parser.add_argument('--no-cuda', '-nc', action='store_true') 26 | parser.add_argument('--debug', action='store_true', help='debug mode') 27 | 28 | # model configuration flags 29 | parser.add_argument('--z-size', '-zs', type=int, default=50) 30 | parser.add_argument('--batch-size', '-bs', type=int, default=100) 31 | parser.add_argument('--eval-path', '-ep', type=str, default='model.pth', 32 | help='path to load evaluation ckpt (default: model.pth)') 33 | parser.add_argument('--dataset', '-d', type=str, default='mnist', 34 | choices=['mnist', 'fashion', 'cifar'], 35 | help='dataset to train and evaluate on (default: mnist)') 36 | parser.add_argument('--has-flow', '-hf', action='store_true', help='inference uses FLOW') 37 | parser.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows') 38 | parser.add_argument('--wide-encoder', '-we', action='store_true', 39 | help='use wider layer (more hidden units for FC, more channels for CIFAR)') 40 | 41 | args = parser.parse_args() 42 | args.cuda = not args.no_cuda and torch.cuda.is_available() 43 | 44 | 45 | def get_default_hparams(): 46 | return HParams( 47 | z_size=args.z_size, 48 | act_func=F.elu, 49 | has_flow=args.has_flow, 50 | n_flows=args.n_flows, 51 | wide_encoder=args.wide_encoder, 52 | cuda=args.cuda, 53 | hamiltonian_flow=False 54 | ) 55 | 56 | 57 | def optimize_local_gaussian( 58 | log_likelihood, 59 | model, 60 | data_var, 61 | k=100, 62 | check_every=100, 63 | sentinel_thres=10, 64 | debug=False 65 | ): 66 | """data_var should be (cuda) variable.""" 67 | 68 | B = data_var.size()[0] 69 | z_size = model.z_size 70 | 71 | data_var = safe_repeat(data_var, k) 72 | zeros = Variable(torch.zeros(B*k, z_size).type(model.dtype)) 73 | mean = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True) 74 | logvar = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True) 75 | 76 | optimizer = optim.Adam([mean, logvar], lr=1e-3) 77 | best_avg, sentinel, prev_seq = 999999, 0, [] 78 | 79 | # perform local opt 80 | time_ = time.time() 81 | for epoch in range(1, 999999): 82 | 83 | eps = Variable(torch.FloatTensor(mean.size()).normal_().type(model.dtype)) 84 | z = eps.mul(logvar.mul(0.5).exp_()).add_(mean) 85 | x_logits = model.decode(z) 86 | 87 | logpz = log_normal(z, zeros, zeros) 88 | logqz = log_normal(z, mean, logvar) 89 | logpx = log_likelihood(x_logits, data_var) 90 | 91 | optimizer.zero_grad() 92 | loss = -torch.mean(logpx + logpz - logqz) 93 | loss_np = loss.data.cpu().numpy() 94 | loss.backward() 95 | optimizer.step() 96 | 97 | prev_seq.append(loss_np) 98 | if epoch % check_every == 0: 99 | last_avg = np.mean(prev_seq) 100 | if debug: # debugging helper 101 | sys.stderr.write( 102 | 'Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n' % \ 103 | (epoch, time.time()-time_, -last_avg, -best_avg) 104 | ) 105 | if last_avg < best_avg: 106 | sentinel, best_avg = 0, last_avg 107 | else: 108 | sentinel += 1 109 | if sentinel > sentinel_thres: 110 | break 111 | prev_seq = [] 112 | time_ = time.time() 113 | 114 | # evaluation 115 | eps = Variable(torch.FloatTensor(B*k, z_size).normal_().type(model.dtype)) 116 | z = eps.mul(logvar.mul(0.5).exp_()).add_(mean) 117 | 118 | logpz = log_normal(z, zeros, zeros) 119 | logqz = log_normal(z, mean, logvar) 120 | logpx = log_likelihood(model.decode(z), data_var) 121 | elbo = logpx + logpz - logqz 122 | 123 | vae_elbo = torch.mean(elbo) 124 | iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1))) 125 | 126 | return vae_elbo.data[0], iwae_elbo.data[0] 127 | 128 | 129 | def main(): 130 | train_loader, test_loader = get_loaders( 131 | dataset=args.dataset, 132 | evaluate=True, batch_size=args.batch_size 133 | ) 134 | model = get_model(args.dataset, get_default_hparams()) 135 | model.load_state_dict(torch.load(args.eval_path)['state_dict']) 136 | model.eval() 137 | 138 | vae_record, iwae_record = [], [] 139 | time_ = time.time() 140 | for i, (batch, _) in tqdm(enumerate(train_loader)): 141 | batch = Variable(batch.type(model.dtype)) 142 | elbo, iwae = optimize_local_gaussian(log_bernoulli, model, batch, debug=args.debug) 143 | vae_record.append(elbo) 144 | iwae_record.append(iwae) 145 | print ('Local opt w/ ffg, batch %d, time elapse %.4f, ELBO %.4f, IWAE %.4f' % \ 146 | (i+1, time.time()-time_, elbo, iwae)) 147 | print ('mean of ELBO so far %.4f, mean of IWAE so far %.4f' % \ 148 | (np.nanmean(vae_record), np.nanmean(iwae_record))) 149 | time_ = time.time() 150 | 151 | print ('Finishing...') 152 | print ('Average ELBO %.4f, IWAE %.4f' % (np.nanmean(vae_record), np.nanmean(iwae_record))) 153 | 154 | 155 | if __name__ == '__main__': 156 | main() 157 | -------------------------------------------------------------------------------- /local_flow.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import sys 7 | from tqdm import tqdm 8 | import argparse 9 | import numpy as np 10 | 11 | import torch 12 | import torch.utils.data 13 | import torch.optim as optim 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch.autograd import Variable 17 | 18 | from utils.math_ops import log_bernoulli, log_normal, log_mean_exp, safe_repeat 19 | from utils.hparams import HParams 20 | 21 | from loader import get_Larochelle_MNIST_loader, get_fashion_loader, get_cifar10_loader 22 | from vae import VAE 23 | from cvae import CVAE 24 | 25 | 26 | parser = argparse.ArgumentParser(description='local_expressive') 27 | # action configuration flags 28 | parser.add_argument('--no-cuda', '-nc', action='store_true') 29 | parser.add_argument('--debug', action='store_true', help='debug mode') 30 | 31 | # model configuration flags 32 | parser.add_argument('--z-size', '-zs', type=int, default=50) 33 | parser.add_argument('--batch-size', '-bs', type=int, default=100) 34 | parser.add_argument('--eval-path', '-ep', type=str, default='model.pth', 35 | help='path to load evaluation ckpt (default: model.pth)') 36 | parser.add_argument('--dataset', '-d', type=str, default='mnist', 37 | choices=['mnist', 'fashion', 'cifar'], 38 | help='dataset to train and evaluate on (default: mnist)') 39 | parser.add_argument('--has-flow', '-hf', action='store_true', help='inference uses FLOW') 40 | parser.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows') 41 | parser.add_argument('--wide-encoder', '-we', action='store_true', 42 | help='use wider layer (more hidden units for FC, more channels for CIFAR)') 43 | 44 | args = parser.parse_args() 45 | args.cuda = not args.no_cuda and torch.cuda.is_available() 46 | 47 | 48 | def get_default_hparams(): 49 | return HParams( 50 | z_size=args.z_size, 51 | act_func=F.elu, 52 | has_flow=args.has_flow, 53 | n_flows=args.n_flows, 54 | wide_encoder=args.wide_encoder, 55 | cuda=args.cuda, 56 | hamiltonian_flow=False 57 | ) 58 | 59 | 60 | def optimize_local_expressive( 61 | log_likelihood, 62 | model, 63 | data_var, 64 | k=100, 65 | check_every=100, 66 | sentinel_thres=10, 67 | n_flows=2, 68 | debug=False 69 | ): 70 | """data_var should be (cuda) variable.""" 71 | 72 | def log_joint(x_logits, x, z): 73 | """log p(x,z)""" 74 | zeros = Variable(torch.zeros(z.size()).type(model.dtype)) 75 | logpz = log_normal(z, zeros, zeros) 76 | logpx = log_likelihood(x_logits, x) 77 | 78 | return logpx + logpz 79 | 80 | def norm_flow(params, z, v): 81 | 82 | h = F.tanh(params[0][0](z)) 83 | mew_ = params[0][1](h) 84 | logit_ = params[0][2](h) 85 | sig_ = F.sigmoid(logit_) 86 | 87 | v = v*sig_ + mew_ 88 | # numerically stable: log (sigmoid(logit)) = logit - softplus(logit) 89 | logdet_v = torch.sum(logit_ - F.softplus(logit_), 1) 90 | 91 | h = F.tanh(params[1][0](v)) 92 | mew_ = params[1][1](h) 93 | logit_ = params[1][2](h) 94 | sig_ = F.sigmoid(logit_) 95 | 96 | z = z*sig_ + mew_ 97 | logdet_z = torch.sum(logit_ - F.softplus(logit_), 1) 98 | 99 | logdet = logdet_v + logdet_z 100 | 101 | return z, v, logdet 102 | 103 | def sample(mean_v0, logvar_v0): 104 | 105 | B = mean_v0.size()[0] 106 | eps = Variable(torch.FloatTensor(B, z_size).normal_().type(model.dtype)) 107 | v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0 108 | logqv0 = log_normal(v0, mean_v0, logvar_v0) 109 | 110 | out = v0 111 | for i in range(len(qz_weights)-1): 112 | out = act_func(qz_weights[i](out)) 113 | out = qz_weights[-1](out) 114 | mean_z0, logvar_z0 = out[:, :z_size], out[:, z_size:] 115 | 116 | eps = Variable(torch.FloatTensor(B, z_size).normal_().type(model.dtype)) 117 | z0 = eps.mul(logvar_z0.mul(0.5).exp_()) + mean_z0 118 | logqz0 = log_normal(z0, mean_z0, logvar_z0) 119 | 120 | zT, vT = z0, v0 121 | logdetsum = 0. 122 | for i in range(n_flows): 123 | zT, vT, logdet = norm_flow(params[i], zT, vT) 124 | logdetsum += logdet 125 | 126 | # reverse model, r(vT|x,zT) 127 | out = zT 128 | for i in range(len(rv_weights)-1): 129 | out = act_func(rv_weights[i](out)) 130 | out = rv_weights[-1](out) 131 | mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:] 132 | logrvT = log_normal(vT, mean_vT, logvar_vT) 133 | 134 | logq = logqz0 + logqv0 - logdetsum - logrvT 135 | 136 | return zT, logq 137 | 138 | def get_params(): 139 | 140 | all_params = [] 141 | 142 | mean_v = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True) 143 | logvar_v = Variable(torch.zeros(B*k, z_size).type(model.dtype), requires_grad=True) 144 | 145 | all_params.append(mean_v) 146 | all_params.append(logvar_v) 147 | 148 | qz_weights = [] # q(z|x,v) 149 | for ins, outs in zip(qz_arch[:-1], qz_arch[1:]): 150 | cur_layer = nn.Linear(ins, outs) 151 | if args.cuda: 152 | cur_layer.cuda() 153 | qz_weights.append(cur_layer) 154 | all_params.append(cur_layer.weight) 155 | 156 | rv_weights = [] # r(v|x,z) 157 | for ins, outs in zip(rv_arch[:-1], rv_arch[1:]): 158 | cur_layer = nn.Linear(ins, outs) 159 | if args.cuda: 160 | cur_layer.cuda() 161 | rv_weights.append(cur_layer) 162 | all_params.append(cur_layer.weight) 163 | 164 | params = [] 165 | for i in range(n_flows): 166 | layers = [ 167 | [nn.Linear(z_size, h_s), 168 | nn.Linear(h_s, z_size), 169 | nn.Linear(h_s, z_size)], 170 | [nn.Linear(z_size, h_s), 171 | nn.Linear(h_s, z_size), 172 | nn.Linear(h_s, z_size)], 173 | ] 174 | 175 | params.append(layers) 176 | 177 | for sublist in layers: 178 | for item in sublist: 179 | all_params.append(item.weight) 180 | if args.cuda: 181 | item.cuda() 182 | 183 | return (mean_v, logvar_v), all_params, params, qz_weights, rv_weights 184 | 185 | # the real shit 186 | B = data_var.size(0) 187 | z_size = args.z_size 188 | qz_arch = rv_arch = [args.z_size, 200, 200, args.z_size*2] 189 | h_s = 200 190 | act_func = F.elu 191 | 192 | data_var = safe_repeat(data_var, k) 193 | (mean_v, logvar_v), all_params, params, qz_weights, rv_weights = get_params() 194 | 195 | # tile input for IS 196 | optimizer = optim.Adam(all_params, lr=1e-3) 197 | best_avg, sentinel, prev_seq = 999999, 0, [] 198 | 199 | # perform local opt 200 | time_ = time.time() 201 | for epoch in range(1, 999999): 202 | z, logqz = sample(mean_v, logvar_v) 203 | x_logits = model.decode(z) 204 | logpxz = log_joint(x_logits, data_var, z) 205 | 206 | optimizer.zero_grad() 207 | loss = -torch.mean(logpxz - logqz) 208 | loss_np = loss.data.cpu().numpy() 209 | loss.backward() 210 | optimizer.step() 211 | 212 | prev_seq.append(loss_np) 213 | if epoch % check_every == 0: 214 | last_avg = np.mean(prev_seq) 215 | if debug: # debugging helper 216 | sys.stderr.write( 217 | 'Epoch %d, time elapse %.4f, last avg %.4f, prev best %.4f\n' % \ 218 | (epoch, time.time()-time_, -last_avg, -best_avg) 219 | ) 220 | if last_avg < best_avg: 221 | sentinel, best_avg = 0, last_avg 222 | else: 223 | sentinel += 1 224 | if sentinel > sentinel_thres: 225 | break 226 | 227 | prev_seq = [] 228 | time_ = time.time() 229 | 230 | # evaluation 231 | z, logqz = sample(mean_v, logvar_v) 232 | x_logits = model.decode(z) 233 | logpxz = log_joint(x_logits, data_var, z) 234 | elbo = logpxz - logqz 235 | 236 | vae_elbo = torch.mean(elbo) 237 | iwae_elbo = torch.mean(log_mean_exp(elbo.view(k, -1).transpose(0, 1))) 238 | 239 | return vae_elbo.data[0], iwae_elbo.data[0] 240 | 241 | 242 | def main(): 243 | train_loader, test_loader = get_loaders( 244 | dataset=args.dataset, 245 | evaluate=True, batch_size=1 246 | ) 247 | model = get_model(args.dataset, get_default_hparams()) 248 | model.load_state_dict(torch.load(args.eval_path)['state_dict']) 249 | model.eval() 250 | 251 | vae_record, iwae_record = [], [] 252 | time_ = time.time() 253 | for i, (batch, _) in tqdm(enumerate(train_loader)): 254 | batch = Variable(batch.type(model.dtype)) 255 | elbo, iwae = optimize_local_expressive( 256 | log_bernoulli, 257 | model, 258 | batch, 259 | n_flows=args.n_flows, debug=args.debug 260 | ) 261 | vae_record.append(elbo) 262 | iwae_record.append(iwae) 263 | print ('Local opt w/ flow, batch %d, time elapse %.4f, ELBO %.4f, IWAE %.4f' % \ 264 | (i+1, time.time()-time_, elbo, iwae)) 265 | print ('mean of ELBO so far %.4f, mean of IWAE so far %.4f' % \ 266 | (np.nanmean(vae_record), np.nanmean(iwae_record))) 267 | time_ = time.time() 268 | 269 | print ('Finishing...') 270 | print ('Average ELBO %.4f, IWAE %.4f' % (np.nanmean(vae_record), np.nanmean(iwae_record))) 271 | 272 | 273 | if __name__ == '__main__': 274 | main() 275 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import sys 7 | import os 8 | import argparse 9 | from tqdm import tqdm 10 | import numpy as np 11 | np.set_printoptions(threshold=sys.maxsize) 12 | import matplotlib.pyplot as plt 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | from torch.autograd import Variable 18 | from torchvision import datasets, transforms 19 | import visdom 20 | 21 | from utils.hparams import HParams 22 | from utils.ais import ais_trajectory 23 | from utils.math_ops import sigmoidial_schedule, linear_schedule 24 | from utils.helper import get_model, get_loaders 25 | 26 | 27 | parser = argparse.ArgumentParser(description='VAE') 28 | # action configuration flags 29 | parser.add_argument('--train', '-t', action='store_true') 30 | parser.add_argument('--load-path', '-lp', type=str, default='NA', 31 | help='path to load checkpoint to retrain') 32 | parser.add_argument('--load-epoch', '-le', type=int, default=0, 33 | help='epoch number to start recording when retraining') 34 | parser.add_argument('--display-epoch', '-de', type=int, default=10, 35 | help='print status every so many epochs (default: 10)') 36 | parser.add_argument('--eval-iwae', '-ei', action='store_true') 37 | parser.add_argument('--eval-ais', '-ea', action='store_true') 38 | parser.add_argument('--n-iwae', '-ni', type=int, default=5000, 39 | help='number of samples for IWAE evaluation (default: 5000)') 40 | parser.add_argument('--n-ais-iwae', '-nai', type=int, default=100, 41 | help='number of IMPORTANCE samples for AIS evaluation (default: 100). \ 42 | This is different from MC samples.') 43 | parser.add_argument('--n-ais-dist', '-nad', type=int, default=10000, 44 | help='number of distributions for AIS evaluation (default: 10000)') 45 | parser.add_argument('--ais-schedule', type=str, default='linear', help='schedule for AIS') 46 | 47 | parser.add_argument('--no-cuda', '-nc', action='store_true', help='force not use CUDA') 48 | parser.add_argument('--visdom', '-v', action='store_true', help='visualize samples') 49 | parser.add_argument('--port', '-p', type=int, default=8097, help='port for visdom') 50 | parser.add_argument('--save-visdom', default='test', help='visdom save path') 51 | parser.add_argument('--encoder-more', action='store_true', help='train the encoder more (5 vs 1)') 52 | parser.add_argument('--early-stopping', '-es', action='store_true', help='apply early stopping') 53 | parser.add_argument('--epochs', '-e', type=int, default=3280, 54 | help='total num of epochs for training (default: 3280)') 55 | parser.add_argument('--lr-schedule', '-lrs', action='store_true', 56 | help='apply learning rate schedule') 57 | 58 | # model configuration flags 59 | parser.add_argument('--z-size', '-zs', type=int, default=50, 60 | help='dimensionality of latent code (default: 50)') 61 | parser.add_argument('--batch-size', '-bs', type=int, default=100, 62 | help='batch size (default: 100)') 63 | parser.add_argument('--save-name', '-sn', type=str, default='model.pth', 64 | help='name to save trained ckpt (default: model.pth)') 65 | parser.add_argument('--eval-path', '-ep', type=str, default='model.pth', 66 | help='path to load evaluation ckpt (default: model.pth)') 67 | parser.add_argument('--dataset', '-d', type=str, default='mnist', 68 | choices=['mnist', 'fashion', 'cifar'], 69 | help='dataset to train and evaluate on (default: mnist)') 70 | parser.add_argument('--wide-encoder', '-we', action='store_true', 71 | help='use wider layer (more hidden units for FC, more channels for CIFAR)') 72 | parser.add_argument('--has-flow', '-hf', action='store_true', 73 | help='use flow for training and eval') 74 | parser.add_argument('--hamiltonian-flow', '-hamil-f', action='store_true') 75 | parser.add_argument('--n-flows', '-nf', type=int, default=2, help='number of flows') 76 | parser.add_argument('--warmup', '-w', action='store_true', 77 | help='apply warmup during training') 78 | 79 | args = parser.parse_args() 80 | args.cuda = not args.no_cuda and torch.cuda.is_available() 81 | 82 | 83 | def get_default_hparams(): 84 | return HParams( 85 | z_size=args.z_size, 86 | act_func=F.elu, 87 | has_flow=args.has_flow, 88 | hamiltonian_flow=args.hamiltonian_flow, 89 | n_flows=args.n_flows, 90 | wide_encoder=args.wide_encoder, 91 | cuda=args.cuda, 92 | ) 93 | 94 | 95 | def train( 96 | model, 97 | train_loader, 98 | test_loader, 99 | k_train=1, # num iwae sample for training 100 | k_eval=1, # num iwae sample for eval 101 | epochs=3280, 102 | display_epoch=10, 103 | lr_schedule=True, 104 | warmup=True, 105 | warmup_thres=None, 106 | encoder_more=False, 107 | checkpoints=None, 108 | early_stopping=False, 109 | save=True, 110 | save_path='checkpoints/mnist/', 111 | patience=10 # for early-stopping 112 | ): 113 | print('Training') 114 | 115 | if args.load_path != 'NA': 116 | f = args.load_path 117 | model.load_state_dict(torch.load(f)['state_dict']) 118 | 119 | # default warmup schedule 120 | if warmup_thres is None: 121 | if 'cifar' in save_path: 122 | warmup_thres = 50. 123 | elif 'mnist' in save_path or 'fashion' in save_path: 124 | warmup_thres = 400. 125 | 126 | if checkpoints is None: # save a checkpoint every display_epoch 127 | checkpoints = [1] + list(range(0, 3280, display_epoch))[1:] + [3280] 128 | 129 | time_ = time.time() 130 | 131 | if lr_schedule: 132 | current_lr = 1e-3 133 | pow = 0 134 | epoch_elapsed = 0 135 | # pth default: beta_1 = .9, beta_2 = .999, eps = 1e-8 136 | optimizer = optim.Adam(model.parameters(), lr=current_lr, eps=1e-4) 137 | else: 138 | optimizer = optim.Adam(model.parameters(), lr=1e-4, eps=1e-4) 139 | 140 | num_worse = 0 # compare against `patience` for early-stopping 141 | prev_valid_err = None 142 | 143 | for epoch in tqdm(range(1, epochs+1)): 144 | warmup_const = min(1., epoch / warmup_thres) if warmup else 1. 145 | # lr schedule from IWAE: https://arxiv.org/pdf/1509.00519.pdf 146 | if lr_schedule: 147 | if epoch_elapsed >= 3 ** pow: 148 | current_lr *= 10. ** (-1. / 7.) 149 | pow += 1 150 | epoch_elapsed = 0 151 | # correct way to do lr decay; also possible w/ `torch.optim.lr_scheduler` 152 | for param_group in optimizer.param_groups: 153 | param_group['lr'] = current_lr 154 | epoch_elapsed += 1 155 | 156 | model.train() # crucial for BN to work properly 157 | for _, (batch, _) in enumerate(train_loader): 158 | batch = Variable(batch) 159 | if args.cuda: 160 | batch = batch.cuda() 161 | 162 | # train the encoder more 163 | if encoder_more: 164 | model.freeze_decoder() 165 | for _ in range(10): 166 | optimizer.zero_grad() 167 | elbo, _, _, _ = model.forward(batch, k_train, warmup_const) 168 | loss = -elbo 169 | loss.backward() 170 | optimizer.step() 171 | model.unfreeze_decoder() 172 | 173 | optimizer.zero_grad() 174 | elbo, _, _, _ = model.forward(batch, k_train, warmup_const) 175 | loss = -elbo 176 | loss.backward() 177 | optimizer.step() 178 | 179 | if epoch % display_epoch == 0: 180 | model.eval() # crucial for BN to work properly 181 | 182 | train_logpx, test_logpx = [], [] 183 | train_logpz, test_logpz = [], [] 184 | train_logqz, test_logqz = [], [] 185 | train_stats, test_stats = [], [] 186 | for _, (batch, _) in enumerate(train_loader): 187 | batch = Variable(batch) 188 | if args.cuda: 189 | batch = batch.cuda() 190 | elbo, logpx, logpz, logqz = model(batch, k=1) 191 | train_stats.append(elbo.data[0]) 192 | train_logpx.append(logpx.data[0]) 193 | train_logpz.append(logpz.data[0]) 194 | train_logqz.append(logqz.data[0]) 195 | 196 | for _, (batch, _) in enumerate(test_loader): 197 | batch = Variable(batch) 198 | if args.cuda: 199 | batch = batch.cuda() 200 | # early stopping with iwae bound 201 | elbo, logpx, logpz, logqz = model(batch, k=k_eval) 202 | test_stats.append(elbo.data[0]) 203 | test_logpx.append(logpx.data[0]) 204 | test_logpz.append(logpz.data[0]) 205 | test_logqz.append(logqz.data[0]) 206 | print ( 207 | 'Train Epoch: [{}/{}]'.format(epoch, epochs), 208 | 'Train set ELBO {:.4f}'.format(np.mean(train_stats)), 209 | 'Test/Validation set IWAE {:.4f}'.format(np.mean(test_stats)), 210 | 'Time: {:.2f}'.format(time.time()-time_), 211 | ) 212 | time_ = time.time() 213 | 214 | if early_stopping: 215 | curr_valid_err = np.mean(test_stats) 216 | 217 | if prev_valid_err is None: # don't have history yet 218 | prev_valid_err = curr_valid_err 219 | elif curr_valid_err >= prev_valid_err: # performance improved 220 | prev_valid_err = curr_valid_err 221 | num_worse = 0 222 | else: 223 | num_worse += 1 224 | 225 | if num_worse >= patience: 226 | break 227 | 228 | if save and epoch in checkpoints: 229 | torch.save({ 230 | 'epoch': epochs + args.load_epoch, 231 | 'state_dict': model.state_dict(), 232 | }, '%s%d_%s' % (save_path, epoch + args.load_epoch, args.save_name)) 233 | 234 | 235 | def test_iwae( 236 | model, 237 | loader, 238 | k=5000, 239 | f='model.pth', 240 | print_res=True 241 | ): 242 | print('Testing with %d importance samples' % k) 243 | model.load_state_dict(torch.load(f)['state_dict']) 244 | model.eval() 245 | time_ = time.time() 246 | elbos = [] 247 | for i, (batch, _) in enumerate(loader): 248 | batch = Variable(batch) 249 | if args.cuda: 250 | batch = batch.cuda() 251 | elbo, logpx, logpz, logqz = model(batch, k=k) 252 | elbos.append(elbo.data[0]) 253 | 254 | mean_ = np.mean(elbos) 255 | if print_res: 256 | print(mean_, 'T:', time.time()-time_) 257 | return mean_ 258 | 259 | 260 | def run(): 261 | train_loader, test_loader = get_loaders( 262 | dataset=args.dataset, 263 | evaluate=args.eval_iwae or args.eval_ais, 264 | batch_size=args.batch_size 265 | ) 266 | model = get_model(args.dataset, get_default_hparams()) 267 | 268 | if args.train: 269 | save_path = 'checkpoints/%s/%s/%s%s/' % ( 270 | args.dataset, 271 | 'warmup' if args.warmup else 'no_warmup', 272 | 'wide_' if args.wide_encoder else '', 273 | 'hamiltonian_flow' if args.hamiltonian_flow else 274 | 'flow' if args.has_flow else 'ffg' 275 | ) 276 | if not os.path.exists(save_path): 277 | os.makedirs(save_path) 278 | 279 | train( 280 | model, train_loader, test_loader, 281 | display_epoch=args.display_epoch, epochs=args.epochs, 282 | lr_schedule=args.lr_schedule, 283 | warmup=args.warmup, 284 | early_stopping=args.early_stopping, 285 | encoder_more=args.encoder_more, 286 | save=True, save_path=save_path 287 | ) 288 | 289 | if args.visdom: 290 | vis = visdom.Visdom(env=args.save, port=args.port) 291 | model.load_state_dict(torch.load(args.eval_path)['state_dict']) 292 | 293 | # plot original images 294 | batch, _ = train_loader.next() 295 | images = list(batch.numpy()) 296 | win_samples = vis.images(images, 10, 2, opts={'caption': 'original images'}, win=None) 297 | 298 | # plot reconstructions 299 | batch = Variable(batch.type(model.dtype)) 300 | reconstruction = model.reconstruct_img(batch) 301 | images = list(reconstruction.data.cpu().numpy()) 302 | win_samples = vis.images(images, 10, 2, opts={'caption': 'reconstruction'}, win=None) 303 | 304 | if args.eval_iwae: 305 | # VAE bounds computed w/ 100 MC samples to reduce variance 306 | train_res, test_res = [], [] 307 | for _ in range(100): 308 | test_iwae(model, train_loader, k=1, f=args.eval_path) 309 | test_iwae(model, test_loader, k=1, f=args.eval_path) 310 | train_res.append(train_res) 311 | test_res.append(test_res) 312 | 313 | print ('Training set VAE ELBO w/ 100 MC samples: %.4f' % np.mean(train_res)) 314 | print ('Test set VAE ELBO w/ 100 MC samples: %.4f' % np.mean(test_res)) 315 | 316 | # IWAE bounds 317 | test_iwae(model, train_loader, k=args.n_iwae, f=args.eval_path) 318 | test_iwae(model, test_loader, k=args.n_iwae, f=args.eval_path) 319 | 320 | if args.eval_ais: 321 | model.load_state_dict(torch.load(args.eval_path)['state_dict']) 322 | schedule_fn = linear_schedule if args.ais_schedule == 'linear' else sigmoidial_schedule 323 | schedule = schedule_fn(args.n_ais_dist) 324 | ais_trajectory( 325 | model, train_loader, 326 | mode='forward', schedule=schedule, n_sample=args.n_ais_iwae 327 | ) 328 | 329 | 330 | if __name__ == '__main__': 331 | run() 332 | -------------------------------------------------------------------------------- /utils/ais.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import os 7 | import math 8 | import torch 9 | import numpy as np 10 | import time 11 | from tqdm import tqdm 12 | 13 | from torch.autograd import Variable 14 | from torch.autograd import grad as torchgrad 15 | 16 | from utils.math_ops import log_normal, log_bernoulli, log_mean_exp, safe_repeat 17 | from utils.hmc import hmc_trajectory, accept_reject 18 | 19 | 20 | def ais_trajectory( 21 | model, 22 | loader, 23 | mode='forward', 24 | schedule=np.linspace(0., 1., 500), 25 | n_sample=100 26 | ): 27 | """Compute annealed importance sampling trajectories for a batch of data. 28 | Could be used for *both* forward and reverse chain in bidirectional Monte Carlo 29 | (default: forward chain with linear schedule). 30 | 31 | Args: 32 | model (vae.VAE): VAE model 33 | loader (iterator): iterator that returns pairs, with first component being `x`, 34 | second would be `z` or label (will not be used) 35 | mode (string): indicate forward/backward chain; must be either `forward` or 36 | 'backward' schedule (list or 1D np.ndarray): temperature schedule, 37 | i.e. `p(z)p(x|z)^t`; foward chain has increasing values, whereas 38 | backward has decreasing values 39 | n_sample (int): number of importance samples (i.e. number of parallel chains 40 | for each datapoint) 41 | 42 | Returns: 43 | A list where each element is a torch.autograd.Variable that contains the 44 | log importance weights for a single batch of data 45 | """ 46 | 47 | assert mode == 'forward' or mode == 'backward', 'Should have forward/backward mode' 48 | 49 | def log_f_i(z, data, t, log_likelihood_fn=log_bernoulli): 50 | """Unnormalized density for intermediate distribution `f_i`: 51 | f_i = p(z)^(1-t) p(x,z)^(t) = p(z) p(x|z)^t 52 | => log f_i = log p(z) + t * log p(x|z) 53 | """ 54 | zeros = Variable(torch.zeros(B, z_size).type(mdtype)) 55 | log_prior = log_normal(z, zeros, zeros) 56 | log_likelihood = log_likelihood_fn(model.decode(z), data) 57 | 58 | return log_prior + log_likelihood.mul_(t) 59 | 60 | model.eval() 61 | 62 | # shorter aliases 63 | z_size = model.z_size 64 | mdtype = model.dtype 65 | 66 | _time = time.time() 67 | logws = [] # for output 68 | 69 | print ('In %s mode' % mode) 70 | 71 | for i, (batch, post_z) in enumerate(loader): 72 | 73 | B = batch.size(0) * n_sample 74 | batch = Variable(batch.type(mdtype)) 75 | batch = safe_repeat(batch, n_sample) 76 | 77 | # batch of step sizes, one for each chain 78 | epsilon = Variable(torch.ones(B).type(model.dtype)).mul_(0.01) 79 | # accept/reject history for tuning step size 80 | accept_hist = Variable(torch.zeros(B).type(model.dtype)) 81 | # record log importance weight; volatile=True reduces memory greatly 82 | logw = Variable(torch.zeros(B).type(mdtype), volatile=True) 83 | 84 | # initial sample of z 85 | if mode == 'forward': 86 | current_z = Variable(torch.randn(B, z_size).type(mdtype), requires_grad=True) 87 | else: 88 | current_z = Variable(safe_repeat(post_z, n_sample).type(mdtype), requires_grad=True) 89 | 90 | for j, (t0, t1) in tqdm(enumerate(zip(schedule[:-1], schedule[1:]), 1)): 91 | # update log importance weight 92 | log_int_1 = log_f_i(current_z, batch, t0) 93 | log_int_2 = log_f_i(current_z, batch, t1) 94 | logw.add_(log_int_2 - log_int_1) 95 | 96 | # resample speed 97 | current_v = Variable(torch.randn(current_z.size()).type(mdtype)) 98 | 99 | def U(z): 100 | return -log_f_i(z, batch, t1) 101 | 102 | def grad_U(z): 103 | # grad w.r.t. outputs; mandatory in this case 104 | grad_outputs = torch.ones(B).type(mdtype) 105 | # torch.autograd.grad default returns volatile 106 | grad = torchgrad(U(z), z, grad_outputs=grad_outputs)[0] 107 | # avoid humongous gradients 108 | grad = torch.clamp(grad, -10000, 10000) 109 | # needs variable wrapper to make differentiable 110 | grad = Variable(grad.data, requires_grad=True) 111 | return grad 112 | 113 | def normalized_kinetic(v): 114 | zeros = Variable(torch.zeros(B, z_size).type(mdtype)) 115 | # this is superior to the unnormalized version 116 | return -log_normal(v, zeros, zeros) 117 | 118 | z, v = hmc_trajectory(current_z, current_v, U, grad_U, epsilon) 119 | 120 | # accept-reject step 121 | current_z, epsilon, accept_hist = accept_reject( 122 | current_z, current_v, 123 | z, v, 124 | epsilon, 125 | accept_hist, j, 126 | U, K=normalized_kinetic 127 | ) 128 | 129 | # IWAE lower bound 130 | logw = log_mean_exp(logw.view(n_sample, -1).transpose(0, 1)) 131 | if mode == 'backward': 132 | logw = -logw 133 | logws.append(logw.data) 134 | 135 | print ('Time elapse %.4f, last batch stats %.4f' % \ 136 | (time.time()-_time, logw.mean().cpu().data.numpy())) 137 | 138 | _time = time.time() 139 | sys.stdout.flush() # for debugging 140 | 141 | return logws 142 | -------------------------------------------------------------------------------- /utils/approx_posts.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch 6 | import torch.utils.data 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | from utils.math_ops import log_bernoulli, log_normal, log_mean_exp 13 | 14 | 15 | class Flow(nn.Module): 16 | """A combination of R-NVP and auxiliary variables.""" 17 | 18 | def __init__(self, model, n_flows=2): 19 | super(Flow, self).__init__() 20 | self.z_size = model.z_size 21 | self.n_flows = n_flows 22 | self._construct_weights() 23 | 24 | def forward(self, z, grad_fn=lambda x: 1, x_info=None): 25 | return self._sample(z, grad_fn, x_info) 26 | 27 | def _norm_flow(self, params, z, v, grad_fn, x_info): 28 | h = F.elu(params[0][0](torch.cat((z, x_info), dim=1))) 29 | mu = params[0][1](h) 30 | logit = params[0][2](h) 31 | sig = F.sigmoid(logit) 32 | 33 | # old CIFAR used the one below 34 | # v = v * sig + mu * grad_fn(z) 35 | 36 | # the more efficient one uses the one below 37 | v = v * sig - F.elu(mu) * grad_fn(z) 38 | logdet_v = torch.sum(logit - F.softplus(logit), 1) 39 | 40 | h = F.elu(params[1][0](torch.cat((v, x_info), dim=1))) 41 | mu = params[1][1](h) 42 | logit = params[1][2](h) 43 | sig = F.sigmoid(logit) 44 | 45 | z = z * sig + mu 46 | logdet_z = torch.sum(logit - F.softplus(logit), 1) 47 | logdet = logdet_v + logdet_z 48 | 49 | return z, v, logdet 50 | 51 | def _sample(self, z0, grad_fn, x_info): 52 | B = z0.size(0) 53 | z_size = self.z_size 54 | act_func = F.elu 55 | qv_weights, rv_weights, params = self.qv_weights, self.rv_weights, self.params 56 | 57 | out = torch.cat((z0, x_info), dim=1) 58 | for i in range(len(qv_weights)-1): 59 | out = act_func(qv_weights[i](out)) 60 | out = qv_weights[-1](out) 61 | mean_v0, logvar_v0 = out[:, :z_size], out[:, z_size:] 62 | 63 | eps = Variable(torch.randn(B, z_size).type( type(out.data) )) 64 | v0 = eps.mul(logvar_v0.mul(0.5).exp_()) + mean_v0 65 | logqv0 = log_normal(v0, mean_v0, logvar_v0) 66 | 67 | zT, vT = z0, v0 68 | logdetsum = 0. 69 | for i in range(self.n_flows): 70 | zT, vT, logdet = self._norm_flow(params[i], zT, vT, grad_fn, x_info) 71 | logdetsum += logdet 72 | 73 | # reverse model, r(vT|x,zT) 74 | out = torch.cat((zT, x_info), dim=1) 75 | for i in range(len(rv_weights)-1): 76 | out = act_func(rv_weights[i](out)) 77 | out = rv_weights[-1](out) 78 | mean_vT, logvar_vT = out[:, :z_size], out[:, z_size:] 79 | logrvT = log_normal(vT, mean_vT, logvar_vT) 80 | 81 | assert logqv0.size() == (B,) 82 | assert logdetsum.size() == (B,) 83 | assert logrvT.size() == (B,) 84 | 85 | logprob = logqv0 - logdetsum - logrvT 86 | 87 | return zT, logprob 88 | 89 | def _construct_weights(self): 90 | z_size = self.z_size 91 | n_flows = self.n_flows 92 | h_s = 200 93 | 94 | qv_arch = rv_arch = [z_size*2, h_s, h_s, z_size*2] 95 | qv_weights, rv_weights = [], [] 96 | 97 | # q(v|x,z) 98 | id = 0 99 | for ins, outs in zip(qv_arch[:-1], qv_arch[1:]): 100 | cur_layer = nn.Linear(ins, outs) 101 | qv_weights.append(cur_layer) 102 | self.add_module('qz%d' % id, cur_layer) 103 | id += 1 104 | 105 | # r(v|x,z) 106 | id = 0 107 | for ins, outs in zip(rv_arch[:-1], rv_arch[1:]): 108 | cur_layer = nn.Linear(ins, outs) 109 | rv_weights.append(cur_layer) 110 | self.add_module('rv%d' % id, cur_layer) 111 | id += 1 112 | 113 | # nf 114 | params = [] 115 | for i in range(n_flows): 116 | layer_grid = [ 117 | [nn.Linear(z_size*2, h_s), 118 | nn.Linear(h_s, z_size), 119 | nn.Linear(h_s, z_size)], 120 | [nn.Linear(z_size*2, h_s), 121 | nn.Linear(h_s, z_size), 122 | nn.Linear(h_s, z_size)], 123 | ] 124 | 125 | params.append(layer_grid) 126 | 127 | id = 0 128 | for layer_list in layer_grid: 129 | for layer in layer_list: 130 | self.add_module('flow%d_layer%d' % (i, id), layer) 131 | id += 1 132 | 133 | self.qv_weights = qv_weights 134 | self.rv_weights = rv_weights 135 | self.params = params 136 | 137 | self.sanity_check_param = self.params[0][0][0]._parameters['weight'] 138 | -------------------------------------------------------------------------------- /utils/helper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from vae import VAE 6 | from cvae import CVAE 7 | from loader import get_Larochelle_MNIST_loader, get_fashion_loader, get_cifar10_loader 8 | 9 | 10 | def get_loaders(dataset='mnist', evaluate=False, batch_size=100): 11 | if dataset == 'mnist': 12 | train_loader, test_loader = get_Larochelle_MNIST_loader( 13 | batch_size=batch_size, 14 | partial=evaluate, num=1000 15 | ) 16 | elif dataset == 'fashion': 17 | train_loader, test_loader = get_fashion_loader( 18 | batch_size=batch_size, 19 | partial=evaluate, num=1000 20 | ) 21 | elif dataset == 'cifar': 22 | train_loader, test_loader = get_cifar10_loader( 23 | batch_size=batch_size, 24 | partial=evaluate, num=100 25 | ) 26 | 27 | return train_loader, test_loader 28 | 29 | 30 | def get_model(dataset, hps): 31 | if dataset == 'mnist' or dataset == 'fashion': 32 | model = VAE(hps) 33 | elif dataset == 'cifar': # convolutional VAE for CIFAR 34 | model = CVAE(hps) 35 | 36 | return model 37 | -------------------------------------------------------------------------------- /utils/hmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import sys 6 | import os 7 | import math 8 | import torch 9 | import numpy as np 10 | 11 | import torch 12 | from torch.autograd import Variable 13 | 14 | 15 | def hmc_trajectory(current_z, current_v, U, grad_U, epsilon, L=10): 16 | """This version of HMC follows https://arxiv.org/pdf/1206.1901.pdf. 17 | 18 | Args: 19 | U: function to compute potential energy/minus log-density 20 | grad_U: function to compute gradients w.r.t. U 21 | epsilon: (adaptive) step size 22 | L: number of leap-frog steps 23 | current_z: current position 24 | """ 25 | 26 | # as of `torch-0.3.0.post4`, there still is no proper scalar support 27 | assert isinstance(epsilon, Variable) 28 | 29 | eps = epsilon.view(-1, 1) 30 | z = current_z 31 | v = current_v - grad_U(z).mul(eps).mul_(.5) 32 | 33 | for i in range(1, L+1): 34 | z = z + v.mul(eps) 35 | if i < L: 36 | v = v - grad_U(z).mul(eps) 37 | 38 | v = v - grad_U(z).mul(eps).mul_(.5) 39 | v = -v # this is not needed; only here to conform to the math 40 | 41 | return z.detach(), v.detach() 42 | 43 | 44 | def accept_reject(current_z, current_v, 45 | z, v, 46 | epsilon, 47 | accept_hist, hist_len, 48 | U, K=lambda v: torch.sum(v * v, 1)): 49 | """Accept/reject based on Hamiltonians for current and propose. 50 | 51 | Args: 52 | current_z: position BEFORE leap-frog steps 53 | current_v: speed BEFORE leap-frog steps 54 | z: position AFTER leap-frog steps 55 | v: speed AFTER leap-frog steps 56 | epsilon: step size of leap-frog. 57 | (This is only needed for adaptive update) 58 | U: function to compute potential energy (MINUS log-density) 59 | K: function to compute kinetic energy (default: kinetic energy in physics w/ mass=1) 60 | """ 61 | 62 | mdtype = type(current_z.data) 63 | 64 | current_Hamil = K(current_v) + U(current_z) 65 | propose_Hamil = K(v) + U(z) 66 | 67 | prob = torch.exp(current_Hamil - propose_Hamil) 68 | uniform_sample = torch.rand(prob.size()) 69 | uniform_sample = Variable(uniform_sample.type(mdtype)) 70 | accept = (prob > uniform_sample).type(mdtype) 71 | z = z.mul(accept.view(-1, 1)) + current_z.mul(1. - accept.view(-1, 1)) 72 | 73 | accept_hist = accept_hist.add(accept) 74 | criteria = (accept_hist / hist_len > 0.65).type(mdtype) 75 | adapt = 1.02 * criteria + 0.98 * (1. - criteria) 76 | epsilon = epsilon.mul(adapt).clamp(1e-4, .5) 77 | 78 | # clear previous history & save memory, similar to detach 79 | z = Variable(z.data, requires_grad=True) 80 | epsilon = Variable(epsilon.data) 81 | accept_hist = Variable(accept_hist.data) 82 | 83 | return z, epsilon, accept_hist 84 | -------------------------------------------------------------------------------- /utils/hparams.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import torch.nn.functional as F 6 | 7 | class HParams(object): 8 | 9 | def __init__(self, **kwargs): 10 | self._items = {} 11 | for k, v in kwargs.items(): 12 | self._set(k, v) 13 | 14 | def _set(self, k, v): 15 | self._items[k] = v 16 | setattr(self, k, v) 17 | 18 | def parse(self, str_value): 19 | hps = HParams(**self._items) 20 | for entry in str_value.strip().split(","): 21 | entry = entry.strip() 22 | if not entry: 23 | continue 24 | key, sep, value = entry.partition("=") 25 | if not sep: 26 | raise ValueError("Unable to parse: %s" % entry) 27 | default_value = hps._items[key] 28 | if isinstance(default_value, bool): 29 | hps._set(key, value.lower() == "true") 30 | elif isinstance(default_value, int): 31 | hps._set(key, int(value)) 32 | elif isinstance(default_value, float): 33 | hps._set(key, float(value)) 34 | else: 35 | hps._set(key, value) 36 | return hps 37 | 38 | def get_default_hparams(): 39 | return HParams( 40 | z_size=50, 41 | act_func=F.elu, 42 | has_flow=False, 43 | large_encoder=False, 44 | wide_encoder=False, 45 | cuda=True, 46 | ) 47 | -------------------------------------------------------------------------------- /utils/math_ops.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import math 6 | import numpy as np 7 | import numpy.linalg as linalg 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | 13 | 14 | def log_normal(x, mean=None, logvar=None): 15 | """Implementation WITHOUT constant, since the constants in p(z) 16 | and q(z|x) cancels out. 17 | Args: 18 | x: [B,Z] 19 | mean,logvar: [B,Z] 20 | 21 | Returns: 22 | output: [B] 23 | """ 24 | if mean is None: 25 | mean = Variable(torch.zeros(x.size()).type(type(x.data))) 26 | if logvar is None: 27 | logvar = Variable(torch.zeros(x.size()).type(type(x.data))) 28 | 29 | return -0.5 * (logvar.sum(1) + ((x - mean).pow(2) / torch.exp(logvar)).sum(1)) 30 | 31 | 32 | def log_normal_full_cov(x, mean, L): 33 | """Log density of full covariance multivariate Gaussian. 34 | Note: results are off by the constant log(), since this 35 | quantity cancels out in p(z) and q(z|x).""" 36 | 37 | def batch_diag(M): 38 | diag = [t.diag() for t in torch.functional.unbind(M)] 39 | diag = torch.functional.stack(diag) 40 | return diag 41 | 42 | def batch_inverse(M, damp=False, eps=1e-6): 43 | damp_matrix = Variable(torch.eye(M[0].size(0)).type(M.data.type())).mul_(eps) 44 | inverse = [] 45 | for t in torch.functional.unbind(M): 46 | # damping to ensure invertible due to float inaccuracy 47 | # this problem is very UNLIKELY when using double 48 | m = t if not damp else t + damp_matrix 49 | inverse.append(m.inverse()) 50 | inverse = torch.functional.stack(inverse) 51 | return inverse 52 | 53 | L_diag = batch_diag(L) 54 | term1 = -torch.log(L_diag).sum(1) 55 | 56 | L_inverse = batch_inverse(L) 57 | scaled_diff = L_inverse.matmul((x - mean).unsqueeze(2)).squeeze() 58 | term2 = -0.5 * (scaled_diff ** 2).sum(1) 59 | 60 | return term1 + term2 61 | 62 | 63 | def log_bernoulli(logit, target): 64 | """ 65 | Args: 66 | logit: [B, X, ?, ?] 67 | target: [B, X, ?, ?] 68 | 69 | Returns: 70 | output: [B] 71 | """ 72 | loss = -F.relu(logit) + torch.mul(target, logit) - torch.log(1. + torch.exp( -logit.abs() )) 73 | while len(loss.size()) > 1: 74 | loss = loss.sum(-1) 75 | 76 | return loss 77 | 78 | 79 | def mean_squared_error(prediction, target): 80 | prediction, target = flatten(prediction), flatten(target) 81 | diff = prediction - target 82 | 83 | return -torch.sum(torch.mul(diff, diff), 1) 84 | 85 | 86 | def discretized_logistic(mu, logs, x): 87 | """Probability mass follow discretized logistic. 88 | https://arxiv.org/pdf/1606.04934.pdf. Assuming pixel values scaled to be 89 | within [0,1]. Follows implementation from OpenAI. 90 | """ 91 | sigmoid = torch.nn.Sigmoid() 92 | 93 | s = torch.exp(logs).unsqueeze(-1).unsqueeze(-1) 94 | logp = torch.log(sigmoid((x + 1./256. - mu) / s) - sigmoid((x - mu) / s) + 1e-7) 95 | 96 | return logp.sum(-1).sum(-1).sum(-1) 97 | 98 | 99 | def flatten(x): 100 | return x.view(x.size(0), -1) 101 | 102 | 103 | def log_mean_exp(x): 104 | max_, _ = torch.max(x, 1, keepdim=True) 105 | return torch.log(torch.mean(torch.exp(x - max_), 1)) + torch.squeeze(max_) 106 | 107 | 108 | def numpy_nan_guard(arr): 109 | return np.all(arr == arr) 110 | 111 | 112 | def safe_repeat(x, n): 113 | return x.repeat(n, *[1 for _ in range(len(x.size()) - 1)]) 114 | 115 | 116 | def sigmoidial_schedule(T, delta=4): 117 | """From section 6 of BDMC paper.""" 118 | 119 | def sigmoid(x): 120 | return np.exp(x) / (1. + np.exp(x)) 121 | 122 | def beta_tilde(t): 123 | return sigmoid(delta * (2.*t / T - 1.)) 124 | 125 | def beta(t): 126 | return (beta_tilde(t) - beta_tilde(1)) / (beta_tilde(T) - beta_tilde(1)) 127 | 128 | return [beta(t) for t in range(1, T+1)] 129 | 130 | 131 | def linear_schedule(T): 132 | return np.linspace(0., 1., T) 133 | -------------------------------------------------------------------------------- /utils/mnist_reader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import gzip 7 | import numpy as np 8 | 9 | 10 | def load_mnist(path, kind='train'): 11 | """Load MNIST data from `path`""" 12 | 13 | labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind) 14 | images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) 15 | 16 | with gzip.open(labels_path, 'rb') as lbpath: 17 | labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) 18 | 19 | with gzip.open(images_path, 'rb') as imgpath: 20 | images = np.frombuffer(imgpath.read(), dtype=np.uint8, 21 | offset=16).reshape(len(labels), 784) 22 | 23 | return images, labels 24 | -------------------------------------------------------------------------------- /utils/simulate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import numpy as np 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | from torch.distributions import Bernoulli 11 | 12 | 13 | def simulate_data(model, batch_size=10, n_batch=1): 14 | """Simulate data from the VAE model. Sample from the 15 | joint distribution p(z)p(x|z). This is equivalent to 16 | sampling from p(x)p(z|x), i.e. z is from the posterior. 17 | 18 | Bidirectional Monte Carlo only works on simulated data, 19 | where we could obtain exact posterior samples. 20 | 21 | Args: 22 | model: VAE model for simulation 23 | batch_size: batch size for simulated data 24 | n_batch: number of batches 25 | 26 | Returns: 27 | iterator that loops over batches of torch Tensor pair x, z 28 | """ 29 | 30 | # shorter aliases 31 | z_size = model.z_size 32 | mdtype = model.dtype 33 | 34 | batches = [] 35 | for i in range(n_batch): 36 | # assume prior for VAE is unit Gaussian 37 | z = torch.randn(batch_size, z_size).type(mdtype) 38 | x_logits = model.decode(Variable(z)) 39 | if isinstance(x_logits, tuple): 40 | x_logits = x_logits[0] 41 | x_bernoulli_dist = Bernoulli(probs=x_logits.sigmoid()) 42 | x = x_bernoulli_dist.sample().data.type(mdtype) 43 | 44 | paired_batch = (x, z) 45 | batches.append(paired_batch) 46 | 47 | return iter(batches) 48 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | import sys 7 | import argparse 8 | 9 | import torch 10 | import torch.utils.data 11 | import torch.optim as optim 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.utils import weight_norm 15 | from torch.autograd import Variable 16 | from torch.autograd import grad as torchgrad 17 | 18 | from utils.math_ops import log_normal, log_bernoulli, log_mean_exp 19 | from utils.approx_posts import Flow 20 | 21 | 22 | class VAE(nn.Module): 23 | """Generic VAE for MNIST and Fashion datasets.""" 24 | def __init__(self, hps): 25 | super(VAE, self).__init__() 26 | 27 | self.z_size = hps.z_size 28 | self.has_flow = hps.has_flow 29 | self.use_cuda = hps.cuda 30 | self.act_func = hps.act_func 31 | self.n_flows = hps.n_flows 32 | self.hamiltonian_flow = hps.hamiltonian_flow 33 | 34 | self._init_layers(wide_encoder=hps.wide_encoder) 35 | 36 | if self.use_cuda: 37 | self.cuda() 38 | self.dtype = torch.cuda.FloatTensor 39 | else: 40 | self.dtype = torch.FloatTensor 41 | 42 | def _init_layers(self, wide_encoder=False): 43 | h_s = 500 if wide_encoder else 200 44 | 45 | self.fc1 = nn.Linear(784, h_s) # assume flattened 46 | self.fc2 = nn.Linear(h_s, h_s) 47 | self.fc3 = nn.Linear(h_s, self.z_size*2) 48 | 49 | self.fc4 = nn.Linear(self.z_size, 200) 50 | self.fc5 = nn.Linear(200, 200) 51 | self.fc6 = nn.Linear(200, 784) 52 | 53 | self.x_info_layer = nn.Linear(200, self.z_size) 54 | 55 | if self.has_flow: 56 | self.q_dist = Flow(self, n_flows=self.n_flows) 57 | if self.use_cuda: 58 | self.q_dist.cuda() 59 | 60 | def sample(self, mu, logvar, grad_fn=lambda x: 1, x_info=None): 61 | eps = Variable(torch.FloatTensor(mu.size()).normal_().type(self.dtype)) 62 | z = eps.mul(logvar.mul(0.5).exp_()).add_(mu) 63 | logqz = log_normal(z, mu, logvar) 64 | 65 | if self.has_flow: 66 | z, logprob = self.q_dist.forward(z, grad_fn, x_info) 67 | logqz += logprob 68 | 69 | zeros = Variable(torch.zeros(z.size()).type(self.dtype)) 70 | logpz = log_normal(z, zeros, zeros) 71 | 72 | return z, logpz, logqz 73 | 74 | def encode(self, net): 75 | net = self.act_func(self.fc1(net)) 76 | net = self.act_func(self.fc2(net)) 77 | x_info = self.act_func(self.x_info_layer(net)) 78 | net = self.fc3(net) 79 | 80 | mean, logvar = net[:, :self.z_size], net[:, self.z_size:] 81 | 82 | return mean, logvar, x_info 83 | 84 | def decode(self, net): 85 | net = self.act_func(self.fc4(net)) 86 | net = self.act_func(self.fc5(net)) 87 | logit = self.fc6(net) 88 | 89 | return logit 90 | 91 | def forward(self, x, k=1, warmup_const=1.): 92 | x = x.repeat(k, 1) 93 | mu, logvar, x_info = self.encode(x) 94 | 95 | # posterior-aware inference 96 | def U(z): 97 | logpx = log_bernoulli(self.decode(z), x) 98 | logpz = log_normal(z) 99 | return -logpx - logpz # energy as -log p(x, z) 100 | 101 | def grad_U(z): 102 | grad_outputs = torch.ones(z.size(0)).type(self.dtype) 103 | grad = torchgrad(U(z), z, grad_outputs=grad_outputs, create_graph=True)[0] 104 | # gradient clipping avoid numerical issue 105 | norm = torch.sqrt(torch.norm(grad, p=2, dim=1)) 106 | # neither grad clip methods consistently outperforms the other 107 | grad = grad / norm.view(-1, 1) 108 | # grad = torch.clamp(grad, -10000, 10000) 109 | return grad.detach() 110 | 111 | if self.hamiltonian_flow: 112 | z, logpz, logqz = self.sample(mu, logvar, grad_fn=grad_U, x_info=x_info) 113 | else: 114 | z, logpz, logqz = self.sample(mu, logvar, x_info=x_info) 115 | 116 | logit = self.decode(z) 117 | logpx = log_bernoulli(logit, x) 118 | elbo = logpx + logpz - warmup_const * logqz # custom warmup 119 | 120 | # need correction for Tensor.repeat 121 | elbo = log_mean_exp(elbo.view(k, -1).transpose(0, 1)) 122 | elbo = torch.mean(elbo) 123 | 124 | logpx = torch.mean(logpx) 125 | logpz = torch.mean(logpz) 126 | logqz = torch.mean(logqz) 127 | 128 | return elbo, logpx, logpz, logqz 129 | --------------------------------------------------------------------------------