├── .gitignore ├── README.md ├── images ├── U_dist.PNG ├── data.png ├── k_16 │ └── k_16.gif ├── k_32 │ └── k_32.gif ├── k_4 │ └── k_4.gif └── k_8 │ └── k_8.gif ├── nf.py ├── normflow.py ├── synthetic_data.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Inference with Normalizing Flows 2 | Implementation of paper [Variational Inference with Normalizing Flows](https://arxiv.org/abs/1505.05770) section 6.1 experiments. 3 | 4 | This experiment visually demonstrates that **Normalizing Flows** can successfully transform a simple 5 | initial simple distribution q_0(z) to much better approximate some known non-Gaussian Bi-variate 6 | distribution p(z). 7 | 8 | The distributions that we want to learn is as followings: 9 | 10 | 11 | We want to learning the Bi-variate distribution of Z(Z1, Z2). Equation for each of the distribution could be found at paper. 12 | 13 | 14 | Then, our **goal** is to transfer some simple distribution to the target distribution by Normalizing Flows. 15 | We will select Normal distribution as the simple distribution and then transfer simple standard Normal distribution to the complex target distribution. 16 | 17 | # Experiments Results 18 | This section show the experiments results. 19 | 1. K=4 20 | 21 | 22 | 2. K=8 23 | 24 | 25 | 3. K=16 26 | 27 | 28 | 4. K=32 29 | 30 | 31 | # Training Criteria 32 | The known target distributions are specified using energy functions U(z). 33 | p(z) = \frac{1}{Z} e^{-U(z)}, where Z is the unknown partition function (normalization constant); 34 | that is, p(z) \prop e^{-U(z)}. 35 | 36 | ## Steps 37 | 1. Generate random samples from initial distribution z0 ~ q_0(z) = N(z; \mu, \sigma^2 I). 38 | Here \mu and \sigma can either be fixed (such as standard Normal distribution) to something "reasonable", or estimated as follows. 39 | Draw auxillary random variable \eps from standard normal distribution 40 | \eps ~ N(0, I) 41 | and apply linear normalizing flow transformation f(\eps) = \mu + \sigma \eps, re-parameterizing 42 | \sigma = e^{1/2*log_var} to ensure \sigma > 0, then jointly optimize {mu, log_var} together 43 | with the other normalizing flow parameters (see below). 44 | 2. Transform the initial samples z_0 through K **Normalizing Flows** transforms, from which we obtain the 45 | transformed approximate distribution q_K(z), 46 | log q_K(z) = log q_0(z) - sum_{k=1}^K log det |J_k| 47 | where J_k is the Jacobian of the k-th (invertible) normalizing flow transform. 48 | E.g. for planar flows, 49 | log q_K(z) = log q_0(z) - sum_{k=1}^K log |1 + u_k^T \psi_k(z_{k-1})| 50 | where each flow includes model parameters \lambda = {w, u, b}. 51 | 3. Jointly optimize all model parameters by minimizing **KL-Divergence** between the approximate distribution q_K(z) 52 | and the true distribution p(z). 53 | loss = KL[q_K(z)||p(z)] = E_{z_K ~ q_K(z)} [log q_K(z_K) - log p(z_K)] 54 | = E_{z_K ~ q_K(z)} [(log q_0(z_0) - sum_k log det |J_k|) - (-U(z_K) - log(Z))] 55 | = E_{z_0 ~ q_0(z)} [log q_0(z_0) - sum_k log det |J_k| + U(f_1(f_2(..f_K(z_0)))) + log(Z)] 56 | Here the partition function Z is independent of z_0 and model parameters, so we can ignore it for the optimization 57 | loss = E_{z_0 ~ q_0(z)} [log q0(z0) - sum_k log det |J_k| + U(z_K)] 58 | 59 | **The expectation could be approximated by Monte Carlo Sampling, the mini-batch average here could be considered as the 60 | Monte Carlo Sampling Expectation.** 61 | 62 | 63 | ## Another understanding of loss function 64 | ### Loss of the Normalizing flow for 6.1, the synthetic data 65 | Let $$p(z)$$ be the true distribution of bivariate distribution. 66 | and, $$z0 \sim q_0(z)$$ is a simple distribution than we already know, then we need to transform this simple distribution by Normalizing Flow to approximate the true distribution $$p(z)$$ that we want to get. 67 | 68 | We use KL divergence to measure the distant of our learned distribution with the true distribution. 69 | $$ 70 | KL(q_k(z_k)||p(z)) = \\ 71 | E_{z_k \sim q_k(z_k)}[logq_k(z_k) - logp(z_k)] = \\ 72 | E_{z_k \sim q_k(z_k)}[logq_0(z_0) - sum_k logdet(Jacobian) - (-U(z_K) - log(Z))] = \\ 73 | E_{z_0 \sim q_0(z)} [log q_0(z_0) - sum_k log det |J_k| + U(z_k) + log(Z)] 74 | $$ 75 | 76 | # TODO 77 | * Add IAF 78 | * [https://arxiv.org/pdf/1606.04934.pdf](https://arxiv.org/pdf/1606.04934.pdf) 79 | * [https://arxiv.org/pdf/1502.03509.pdf](https://arxiv.org/pdf/1502.03509.pdf) 80 | 81 | # Reference 82 | [1] Jimenez Rezende, D., Mohamed, S., "Variational Inference with Normalizing Flows", 83 | Proceedings of the 32nd International Conference on Machine Learning, 2015. 84 | 85 | [vae-normflow](https://github.com/16lawrencel/vae-normflow) 86 | 87 | [Reproduce results from sec. 6.1 in "Variational inference using normalizing flows" ](https://github.com/casperkaae/parmesan/issues/22) 88 | 89 | [parmesan/parmesan/layers/flow.py](https://github.com/casperkaae/parmesan/blob/master/parmesan/layers/flow.py) 90 | 91 | [wuaalb/nf6_1.py](https://gist.github.com/wuaalb/c5b85d0c257d44b0d98a) 92 | -------------------------------------------------------------------------------- /images/U_dist.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/U_dist.PNG -------------------------------------------------------------------------------- /images/data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/data.png -------------------------------------------------------------------------------- /images/k_16/k_16.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_16/k_16.gif -------------------------------------------------------------------------------- /images/k_32/k_32.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_32/k_32.gif -------------------------------------------------------------------------------- /images/k_4/k_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_4/k_4.gif -------------------------------------------------------------------------------- /images/k_8/k_8.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_8/k_8.gif -------------------------------------------------------------------------------- /nf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import tensorflow as tf 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | 9 | # define network structure 10 | 11 | 12 | def inputs(input_dim, hidden_dim): 13 | x = tf.placeholder(tf.float32, [None, input_dim], 'x') 14 | e = tf.placeholder(tf.float32, [None, hidden_dim], 'e') 15 | return x, e 16 | 17 | 18 | def encoder(x, e, input_dim, hidden_dim, z_dim, K, initializer=tf.contrib.layers.xavier_initializer): 19 | ''' 20 | :param x: input 21 | :param e: 22 | :param input_dim: 23 | :param hidden_dim: 24 | :param z_dim: 25 | :param K: number of normalizing flow 26 | :param initializer: 27 | :return: 28 | ''' 29 | with tf.variable_scope('encoder'): 30 | w_h = tf.get_variable('w_h', [input_dim, hidden_dim], initializer=initializer()) 31 | b_h = tf.get_variable('b_h', [hidden_dim]) 32 | w_mu = tf.get_variable('w_mu', [hidden_dim, z_dim], initializer=initializer()) 33 | b_mu = tf.get_variable('b_mu', [z_dim]) 34 | w_v = tf.get_variable('w_v', [hidden_dim, z_dim], initializer=initializer()) 35 | b_v = tf.get_variable('b_v', [z_dim]) 36 | 37 | # Weights for outputting normalizing flow parameters 38 | w_us = tf.get_variable('w_us', [hidden_dim, K*z_dim]) 39 | b_us = tf.get_variable('b_us', [K*z_dim]) 40 | w_ws = tf.get_variable('w_ws', [hidden_dim, K*z_dim]) 41 | b_ws = tf.get_variable('b_ws', [K*z_dim]) 42 | w_bs = tf.get_variable('w_bs', [hidden_dim, K]) 43 | b_bs = tf.get_variable('b_bs', [K]) 44 | 45 | # compute hidden state 46 | h = tf.nn.tanh(tf.matmul(x, w_h) + b_h) 47 | mu = tf.matmul(h, w_mu) + b_mu 48 | log_var = tf.matmul(h, w_v) + b_v 49 | # re-parameterization 50 | z = mu + tf.sqrt(tf.exp(log_var)) * e 51 | 52 | # Normalizing Flow parameters 53 | us = tf.matmul(h, w_us) + b_us 54 | ws = tf.matmul(h, w_ws) + b_ws 55 | bs = tf.matmul(h, w_bs) + b_bs 56 | 57 | t = (us, ws, bs) 58 | 59 | return mu, log_var, z, t 60 | 61 | 62 | def norm_flow(z, lambd, K, Z): 63 | us, ws, bs = lambd 64 | 65 | log_detjs = [] 66 | for k in range(K): 67 | u, w, b = us[:, k*Z:(k+1)*Z], ws[:, k*Z:(k+1)*Z], bs[:, k] 68 | temp = tf.expand_dims(tf.nn.tanh(tf.reduce_sum(w*z, 1) + b), 1) 69 | temp = tf.tile(temp, [1, u.get_shape()[1].value]) 70 | z = z + tf.multiply(u, temp) 71 | 72 | # Eqn. (11) and (12) 73 | temp = tf.expand_dims(dtanh(tf.reduce_sum(w*z, 1) + b), 1) 74 | temp = tf.tile(temp, [1, w.get_shape()[1].value]) 75 | log_detj = tf.abs(1. + tf.reduce_sum(tf.multiply(u, temp*w), 1)) 76 | log_detjs.append(log_detj) 77 | 78 | if K != 0: 79 | log_detj = tf.reduce_sum(log_detjs) 80 | else: 81 | log_detj = 0 82 | 83 | return z, log_detj 84 | 85 | 86 | def dtanh(input): 87 | return 1.0 - tf.square(tf.tanh(input)) 88 | 89 | 90 | def decoder(z, D, H, Z, initializer=tf.contrib.layers.xavier_initializer, out_fn=tf.sigmoid): 91 | with tf.variable_scope('decoder'): 92 | w_h = tf.get_variable('w_h', [Z, H], initializer=initializer()) 93 | b_h = tf.get_variable('b_h', [H]) 94 | w_mu = tf.get_variable('w_mu', [H, D], initializer=initializer()) 95 | b_mu = tf.get_variable('b_mu', [D]) 96 | w_v = tf.get_variable('w_v', [H, 1], initializer=initializer()) 97 | b_v = tf.get_variable('b_v', [1]) 98 | 99 | h = tf.nn.tanh(tf.matmul(z, w_h) + b_h) 100 | out_mu = tf.matmul(h, w_mu) + b_mu 101 | out_log_var = tf.matmul(h, w_v) + b_v 102 | out = out_fn(out_mu) 103 | 104 | return out, out_mu, out_log_var 105 | 106 | 107 | def make_loss(pred, actual, log_var, mu, log_detj, sigma=1.0): 108 | # kl loss 109 | kl = -tf.reduce_mean(0.5*tf.reduce_sum(1.0 + log_var - tf.square(mu) - tf.exp(log_var), 1)) 110 | # re-construct loss 111 | # TODO: re-construct loss should be computed by negative log-likelihood of Bernoulli distribution 112 | # , here is only L2 loss 113 | 114 | rec_err = 0.5*(tf.nn.l2_loss(actual - pred)) / sigma 115 | loss = tf.reduce_mean(kl + rec_err - log_detj) 116 | # TODO: I think it is wrong here to compute the loss, wrong sign for (kl + rec_err), need verify! 117 | # loss = tf.reduce_mean(-kl - rec_err - log_detj) # test this loss 118 | return loss 119 | 120 | 121 | def train_step(sess, input_data, train_op, loss_op, x_op, e_op, Z): 122 | e_ = np.random.normal(size=(input_data.shape[0], Z)) 123 | _, l = sess.run([train_op, loss_op], feed_dict={x_op: input_data, e_op: e_}) 124 | return l 125 | 126 | 127 | def reconstruct(sess, batch_size, out_op, x_op, e_op, Z): 128 | e_ = np.random.normal(size=(input_data.shape[0], Z)) 129 | x_rec = sess.run([out_op], feed_dict={x_op: input_data, e_op: e_}) 130 | return x_rec 131 | 132 | 133 | def show_reconstruction(actual, recon): 134 | fig, axs = plt.subplots(1, 2) 135 | axs[0].imshow(actual.reshape(28, 28), cmap='gray') 136 | axs[1].imshow(recon.reshape(28, 28), cmap='gray') 137 | axs[0].set_title('actual') 138 | axs[1].set_title('reconstructed') 139 | plt.show() 140 | 141 | 142 | def sample_latent(sess, input_data, z_op, x_op, e_op, Z): 143 | e_ = np.random.normal(size=(input_data.shape[0], Z)) 144 | zs = sess.run(z_op, feed_dict={x_op: input_data, e_op: e_}) 145 | return zs 146 | 147 | 148 | if __name__ == '__main__': 149 | N = 1000 150 | xs = np.vstack(( 151 | np.random.uniform(-6, -2, size=(N//3, 2)), 152 | np.random.multivariate_normal([0, 0], np.eye(2) / 2, size=N//3), 153 | np.random.multivariate_normal([5, -5], np.eye(2) / 2, size=N//3) 154 | )) 155 | ys = np.repeat(np.arange(3), N // 3) 156 | 157 | idxs = np.random.choice(range(xs.shape[0]), xs.shape[0]) 158 | xs, ys = xs[idxs], ys[idxs] 159 | 160 | plt.scatter(xs[:, 0], xs[:, 1], c=ys) 161 | plt.title('original data') 162 | plt.show() 163 | 164 | tf.reset_default_graph() 165 | data = xs 166 | data_dim = xs.shape[1] 167 | 168 | enc_h = 128 169 | enc_z = 2 170 | dec_h = 128 171 | max_iters = 10000 172 | batch_size = data.shape[0] 173 | learning_rate = 0.001 174 | k = 3 175 | 176 | x, e = inputs(data_dim, enc_z) 177 | mu, log_var, z0, lambd = encoder(x, e, data_dim, enc_h, enc_z, k) 178 | z_k, log_detj = norm_flow(z0, lambd, k, enc_z) 179 | out_op, out_mu, out_log_var = decoder(z_k, data_dim, dec_h, enc_z, out_fn=tf.identity) 180 | loss_op = make_loss(out_op, x, log_var, mu, log_detj, z0) 181 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_op) 182 | 183 | sess = tf.InteractiveSession() 184 | sess.run(tf.initialize_all_variables()) 185 | 186 | idx = 0 187 | for i in range(max_iters): 188 | x_ = data[idx:idx + batch_size] 189 | l = train_step(sess, x_, train_op, loss_op, x, e, enc_z) 190 | idx += batch_size 191 | if idx >= x_.shape[0]: 192 | idx = 0 193 | if i % 1000 == 0: 194 | print('iter: %d\tloss: %.2f' % (i, l)) 195 | 196 | zs = sample_latent(sess, xs, z_k, x, e, enc_z) 197 | fig = plt.figure(figsize=(8, 6)) 198 | plt.scatter(zs[:, 0], zs[:, 1], c=ys) 199 | plt.title('latent z values for each point in the dataset') 200 | plt.show() 201 | 202 | k = 500 203 | 204 | # Take a data point from each class, replicate it k times 205 | x_ = np.repeat(data[[(ys == i).argmax() for i in range(3)]], k, axis=0) 206 | y_ = np.repeat(np.arange(3), k) 207 | e_ = np.random.normal(size=(x_.shape[0], enc_z)) 208 | zs = sess.run(z_k, feed_dict={x: x_, e: e_}) 209 | 210 | fig = plt.figure(figsize=(8, 6)) 211 | plt.scatter(zs[:, 0], zs[:, 1], c=y_) 212 | plt.title("Posterior samples") 213 | plt.show() 214 | 215 | reconstructed = reconstruct(sess, 1000, out_op, x, e, enc_z)[0] 216 | plt.scatter(reconstructed[:, 0], reconstructed[:, 1]) 217 | plt.title('reconstructed data') 218 | plt.show() 219 | -------------------------------------------------------------------------------- /normflow.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | import numpy as np 6 | import tensorflow as tf 7 | import synthetic_data 8 | import visualize 9 | import matplotlib.pyplot as plt 10 | 11 | plt.ioff() 12 | 13 | 14 | class PlanarFlow(object): 15 | """ 16 | Planar normalizing flow 17 | equation 10-12, 21-23 in paper https://arxiv.org/pdf/1505.05770.pdf 18 | """ 19 | 20 | def __init__(self, z_dim=2, var_scope='planarflow'): 21 | self.z_dim = z_dim 22 | self.h = tf.tanh 23 | self.var_scope = var_scope 24 | 25 | with tf.variable_scope(var_scope): 26 | initializer = tf.contrib.layers.xavier_initializer_conv2d() 27 | self.u = tf.get_variable('u', initializer=initializer(shape=(z_dim, 1))) 28 | self.w = tf.get_variable('w', initializer=initializer(shape=(z_dim, 1))) 29 | self.b = tf.get_variable('b', initializer=initializer(shape=(1, 1))) 30 | 31 | def __call__(self, z, logp, name='flow'): 32 | """ 33 | :param z: B*z_dim 34 | :param name: 35 | :return: 36 | """ 37 | with tf.name_scope(name): 38 | a = self.h(tf.matmul(z, self.w) + self.b) 39 | psi = tf.matmul(1 - a ** 2, tf.transpose(self.w)) 40 | 41 | # Section A.1, try to make the transformation invertible 42 | x = tf.matmul(tf.transpose(self.w), self.u) 43 | m = -1 + tf.nn.softplus(x) 44 | u_h = self.u + (m - x) * self.w / (tf.matmul(tf.transpose(self.w), self.w)) 45 | 46 | logp = logp - tf.squeeze(tf.log(1 + tf.matmul(psi, u_h))) 47 | z = z + tf.matmul(a, tf.transpose(u_h)) 48 | 49 | return z, logp 50 | 51 | 52 | class NormalizingFlow(object): 53 | """ 54 | Normalizing flow 55 | """ 56 | def __init__(self, z_dim, K=3, name='normalizingflow'): 57 | self.z_dim = z_dim 58 | self.K = K 59 | self.planar_flows = [] 60 | with tf.variable_scope(name): 61 | for i in range(K): 62 | flow = PlanarFlow(z_dim, var_scope='planarflow_' + str(i+1)) 63 | self.planar_flows.append(flow) 64 | 65 | def __call__(self, z, logp, name='normflow'): 66 | with tf.name_scope(name): 67 | for flow in self.planar_flows: 68 | z, logp = flow(z, logp) 69 | 70 | return z, logp 71 | 72 | 73 | def build_network(input_z0_placeholder, log_q0_placehoder, K=32, z_dim=2, name='func_U'): 74 | with tf.variable_scope(name): 75 | normFlow = NormalizingFlow(z_dim=z_dim, K=K) 76 | zk, logqk = normFlow(input_z0_placeholder, log_q0_placehoder) 77 | return zk, logqk 78 | 79 | 80 | def compute_loss(U_func, sum_log_det, z_k): 81 | U_z = U_func(z_k) 82 | U_z = tf.clip_by_value(U_z, -10000, 10000) 83 | kld = sum_log_det + U_z 84 | kld = tf.reduce_mean(kld) 85 | return kld 86 | 87 | 88 | def save(saver, sess, logdir, step, write_meta=False): 89 | model_name = 'model.ckpt' 90 | checkpoint_path = os.path.join(logdir, model_name) 91 | print('Storing checkpoint to {} ...'.format(logdir)) 92 | 93 | # change here 94 | if not tf.gfile.Exists(logdir): 95 | tf.gfile.MakeDirs(logdir) 96 | 97 | saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=write_meta) 98 | print('Save Model Done.') 99 | 100 | 101 | def save_image(sess, zk_arr, logqk_arr, input_z0_placeholder, log_q0_placehoder, sampler, path): 102 | fig, axes = plt.subplots(2, 2) 103 | axes = axes.flatten() 104 | 105 | for u_idx, (zk, logqk) in enumerate(zip(zk_arr, logqk_arr)): 106 | ax = axes[u_idx] 107 | 108 | side = np.linspace(-5, 5, 500) 109 | X, Y = np.meshgrid(side, side) 110 | counts = np.zeros(X.shape) 111 | p = np.zeros(X.shape) 112 | 113 | size = [-5, 5] 114 | num_side = 500 115 | 116 | L = 100 117 | print("Sampling", end='') 118 | for i in range(1000): 119 | z, logq = sampler(L) 120 | z_k, logq_k = sess.run([zk, logqk], feed_dict={input_z0_placeholder: z, log_q0_placehoder: logq}) 121 | # check nan 122 | if np.any(np.isnan(z_k)): 123 | print("NaN detected") 124 | continue 125 | 126 | q_k = np.exp(logq_k) 127 | z_k = (z_k - size[0]) * num_side / (size[1] - size[0]) 128 | for l in range(L): 129 | x, y = int(z_k[l, 1]), int(z_k[l, 0]) 130 | if 0 <= x < num_side and 0 <= y < num_side: 131 | counts[x, y] += 1 132 | p[x, y] += q_k[l] 133 | 134 | counts = np.maximum(counts, np.ones(counts.shape)) 135 | p /= counts 136 | p /= np.sum(p) 137 | Y = -Y 138 | ax.pcolormesh(X, Y, p) 139 | 140 | fig.tight_layout() 141 | plt.savefig(path) 142 | plt.close() 143 | 144 | 145 | if __name__ == '__main__': 146 | # show data 147 | print("show synethtic data, close the data image and continue") 148 | visualize.plot_density() 149 | 150 | K = 32 151 | z_dim = 2 152 | L = 256 153 | steps = 4000000 154 | is_training = True 155 | learning_rate = 0.001 156 | save_model_every_steps = 1000 157 | print_loss_every_steps = 100 158 | logdir = './log/' 159 | logdir = os.path.join(logdir, 'K=' + str(K)) 160 | logdir_image = os.path.join(logdir, 'images') 161 | checkpoint = r'model.ckpt-3980000' 162 | 163 | if not tf.gfile.Exists(logdir_image): 164 | tf.gfile.MakeDirs(logdir_image) 165 | 166 | U1 = getattr(synthetic_data, 'U1_tf') 167 | U2 = getattr(synthetic_data, 'U2_tf') 168 | U3 = getattr(synthetic_data, 'U3_tf') 169 | U4 = getattr(synthetic_data, 'U4_tf') 170 | U_arr = [U1, U2, U3, U4] 171 | input_z0_placeholder = tf.placeholder(tf.float32, [None, 2]) 172 | log_q0_placehoder = tf.placeholder(tf.float32, [None]) 173 | 174 | zk_arr = [] 175 | logqk_arr = [] 176 | loss_arr = [] 177 | train_op_arr = [] 178 | for i, U in enumerate(U_arr): 179 | zk, logqk = build_network(input_z0_placeholder, log_q0_placehoder, K=K, z_dim=z_dim, name="dist/" + U.__name__) 180 | loss = compute_loss(U, logqk, zk) 181 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss) 182 | zk_arr.append(zk) 183 | logqk_arr.append(logqk) 184 | loss_arr.append(loss) 185 | train_op_arr.append(train_op) 186 | 187 | sess = tf.InteractiveSession() 188 | init = tf.global_variables_initializer() 189 | sess.run(init) 190 | 191 | saver = tf.train.Saver(var_list=tf.trainable_variables()) 192 | 193 | # TODO: restore from 194 | if not is_training: 195 | # restore model from checkpoint 196 | saver.restore(sess, checkpoint) 197 | print('Model restore successfully!') 198 | 199 | sampler = synthetic_data.normal_sampler() 200 | if is_training: 201 | for step in range(steps): 202 | z0, log_q0 = sampler(L) 203 | for i, U in enumerate(U_arr): 204 | loss = loss_arr[i] 205 | train_op = train_op_arr[i] 206 | zk = zk_arr[i] 207 | logqk = logqk_arr[i] 208 | 209 | l, _ = sess.run([loss, train_op], feed_dict={input_z0_placeholder: z0, log_q0_placehoder: log_q0}) 210 | if step % print_loss_every_steps == 0: 211 | print("Training {}, step {}, loss={}".format(U.__name__, step, l)) 212 | 213 | if step % save_model_every_steps == 0: 214 | save(saver, sess, logdir, step, write_meta=False) 215 | path = os.path.join(logdir_image, str(step) + '.png') 216 | save_image(sess, zk_arr, logqk_arr, input_z0_placeholder, log_q0_placehoder, sampler, path) 217 | 218 | save(saver, sess, logdir, steps, write_meta=False) 219 | 220 | print("done!") 221 | path = os.path.join(logdir_image, 'final.png') 222 | save_image(sess, zk_arr, logqk_arr, input_z0_placeholder, log_q0_placehoder, sampler, path) 223 | -------------------------------------------------------------------------------- /synthetic_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | # synthetic data in table 1 of paper: https://arxiv.org/abs/1505.05770 8 | 9 | 10 | def w1_tf(z): 11 | return tf.sin(2 * np.pi * z[:, 0] / 4) 12 | 13 | 14 | def w1(z): 15 | return np.sin(2 * np.pi * z[:, 0] / 4) 16 | 17 | 18 | def w2_tf(z): 19 | return 3 * tf.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2) 20 | 21 | 22 | def w2(z): 23 | return 3 * np.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2) 24 | 25 | 26 | def w3_tf(z): 27 | return 3 * tf.sigmoid((z[:, 0] - 1) / 0.3) 28 | 29 | 30 | def w3(z): 31 | return 3 * (1.0 / (1 + np.exp(-(z[:, 0] - 1) / 0.3))) 32 | 33 | 34 | def U1_tf(z): 35 | z_norm = tf.norm(z, 2, 1) 36 | add1 = 0.5 * ((z_norm - 2) / 0.4) ** 2 37 | add2 = -tf.log(tf.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2) + tf.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2) + 1e-9) 38 | return add1 + add2 39 | 40 | 41 | def U1(z): 42 | add1 = 0.5 * ((np.linalg.norm(z, 2, 1) - 2) / 0.4) ** 2 43 | add2 = -np.log(np.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2) + np.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2)) 44 | return add1 + add2 45 | 46 | 47 | def U2(z): 48 | return 0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2 49 | 50 | 51 | def U2_tf(z): 52 | return 0.5 * ((z[:, 1] - w1_tf(z)) / 0.4) ** 2 53 | 54 | 55 | def U3(z): 56 | in1 = np.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.35) ** 2) 57 | in2 = np.exp(-0.5 * ((z[:, 1] - w1(z) + w2(z)) / 0.35) ** 2) 58 | return -np.log(in1 + in2 + 1e-9) 59 | 60 | 61 | def U3_tf(z): 62 | in1 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z)) / 0.35) ** 2) 63 | in2 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z) + w2_tf(z)) / 0.35) ** 2) 64 | return -tf.log(in1 + in2 + 1e-9) 65 | 66 | 67 | def U4(z): 68 | in1 = np.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2) 69 | in2 = np.exp(-0.5 * ((z[:, 1] - w1(z) + w3(z)) / 0.35) ** 2) 70 | return -np.log(in1 + in2) 71 | 72 | 73 | def U4_tf(z): 74 | in1 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z)) / 0.4) ** 2) 75 | in2 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z) + w3_tf(z)) / 0.35) ** 2) 76 | return -tf.log(in1 + in2 + 1e-9) 77 | 78 | 79 | def normal_sampler(mean=np.zeros(2), sigma=np.ones(2)): 80 | dim = mean.shape[0] 81 | 82 | def sampler(N): 83 | z = mean + np.random.randn(N, dim) * sigma 84 | logq = -0.5 * np.sum(2 * np.log(sigma) + np.log(2 * np.pi) + ((z - mean) / sigma) ** 2, 1) 85 | return z, logq 86 | 87 | return sampler 88 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import synthetic_data 7 | 8 | 9 | def compute_density(U_func, Z): 10 | neg_logp = U_func(Z) 11 | p = np.exp(-neg_logp) 12 | p /= np.sum(p) 13 | return p 14 | 15 | 16 | def plot_density(): 17 | fig, axes = plt.subplots(2, 2) 18 | U_list = [synthetic_data.U1, synthetic_data.U2, synthetic_data.U3, synthetic_data.U4] 19 | 20 | space = np.linspace(-5, 5, 500) 21 | X, Y = np.meshgrid(space, space) 22 | shape = X.shape 23 | X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1)) 24 | Z = np.concatenate([X_flatten, Y_flatten], 1) 25 | 26 | # ISSUE 27 | Y = -Y # not sure why, but my plots are upside down compared to paper 28 | 29 | for U, ax in zip(U_list, axes.flatten()): 30 | density = compute_density(U, Z) 31 | density = np.reshape(density, shape) 32 | ax.pcolormesh(X, Y, density) 33 | ax.set_title(U.__name__) 34 | ax.axis('off') 35 | 36 | fig.tight_layout() 37 | plt.savefig('./data.png') 38 | 39 | 40 | if __name__ == '__main__': 41 | plot_density() 42 | --------------------------------------------------------------------------------