├── LICENSE ├── README.md ├── distributions.py ├── test_gamma.py ├── train_gamma_vae.py └── train_normal_vae.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Jaan Altosaar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## gamma-variational-autoencoder 2 | 3 | This is an example implementation of Reparameterized Rejection sampling (Naesseth et al. https://arxiv.org/abs/1610.05683). 4 | 5 | There are two examples here, implemented in MXNet: 6 | 7 | 1. A standard VAE with Gaussian latent variables. 8 | 9 | 2. A VAE with Gamma-distributed latent variables. 10 | 11 | A key to the fast implementation is to sample Gamma variables, then calculate the inverse function to get the epsilon used in the reparameterization algorithm. Thanks to Christian Naesseth for pointing out this nice trick! 12 | -------------------------------------------------------------------------------- /distributions.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from mxnet import gluon 3 | 4 | 5 | class ReparameterizedGamma(gluon.HybridBlock): 6 | """Returns a reparameterized sample from a Gamma(shape, scale) distribution. 7 | 8 | shape, scale are the shape and scale of the distribution. We use Algorithm 1 9 | of [1], but sample from a Gamma(shape, 1) to guarantee acceptance. We also 10 | use shape augmentation as in Section 5 of [1]. 11 | 12 | References: 13 | 1. Naesseth et al. (2017). 14 | """ 15 | 16 | def __init__(self, B): 17 | """B is the number of times to augment the shape, useful for sparsity.""" 18 | super().__init__() 19 | self.B = B 20 | 21 | def hybrid_forward(self, F, shape, scale): 22 | # sample the \tilde z ~ Gamma(shape + B, 1.) to guarantee acceptance 23 | one = F.ones_like(shape) 24 | z_tilde = F.sample_gamma(shape + self.B, one) 25 | # compute the epsilon corresponding to \tilde z; this epsilon is 'accepted' 26 | # \epsilon = h_inverse(z_tilde; shape + B) 27 | eps = self.compute_h_inverse(F, z_tilde, F.stop_gradient(shape) + self.B) 28 | # now compute z_tilde = h(epsilon, shape + B) 29 | z_tilde = self.compute_h(F, eps, shape + self.B) 30 | # E_{u_1,...,u_B, \tilde z}[f(\tilde z) \prod_i u_i^{(shape + i - 1)^{-1}}] 31 | # = E_{u_1,...,u_B, \pi(eps)}[f(h(eps, shape + B) \prod_i ...] 32 | B_range = F.arange(start=1, stop=self.B + 1) 33 | # expand dims broadcast with shape 34 | B_range = F.expand_dims(F.expand_dims(B_range, -1), -1) 35 | zero = F.zeros_like(shape) 36 | unif_sample = F.sample_uniform(zero, one, shape=(self.B)) 37 | # transpose so that boosting dimension is the innermost 38 | unif_sample = F.transpose(unif_sample, axes=(2, 0, 1)) 39 | unif_prod = F.prod( 40 | unif_sample**(1. / (F.broadcast_add(shape, B_range) - 1.)), axis=0) 41 | # This reparameterized sample is distributed as Gamma(shape, 1) 42 | # z = h(eps, shape + B) \prod_i u_i^{1 / (shape + i - 1)} 43 | z = z_tilde * unif_prod 44 | # Divide by scale to get a sample distributed as Gamma(shape, scale) 45 | return z * scale 46 | 47 | def compute_h(self, F, eps, shape): 48 | return (shape - 1. / 3.) * (1. + eps / F.sqrt(9. * shape - 3.))**3 49 | 50 | def compute_h_inverse(self, F, z, shape): 51 | return F.sqrt(9. * shape - 3.) * ((z / (shape - 1. / 3.))**(1. / 3.) - 1.) 52 | -------------------------------------------------------------------------------- /test_gamma.py: -------------------------------------------------------------------------------- 1 | import distributions 2 | import numpy as np 3 | import mxnet as mx 4 | from mxnet import nd 5 | import scipy.stats 6 | mx.random.seed(24232) 7 | np.random.seed(2423242) 8 | 9 | 10 | def sample_gamma(shape, scale, n_samples): 11 | reparam_gamma = distributions.ReparameterizedGamma(B=8) 12 | if not isinstance(shape, np.ndarray): 13 | shape = np.array([[shape]]) 14 | scale = np.array([[scale]]) 15 | shape = np.repeat(shape, n_samples, axis=0) 16 | scale = np.repeat(scale, n_samples, axis=0) 17 | sample = reparam_gamma(nd.array(shape), nd.array(scale)) 18 | return sample 19 | 20 | 21 | def check_gamma_mean(shape, scale, n_samples): 22 | sample = sample_gamma(shape, scale, n_samples) 23 | mean = sample.asnumpy().mean(axis=0) 24 | print('actual, computed:') 25 | true_mean = np.squeeze(shape * scale) 26 | print(true_mean, mean) 27 | np.testing.assert_allclose(true_mean, mean, rtol=1e-1) 28 | 29 | 30 | def test_gamma_sampling_mean(): 31 | """Check that reparameterized samples recover the correct mean.""" 32 | check_gamma_mean(np.array([[1., 1., 1.]]), np.array([[1., 1., 1.]]), 1000) 33 | check_gamma_mean(10., 1., 1000) 34 | check_gamma_mean(1., 1., 1000) 35 | check_gamma_mean(0.1, 1., 10000) 36 | check_gamma_mean(0.01, 1., 100000) 37 | check_gamma_mean(1., 30., 10000) 38 | check_gamma_mean(5., 30., 10000) 39 | check_gamma_mean(0.3, 3., 10000) 40 | 41 | 42 | def check_gamma_grads(np_shape, np_scale): 43 | """Test that reparameterization gradients are correct.""" 44 | if not isinstance(np_shape, np.ndarray): 45 | np_shape = np.array([[np_shape]]) 46 | np_scale = np.array([[np_scale]]) 47 | shape = nd.array(np_shape) 48 | scale = nd.array(np_scale) 49 | shape.attach_grad() 50 | scale.attach_grad() 51 | 52 | def function(F, z): 53 | return F.square(z) - 3.5 54 | 55 | reparam_gamma = distributions.ReparameterizedGamma(B=8) 56 | # compute gradient of a simple function f(z) = z 57 | g_shape_list = [] 58 | g_scale_list = [] 59 | for _ in range(1000): 60 | with mx.autograd.record(): 61 | z_sample = reparam_gamma(shape, scale) 62 | f = function(nd, z_sample) 63 | f.backward() 64 | g_shape_list.append(shape.grad.asnumpy()) 65 | g_scale_list.append(scale.grad.asnumpy()) 66 | g_shape = np.mean(g_shape_list, axis=0) 67 | g_scale = np.mean(g_scale_list, axis=0) 68 | np_z = scipy.stats.gamma.rvs( 69 | np_shape, scale=np_scale, size=(100000, np_shape.shape[-1])) 70 | score_shape, score_scale = gamma_score(np_z, np_shape, np_scale) 71 | f_z = function(np, np_z) 72 | np_g_shape = np.mean(score_shape * f_z, axis=0) 73 | np_g_scale = np.mean(score_scale * f_z, axis=0) 74 | print('shape score, reparam') 75 | print(np_g_shape, g_shape) 76 | np.testing.assert_allclose(np_g_shape, np.squeeze(g_shape), rtol=2e-1) 77 | print('scale score, reparam') 78 | print(np_g_scale, g_scale) 79 | np.testing.assert_allclose(np_g_scale, np.squeeze(g_scale), rtol=2e-1) 80 | 81 | 82 | def test_gamma_grads(): 83 | check_gamma_grads(1., 1.) 84 | check_gamma_grads(1., 3.) 85 | check_gamma_grads(0.3, 3.) 86 | check_gamma_grads(np.array([[0.3, 0.3, 0.3]]), np.array([[3., 3., 3.]])) 87 | 88 | 89 | def gamma_score(z, shape, scale): 90 | """Score function of gamma.""" 91 | score_shape = -scipy.special.psi(shape) - np.log(scale) + np.log(z) 92 | score_scale = -shape / scale + z / scale / scale 93 | return score_shape, score_scale 94 | 95 | 96 | if __name__ == '__main__': 97 | # test_gamma_sampling_mean() 98 | test_gamma_grads() 99 | -------------------------------------------------------------------------------- /train_gamma_vae.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import time 3 | import scipy.misc 4 | from mxnet import nd 5 | import mxnet as mx 6 | import h5py 7 | import numpy as np 8 | from mxnet import gluon 9 | 10 | import distributions 11 | 12 | 13 | class DeepLatentGammaModel(gluon.HybridBlock): 14 | def __init__(self): 15 | super().__init__() 16 | with self.name_scope(): 17 | self.log_prior = GammaLogProb() 18 | # self.log_prior = GaussianLogProb() 19 | # generative network parameterizes the likelihood 20 | self.net = gluon.nn.HybridSequential() 21 | with self.net.name_scope(): 22 | self.net.add(gluon.nn.Dense( 23 | 200, 'relu', weight_initializer=mx.init.Xavier())) 24 | self.net.add(gluon.nn.Dense( 25 | 200, 'relu', weight_initializer=mx.init.Xavier())) 26 | self.net.add(gluon.nn.Dense(784, weight_initializer=mx.init.Xavier())) 27 | self.log_lik = BernoulliLogLik() 28 | 29 | def hybrid_forward(self, F, z, x): 30 | # use a sparse Gamma(shape=0.3, scale=3) prior 31 | log_prior = self.log_prior(z, 0.3 * F.ones_like(z), 3. * F.ones_like(z)) 32 | # log_prior = self.log_prior(z, F.zeros_like(z), F.ones_like(z)) 33 | logits = self.net(z) 34 | log_lik = self.log_lik(x, logits) 35 | return F.sum(log_lik, -1), F.sum(log_prior, -1) 36 | 37 | 38 | class ELBO(gluon.HybridBlock): 39 | def __init__(self, model, variational): 40 | super().__init__() 41 | with self.name_scope(): 42 | self.variational = variational 43 | self.model = model 44 | 45 | def hybrid_forward(self, F, x): 46 | z, log_q_z = self.variational(x) 47 | log_lik, log_prior = self.model(z, x) 48 | return log_lik + log_prior - log_q_z 49 | 50 | 51 | class BernoulliLogLik(gluon.HybridBlock): 52 | """Calculate log probability of a Bernoulli.""" 53 | 54 | def __init__(self): 55 | super().__init__() 56 | 57 | def hybrid_forward(self, F, x, logits): 58 | """Bernoulli log prob is 59 | x * log(1 + exp(-z))^(-1) + (1-x) * log(1 + exp(z))^(-1) 60 | = - x * log(1 + exp(z)) + x * log(exp(z)) - log(1 + exp(z)) + x * log(1 + exp(z)) 61 | = x * z - log(1 + exp(z)) 62 | = x * z - max(0, z) - log(1 + exp(-|z|) 63 | In the last step, observe that softplus(z) ~= z when z large. 64 | When z small, we hit underflow. 65 | """ 66 | return x * logits - F.relu(logits) - F.Activation(-F.abs(logits), 'softrelu') 67 | 68 | 69 | class AmortizedGammaVariational(gluon.HybridBlock): 70 | def __init__(self, latent_size, batch_size): 71 | super().__init__() 72 | self.net = gluon.nn.HybridSequential() 73 | with self.name_scope(): 74 | self.reparam_gamma = distributions.ReparameterizedGamma(B=8) 75 | self.gamma_log_prob = GammaLogProb() 76 | with self.net.name_scope(): 77 | self.net.add(gluon.nn.Dense(200, activation='relu', flatten=True)) 78 | self.net.add(gluon.nn.Dense(200, activation='relu')) 79 | self.net.add(gluon.nn.Dense(latent_size * 2)) 80 | 81 | def hybrid_forward(self, F, x): 82 | mean_scale_arg = self.net(x) 83 | shape_arg, scale_arg = F.split(mean_scale_arg, num_outputs=2, axis=-1) 84 | shape = F.Activation(shape_arg, 'softrelu') 85 | scale = F.Activation(scale_arg, 'softrelu') 86 | z = self.reparam_gamma(shape, scale) 87 | log_prob = self.gamma_log_prob(z, shape, scale) 88 | return z, F.sum(log_prob, -1) 89 | 90 | 91 | class GammaLogProb(gluon.HybridBlock): 92 | def __init__(self): 93 | super().__init__() 94 | 95 | def hybrid_forward(self, F, x, shape, scale): 96 | return -F.gammaln(shape) - shape * F.log(scale) \ 97 | + (shape - 1.) * F.log(x) - x / scale 98 | 99 | 100 | if __name__ == '__main__': 101 | np.random.seed(24232) 102 | mx.random.seed(2423232) 103 | 104 | USE_GPU = False 105 | LATENT_SIZE = 32 106 | BATCH_SIZE = 64 107 | PRINT_EVERY = 100 108 | MAX_ITERATIONS = 1000000 109 | OUT_DIR = pathlib.Path(pathlib.os.environ['LOG']) / 'debug' 110 | 111 | dataset = mx.gluon.data.vision.MNIST( 112 | train=True, 113 | transform=lambda data, label: ( 114 | np.round(data.astype(np.float32) / 255), label)) 115 | train_data = mx.gluon.data.DataLoader(dataset, 116 | batch_size=BATCH_SIZE, shuffle=True) 117 | 118 | ctx = [mx.gpu(0)] if USE_GPU else [mx.cpu()] 119 | with mx.Context(ctx[0]): 120 | variational = AmortizedGammaVariational(LATENT_SIZE, BATCH_SIZE) 121 | model = DeepLatentGammaModel() 122 | elbo = ELBO(model, variational) 123 | 124 | variational.hybridize() 125 | model.hybridize() 126 | elbo.hybridize() 127 | 128 | variational.initialize(mx.init.Xavier()) 129 | model.initialize(mx.init.Xavier()) 130 | 131 | params = model.collect_params() 132 | params.update(variational.collect_params()) 133 | trainer = gluon.Trainer( 134 | params, 'rmsprop', {'learning_rate': 0.00001, 'centered': True}) 135 | wd_param = params.get('hybridsequential0_dense2_weight') 136 | 137 | def get_posterior_predictive(data, step): 138 | z, _ = variational(data) 139 | logits = model.net(z) 140 | probs = nd.sigmoid(logits) 141 | np_probs = probs.asnumpy() 142 | for i, prob in enumerate(np_probs): 143 | prob = prob.reshape((28, 28)) 144 | scipy.misc.imsave(OUT_DIR / f'step_{step}_test_{i}.jpg', prob) 145 | 146 | step = 0 147 | t0 = time.time() 148 | for data, _ in train_data: 149 | break 150 | while step < MAX_ITERATIONS: 151 | for _, _ in train_data: 152 | data = data.reshape(-1, 784) 153 | with mx.autograd.record(): 154 | elbo_batch = elbo(data) 155 | loss = -elbo_batch 156 | loss.backward() 157 | for name, param in variational.collect_params().items(): 158 | g = param.grad().asnumpy() 159 | # print(name, g.max(), g.min()) 160 | if step % PRINT_EVERY == 0: 161 | get_posterior_predictive(data, step) 162 | np_elbo = np.mean(elbo_batch.asnumpy()) 163 | t1 = time.time() 164 | speed = (t1 - t0) / PRINT_EVERY 165 | t0 = t1 166 | print(f'Iter {step}\tELBO: {np_elbo:.1f}\tspeed: {speed:.3e} s/iter') 167 | trainer.step(BATCH_SIZE) 168 | step += 1 169 | -------------------------------------------------------------------------------- /train_normal_vae.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import time 3 | import scipy.misc 4 | from mxnet import nd 5 | import mxnet as mx 6 | import h5py 7 | import numpy as np 8 | from mxnet import gluon 9 | 10 | 11 | class DeepLatentGaussianModel(gluon.HybridBlock): 12 | def __init__(self): 13 | super().__init__() 14 | with self.name_scope(): 15 | self.log_prior = GaussianLogProb() 16 | # generative network parameterizes the likelihood 17 | self.net = gluon.nn.HybridSequential() 18 | with self.net.name_scope(): 19 | self.net.add(gluon.nn.Dense( 20 | 200, 'relu', weight_initializer=mx.init.Xavier())) 21 | self.net.add(gluon.nn.Dense( 22 | 200, 'relu', weight_initializer=mx.init.Xavier())) 23 | self.net.add(gluon.nn.Dense(784, weight_initializer=mx.init.Xavier())) 24 | self.log_lik = BernoulliLogLik() 25 | 26 | def hybrid_forward(self, F, z, x): 27 | log_prior = self.log_prior(z, F.zeros_like(z), F.ones_like(z)) 28 | logits = self.net(z) 29 | log_lik = self.log_lik(x, logits) 30 | return F.sum(log_lik, -1), F.sum(log_prior, -1) 31 | 32 | 33 | class ELBO(gluon.HybridBlock): 34 | def __init__(self, model, variational): 35 | super().__init__() 36 | with self.name_scope(): 37 | self.variational = variational 38 | self.model = model 39 | self.kl = GaussianKL() 40 | 41 | def hybrid_forward(self, F, x): 42 | z, mu_z, sigma_z, log_q_z = self.variational(x) 43 | log_lik, log_prior = self.model(z, x) 44 | # return log_lik + log_prior - log_q_z 45 | kl = F.sum(self.kl(mu_z, sigma_z), -1) 46 | return log_lik - kl 47 | 48 | 49 | class GaussianKL(gluon.HybridBlock): 50 | def __init__(self): 51 | super().__init__() 52 | 53 | def hybrid_forward(self, F, mu, sigma): 54 | return -F.log(sigma) + (F.square(sigma) + F.square(mu)) / 2. - 0.5 55 | 56 | 57 | class BernoulliLogLik(gluon.HybridBlock): 58 | """Calculate log probability of a Bernoulli.""" 59 | 60 | def __init__(self): 61 | super().__init__() 62 | 63 | def hybrid_forward(self, F, x, logits): 64 | """Bernoulli log prob is 65 | x * log(1 + exp(-z))^(-1) + (1-x) * log(1 + exp(z))^(-1) 66 | = - x * log(1 + exp(z)) + x * log(exp(z)) - log(1 + exp(z)) + x * log(1 + exp(z)) 67 | = x * z - log(1 + exp(z)) 68 | = x * z - max(0, z) - log(1 + exp(-|z|) 69 | In the last step, observe that softplus(z) ~= z when z large. 70 | When z small, we hit underflow. 71 | """ 72 | return x * logits - F.relu(logits) - F.Activation(-F.abs(logits), 'softrelu') 73 | 74 | 75 | class GaussianLogProb(gluon.HybridBlock): 76 | def __init__(self): 77 | super().__init__() 78 | 79 | def hybrid_forward(self, F, x, mean, sigma): 80 | variance = F.square(sigma) 81 | return -0.5 * F.log(2. * np.pi * F.square(sigma)) \ 82 | - F.square(x - mean) / variance 83 | 84 | 85 | class AmortizedGaussianVariational(gluon.HybridBlock): 86 | def __init__(self, latent_size, batch_size): 87 | super().__init__() 88 | self.net = gluon.nn.HybridSequential() 89 | with self.name_scope(): 90 | self.gaussian_log_prob = GaussianLogProb() 91 | init = mx.init.Xavier() 92 | with self.net.name_scope(): 93 | self.net.add(gluon.nn.Dense(200, activation='relu', flatten=True, 94 | weight_initializer=init)) 95 | self.net.add(gluon.nn.Dense(200, activation='relu', 96 | weight_initializer=mx.init.Xavier())) 97 | self.net.add(gluon.nn.Dense(latent_size * 2, 98 | weight_initializer=mx.init.Xavier())) 99 | 100 | def hybrid_forward(self, F, x): 101 | mean_sigma_arg = self.net(x) 102 | mu, sigma_arg = F.split(mean_sigma_arg, num_outputs=2, axis=-1) 103 | sigma = F.Activation(sigma_arg, 'softrelu') 104 | eps = F.sample_normal(F.zeros_like(mu), F.ones_like(mu)) 105 | z = mu + eps * sigma 106 | log_prob = self.gaussian_log_prob(z, mu, sigma) 107 | return z, mu, sigma, F.sum(log_prob, -1) 108 | 109 | 110 | if __name__ == '__main__': 111 | np.random.seed(24232) 112 | mx.random.seed(2423232) 113 | 114 | USE_GPU = False 115 | LATENT_SIZE = 100 116 | BATCH_SIZE = 64 117 | PRINT_EVERY = 1000 118 | MAX_ITERATIONS = 1000000 119 | OUT_DIR = pathlib.Path(pathlib.os.environ['LOG']) / 'debug' 120 | 121 | # hdf5 file from: 122 | # https://github.com/altosaar/proximity_vi/blob/master/get_binary_mnist.py 123 | data_path = pathlib.Path(pathlib.os.environ['DAT']) / 'binarized_mnist.hdf5' 124 | f = h5py.File(data_path, 'r') 125 | raw_data = f['train'][:][:] 126 | f.close() 127 | 128 | def get_data(): 129 | return mx.io.NDArrayIter( 130 | data={'data': nd.array(raw_data)}, 131 | label={'label': range(len(raw_data)) * np.ones((len(raw_data),))}, 132 | batch_size=BATCH_SIZE, 133 | last_batch_handle='discard', 134 | shuffle=True) 135 | 136 | ctx = [mx.gpu(0)] if USE_GPU else [mx.cpu()] 137 | with mx.Context(ctx[0]): 138 | variational = AmortizedGaussianVariational(LATENT_SIZE, BATCH_SIZE) 139 | model = DeepLatentGaussianModel() 140 | elbo = ELBO(model, variational) 141 | 142 | variational.hybridize() 143 | model.hybridize() 144 | elbo.hybridize() 145 | 146 | variational.initialize(mx.init.Xavier()) 147 | model.initialize(mx.init.Xavier()) 148 | 149 | params = model.collect_params() 150 | params.update(variational.collect_params()) 151 | trainer = gluon.Trainer(params, 'rmsprop', {'learning_rate': 0.001}) 152 | # , 'centered': True}) 153 | 154 | def get_posterior_predictive(batch, step): 155 | z, _, _, _ = variational(batch.data[0]) 156 | logits = model.net(z) 157 | probs = nd.sigmoid(logits) 158 | np_probs = probs.asnumpy() 159 | for i, prob in enumerate(np_probs): 160 | prob = prob.reshape((28, 28)) 161 | scipy.misc.imsave(OUT_DIR / f'step_{step}_test_{i}.jpg', prob) 162 | 163 | step = 0 164 | t0 = time.time() 165 | train_data = get_data() 166 | while step < MAX_ITERATIONS: 167 | if step % (train_data.num_data // BATCH_SIZE) == 0: 168 | train_data = get_data() 169 | data = next(train_data) 170 | with mx.autograd.record(): 171 | elbo_batch = elbo(data.data[0]) 172 | (-elbo_batch).backward() 173 | if step % PRINT_EVERY == 0: 174 | get_posterior_predictive(data, step) 175 | np_elbo = np.mean(elbo_batch.asnumpy()) 176 | t1 = time.time() 177 | speed = (t1 - t0) / PRINT_EVERY 178 | t0 = t1 179 | print(f'Iter {step}\tELBO: {np_elbo:.1f}\tspeed: {speed:.3e} s/iter') 180 | trainer.step(BATCH_SIZE) 181 | step += 1 182 | --------------------------------------------------------------------------------