├── .gitignore
├── README.md
├── images
├── U_dist.PNG
├── data.png
├── k_16
│ └── k_16.gif
├── k_32
│ └── k_32.gif
├── k_4
│ └── k_4.gif
└── k_8
│ └── k_8.gif
├── nf.py
├── normflow.py
├── synthetic_data.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Variational Inference with Normalizing Flows
2 | Implementation of paper [Variational Inference with Normalizing Flows](https://arxiv.org/abs/1505.05770) section 6.1 experiments.
3 |
4 | This experiment visually demonstrates that **Normalizing Flows** can successfully transform a simple
5 | initial simple distribution q_0(z) to much better approximate some known non-Gaussian Bi-variate
6 | distribution p(z).
7 |
8 | The distributions that we want to learn is as followings:
9 |
10 |
11 | We want to learning the Bi-variate distribution of Z(Z1, Z2). Equation for each of the distribution could be found at paper.
12 |
13 |
14 | Then, our **goal** is to transfer some simple distribution to the target distribution by Normalizing Flows.
15 | We will select Normal distribution as the simple distribution and then transfer simple standard Normal distribution to the complex target distribution.
16 |
17 | # Experiments Results
18 | This section show the experiments results.
19 | 1. K=4
20 |
21 |
22 | 2. K=8
23 |
24 |
25 | 3. K=16
26 |
27 |
28 | 4. K=32
29 |
30 |
31 | # Training Criteria
32 | The known target distributions are specified using energy functions U(z).
33 | p(z) = \frac{1}{Z} e^{-U(z)}, where Z is the unknown partition function (normalization constant);
34 | that is, p(z) \prop e^{-U(z)}.
35 |
36 | ## Steps
37 | 1. Generate random samples from initial distribution z0 ~ q_0(z) = N(z; \mu, \sigma^2 I).
38 | Here \mu and \sigma can either be fixed (such as standard Normal distribution) to something "reasonable", or estimated as follows.
39 | Draw auxillary random variable \eps from standard normal distribution
40 | \eps ~ N(0, I)
41 | and apply linear normalizing flow transformation f(\eps) = \mu + \sigma \eps, re-parameterizing
42 | \sigma = e^{1/2*log_var} to ensure \sigma > 0, then jointly optimize {mu, log_var} together
43 | with the other normalizing flow parameters (see below).
44 | 2. Transform the initial samples z_0 through K **Normalizing Flows** transforms, from which we obtain the
45 | transformed approximate distribution q_K(z),
46 | log q_K(z) = log q_0(z) - sum_{k=1}^K log det |J_k|
47 | where J_k is the Jacobian of the k-th (invertible) normalizing flow transform.
48 | E.g. for planar flows,
49 | log q_K(z) = log q_0(z) - sum_{k=1}^K log |1 + u_k^T \psi_k(z_{k-1})|
50 | where each flow includes model parameters \lambda = {w, u, b}.
51 | 3. Jointly optimize all model parameters by minimizing **KL-Divergence** between the approximate distribution q_K(z)
52 | and the true distribution p(z).
53 | loss = KL[q_K(z)||p(z)] = E_{z_K ~ q_K(z)} [log q_K(z_K) - log p(z_K)]
54 | = E_{z_K ~ q_K(z)} [(log q_0(z_0) - sum_k log det |J_k|) - (-U(z_K) - log(Z))]
55 | = E_{z_0 ~ q_0(z)} [log q_0(z_0) - sum_k log det |J_k| + U(f_1(f_2(..f_K(z_0)))) + log(Z)]
56 | Here the partition function Z is independent of z_0 and model parameters, so we can ignore it for the optimization
57 | loss = E_{z_0 ~ q_0(z)} [log q0(z0) - sum_k log det |J_k| + U(z_K)]
58 |
59 | **The expectation could be approximated by Monte Carlo Sampling, the mini-batch average here could be considered as the
60 | Monte Carlo Sampling Expectation.**
61 |
62 |
63 | ## Another understanding of loss function
64 | ### Loss of the Normalizing flow for 6.1, the synthetic data
65 | Let $$p(z)$$ be the true distribution of bivariate distribution.
66 | and, $$z0 \sim q_0(z)$$ is a simple distribution than we already know, then we need to transform this simple distribution by Normalizing Flow to approximate the true distribution $$p(z)$$ that we want to get.
67 |
68 | We use KL divergence to measure the distant of our learned distribution with the true distribution.
69 | $$
70 | KL(q_k(z_k)||p(z)) = \\
71 | E_{z_k \sim q_k(z_k)}[logq_k(z_k) - logp(z_k)] = \\
72 | E_{z_k \sim q_k(z_k)}[logq_0(z_0) - sum_k logdet(Jacobian) - (-U(z_K) - log(Z))] = \\
73 | E_{z_0 \sim q_0(z)} [log q_0(z_0) - sum_k log det |J_k| + U(z_k) + log(Z)]
74 | $$
75 |
76 | # TODO
77 | * Add IAF
78 | * [https://arxiv.org/pdf/1606.04934.pdf](https://arxiv.org/pdf/1606.04934.pdf)
79 | * [https://arxiv.org/pdf/1502.03509.pdf](https://arxiv.org/pdf/1502.03509.pdf)
80 |
81 | # Reference
82 | [1] Jimenez Rezende, D., Mohamed, S., "Variational Inference with Normalizing Flows",
83 | Proceedings of the 32nd International Conference on Machine Learning, 2015.
84 |
85 | [vae-normflow](https://github.com/16lawrencel/vae-normflow)
86 |
87 | [Reproduce results from sec. 6.1 in "Variational inference using normalizing flows" ](https://github.com/casperkaae/parmesan/issues/22)
88 |
89 | [parmesan/parmesan/layers/flow.py](https://github.com/casperkaae/parmesan/blob/master/parmesan/layers/flow.py)
90 |
91 | [wuaalb/nf6_1.py](https://gist.github.com/wuaalb/c5b85d0c257d44b0d98a)
92 |
--------------------------------------------------------------------------------
/images/U_dist.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/U_dist.PNG
--------------------------------------------------------------------------------
/images/data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/data.png
--------------------------------------------------------------------------------
/images/k_16/k_16.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_16/k_16.gif
--------------------------------------------------------------------------------
/images/k_32/k_32.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_32/k_32.gif
--------------------------------------------------------------------------------
/images/k_4/k_4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_4/k_4.gif
--------------------------------------------------------------------------------
/images/k_8/k_8.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/weixsong/NormalizingFlow/6562f872ef589295c61123c91dc1d63d264af4cf/images/k_8/k_8.gif
--------------------------------------------------------------------------------
/nf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import tensorflow as tf
7 | from tensorflow.examples.tutorials.mnist import input_data
8 |
9 | # define network structure
10 |
11 |
12 | def inputs(input_dim, hidden_dim):
13 | x = tf.placeholder(tf.float32, [None, input_dim], 'x')
14 | e = tf.placeholder(tf.float32, [None, hidden_dim], 'e')
15 | return x, e
16 |
17 |
18 | def encoder(x, e, input_dim, hidden_dim, z_dim, K, initializer=tf.contrib.layers.xavier_initializer):
19 | '''
20 | :param x: input
21 | :param e:
22 | :param input_dim:
23 | :param hidden_dim:
24 | :param z_dim:
25 | :param K: number of normalizing flow
26 | :param initializer:
27 | :return:
28 | '''
29 | with tf.variable_scope('encoder'):
30 | w_h = tf.get_variable('w_h', [input_dim, hidden_dim], initializer=initializer())
31 | b_h = tf.get_variable('b_h', [hidden_dim])
32 | w_mu = tf.get_variable('w_mu', [hidden_dim, z_dim], initializer=initializer())
33 | b_mu = tf.get_variable('b_mu', [z_dim])
34 | w_v = tf.get_variable('w_v', [hidden_dim, z_dim], initializer=initializer())
35 | b_v = tf.get_variable('b_v', [z_dim])
36 |
37 | # Weights for outputting normalizing flow parameters
38 | w_us = tf.get_variable('w_us', [hidden_dim, K*z_dim])
39 | b_us = tf.get_variable('b_us', [K*z_dim])
40 | w_ws = tf.get_variable('w_ws', [hidden_dim, K*z_dim])
41 | b_ws = tf.get_variable('b_ws', [K*z_dim])
42 | w_bs = tf.get_variable('w_bs', [hidden_dim, K])
43 | b_bs = tf.get_variable('b_bs', [K])
44 |
45 | # compute hidden state
46 | h = tf.nn.tanh(tf.matmul(x, w_h) + b_h)
47 | mu = tf.matmul(h, w_mu) + b_mu
48 | log_var = tf.matmul(h, w_v) + b_v
49 | # re-parameterization
50 | z = mu + tf.sqrt(tf.exp(log_var)) * e
51 |
52 | # Normalizing Flow parameters
53 | us = tf.matmul(h, w_us) + b_us
54 | ws = tf.matmul(h, w_ws) + b_ws
55 | bs = tf.matmul(h, w_bs) + b_bs
56 |
57 | t = (us, ws, bs)
58 |
59 | return mu, log_var, z, t
60 |
61 |
62 | def norm_flow(z, lambd, K, Z):
63 | us, ws, bs = lambd
64 |
65 | log_detjs = []
66 | for k in range(K):
67 | u, w, b = us[:, k*Z:(k+1)*Z], ws[:, k*Z:(k+1)*Z], bs[:, k]
68 | temp = tf.expand_dims(tf.nn.tanh(tf.reduce_sum(w*z, 1) + b), 1)
69 | temp = tf.tile(temp, [1, u.get_shape()[1].value])
70 | z = z + tf.multiply(u, temp)
71 |
72 | # Eqn. (11) and (12)
73 | temp = tf.expand_dims(dtanh(tf.reduce_sum(w*z, 1) + b), 1)
74 | temp = tf.tile(temp, [1, w.get_shape()[1].value])
75 | log_detj = tf.abs(1. + tf.reduce_sum(tf.multiply(u, temp*w), 1))
76 | log_detjs.append(log_detj)
77 |
78 | if K != 0:
79 | log_detj = tf.reduce_sum(log_detjs)
80 | else:
81 | log_detj = 0
82 |
83 | return z, log_detj
84 |
85 |
86 | def dtanh(input):
87 | return 1.0 - tf.square(tf.tanh(input))
88 |
89 |
90 | def decoder(z, D, H, Z, initializer=tf.contrib.layers.xavier_initializer, out_fn=tf.sigmoid):
91 | with tf.variable_scope('decoder'):
92 | w_h = tf.get_variable('w_h', [Z, H], initializer=initializer())
93 | b_h = tf.get_variable('b_h', [H])
94 | w_mu = tf.get_variable('w_mu', [H, D], initializer=initializer())
95 | b_mu = tf.get_variable('b_mu', [D])
96 | w_v = tf.get_variable('w_v', [H, 1], initializer=initializer())
97 | b_v = tf.get_variable('b_v', [1])
98 |
99 | h = tf.nn.tanh(tf.matmul(z, w_h) + b_h)
100 | out_mu = tf.matmul(h, w_mu) + b_mu
101 | out_log_var = tf.matmul(h, w_v) + b_v
102 | out = out_fn(out_mu)
103 |
104 | return out, out_mu, out_log_var
105 |
106 |
107 | def make_loss(pred, actual, log_var, mu, log_detj, sigma=1.0):
108 | # kl loss
109 | kl = -tf.reduce_mean(0.5*tf.reduce_sum(1.0 + log_var - tf.square(mu) - tf.exp(log_var), 1))
110 | # re-construct loss
111 | # TODO: re-construct loss should be computed by negative log-likelihood of Bernoulli distribution
112 | # , here is only L2 loss
113 |
114 | rec_err = 0.5*(tf.nn.l2_loss(actual - pred)) / sigma
115 | loss = tf.reduce_mean(kl + rec_err - log_detj)
116 | # TODO: I think it is wrong here to compute the loss, wrong sign for (kl + rec_err), need verify!
117 | # loss = tf.reduce_mean(-kl - rec_err - log_detj) # test this loss
118 | return loss
119 |
120 |
121 | def train_step(sess, input_data, train_op, loss_op, x_op, e_op, Z):
122 | e_ = np.random.normal(size=(input_data.shape[0], Z))
123 | _, l = sess.run([train_op, loss_op], feed_dict={x_op: input_data, e_op: e_})
124 | return l
125 |
126 |
127 | def reconstruct(sess, batch_size, out_op, x_op, e_op, Z):
128 | e_ = np.random.normal(size=(input_data.shape[0], Z))
129 | x_rec = sess.run([out_op], feed_dict={x_op: input_data, e_op: e_})
130 | return x_rec
131 |
132 |
133 | def show_reconstruction(actual, recon):
134 | fig, axs = plt.subplots(1, 2)
135 | axs[0].imshow(actual.reshape(28, 28), cmap='gray')
136 | axs[1].imshow(recon.reshape(28, 28), cmap='gray')
137 | axs[0].set_title('actual')
138 | axs[1].set_title('reconstructed')
139 | plt.show()
140 |
141 |
142 | def sample_latent(sess, input_data, z_op, x_op, e_op, Z):
143 | e_ = np.random.normal(size=(input_data.shape[0], Z))
144 | zs = sess.run(z_op, feed_dict={x_op: input_data, e_op: e_})
145 | return zs
146 |
147 |
148 | if __name__ == '__main__':
149 | N = 1000
150 | xs = np.vstack((
151 | np.random.uniform(-6, -2, size=(N//3, 2)),
152 | np.random.multivariate_normal([0, 0], np.eye(2) / 2, size=N//3),
153 | np.random.multivariate_normal([5, -5], np.eye(2) / 2, size=N//3)
154 | ))
155 | ys = np.repeat(np.arange(3), N // 3)
156 |
157 | idxs = np.random.choice(range(xs.shape[0]), xs.shape[0])
158 | xs, ys = xs[idxs], ys[idxs]
159 |
160 | plt.scatter(xs[:, 0], xs[:, 1], c=ys)
161 | plt.title('original data')
162 | plt.show()
163 |
164 | tf.reset_default_graph()
165 | data = xs
166 | data_dim = xs.shape[1]
167 |
168 | enc_h = 128
169 | enc_z = 2
170 | dec_h = 128
171 | max_iters = 10000
172 | batch_size = data.shape[0]
173 | learning_rate = 0.001
174 | k = 3
175 |
176 | x, e = inputs(data_dim, enc_z)
177 | mu, log_var, z0, lambd = encoder(x, e, data_dim, enc_h, enc_z, k)
178 | z_k, log_detj = norm_flow(z0, lambd, k, enc_z)
179 | out_op, out_mu, out_log_var = decoder(z_k, data_dim, dec_h, enc_z, out_fn=tf.identity)
180 | loss_op = make_loss(out_op, x, log_var, mu, log_detj, z0)
181 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_op)
182 |
183 | sess = tf.InteractiveSession()
184 | sess.run(tf.initialize_all_variables())
185 |
186 | idx = 0
187 | for i in range(max_iters):
188 | x_ = data[idx:idx + batch_size]
189 | l = train_step(sess, x_, train_op, loss_op, x, e, enc_z)
190 | idx += batch_size
191 | if idx >= x_.shape[0]:
192 | idx = 0
193 | if i % 1000 == 0:
194 | print('iter: %d\tloss: %.2f' % (i, l))
195 |
196 | zs = sample_latent(sess, xs, z_k, x, e, enc_z)
197 | fig = plt.figure(figsize=(8, 6))
198 | plt.scatter(zs[:, 0], zs[:, 1], c=ys)
199 | plt.title('latent z values for each point in the dataset')
200 | plt.show()
201 |
202 | k = 500
203 |
204 | # Take a data point from each class, replicate it k times
205 | x_ = np.repeat(data[[(ys == i).argmax() for i in range(3)]], k, axis=0)
206 | y_ = np.repeat(np.arange(3), k)
207 | e_ = np.random.normal(size=(x_.shape[0], enc_z))
208 | zs = sess.run(z_k, feed_dict={x: x_, e: e_})
209 |
210 | fig = plt.figure(figsize=(8, 6))
211 | plt.scatter(zs[:, 0], zs[:, 1], c=y_)
212 | plt.title("Posterior samples")
213 | plt.show()
214 |
215 | reconstructed = reconstruct(sess, 1000, out_op, x, e, enc_z)[0]
216 | plt.scatter(reconstructed[:, 0], reconstructed[:, 1])
217 | plt.title('reconstructed data')
218 | plt.show()
219 |
--------------------------------------------------------------------------------
/normflow.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import os
5 | import numpy as np
6 | import tensorflow as tf
7 | import synthetic_data
8 | import visualize
9 | import matplotlib.pyplot as plt
10 |
11 | plt.ioff()
12 |
13 |
14 | class PlanarFlow(object):
15 | """
16 | Planar normalizing flow
17 | equation 10-12, 21-23 in paper https://arxiv.org/pdf/1505.05770.pdf
18 | """
19 |
20 | def __init__(self, z_dim=2, var_scope='planarflow'):
21 | self.z_dim = z_dim
22 | self.h = tf.tanh
23 | self.var_scope = var_scope
24 |
25 | with tf.variable_scope(var_scope):
26 | initializer = tf.contrib.layers.xavier_initializer_conv2d()
27 | self.u = tf.get_variable('u', initializer=initializer(shape=(z_dim, 1)))
28 | self.w = tf.get_variable('w', initializer=initializer(shape=(z_dim, 1)))
29 | self.b = tf.get_variable('b', initializer=initializer(shape=(1, 1)))
30 |
31 | def __call__(self, z, logp, name='flow'):
32 | """
33 | :param z: B*z_dim
34 | :param name:
35 | :return:
36 | """
37 | with tf.name_scope(name):
38 | a = self.h(tf.matmul(z, self.w) + self.b)
39 | psi = tf.matmul(1 - a ** 2, tf.transpose(self.w))
40 |
41 | # Section A.1, try to make the transformation invertible
42 | x = tf.matmul(tf.transpose(self.w), self.u)
43 | m = -1 + tf.nn.softplus(x)
44 | u_h = self.u + (m - x) * self.w / (tf.matmul(tf.transpose(self.w), self.w))
45 |
46 | logp = logp - tf.squeeze(tf.log(1 + tf.matmul(psi, u_h)))
47 | z = z + tf.matmul(a, tf.transpose(u_h))
48 |
49 | return z, logp
50 |
51 |
52 | class NormalizingFlow(object):
53 | """
54 | Normalizing flow
55 | """
56 | def __init__(self, z_dim, K=3, name='normalizingflow'):
57 | self.z_dim = z_dim
58 | self.K = K
59 | self.planar_flows = []
60 | with tf.variable_scope(name):
61 | for i in range(K):
62 | flow = PlanarFlow(z_dim, var_scope='planarflow_' + str(i+1))
63 | self.planar_flows.append(flow)
64 |
65 | def __call__(self, z, logp, name='normflow'):
66 | with tf.name_scope(name):
67 | for flow in self.planar_flows:
68 | z, logp = flow(z, logp)
69 |
70 | return z, logp
71 |
72 |
73 | def build_network(input_z0_placeholder, log_q0_placehoder, K=32, z_dim=2, name='func_U'):
74 | with tf.variable_scope(name):
75 | normFlow = NormalizingFlow(z_dim=z_dim, K=K)
76 | zk, logqk = normFlow(input_z0_placeholder, log_q0_placehoder)
77 | return zk, logqk
78 |
79 |
80 | def compute_loss(U_func, sum_log_det, z_k):
81 | U_z = U_func(z_k)
82 | U_z = tf.clip_by_value(U_z, -10000, 10000)
83 | kld = sum_log_det + U_z
84 | kld = tf.reduce_mean(kld)
85 | return kld
86 |
87 |
88 | def save(saver, sess, logdir, step, write_meta=False):
89 | model_name = 'model.ckpt'
90 | checkpoint_path = os.path.join(logdir, model_name)
91 | print('Storing checkpoint to {} ...'.format(logdir))
92 |
93 | # change here
94 | if not tf.gfile.Exists(logdir):
95 | tf.gfile.MakeDirs(logdir)
96 |
97 | saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=write_meta)
98 | print('Save Model Done.')
99 |
100 |
101 | def save_image(sess, zk_arr, logqk_arr, input_z0_placeholder, log_q0_placehoder, sampler, path):
102 | fig, axes = plt.subplots(2, 2)
103 | axes = axes.flatten()
104 |
105 | for u_idx, (zk, logqk) in enumerate(zip(zk_arr, logqk_arr)):
106 | ax = axes[u_idx]
107 |
108 | side = np.linspace(-5, 5, 500)
109 | X, Y = np.meshgrid(side, side)
110 | counts = np.zeros(X.shape)
111 | p = np.zeros(X.shape)
112 |
113 | size = [-5, 5]
114 | num_side = 500
115 |
116 | L = 100
117 | print("Sampling", end='')
118 | for i in range(1000):
119 | z, logq = sampler(L)
120 | z_k, logq_k = sess.run([zk, logqk], feed_dict={input_z0_placeholder: z, log_q0_placehoder: logq})
121 | # check nan
122 | if np.any(np.isnan(z_k)):
123 | print("NaN detected")
124 | continue
125 |
126 | q_k = np.exp(logq_k)
127 | z_k = (z_k - size[0]) * num_side / (size[1] - size[0])
128 | for l in range(L):
129 | x, y = int(z_k[l, 1]), int(z_k[l, 0])
130 | if 0 <= x < num_side and 0 <= y < num_side:
131 | counts[x, y] += 1
132 | p[x, y] += q_k[l]
133 |
134 | counts = np.maximum(counts, np.ones(counts.shape))
135 | p /= counts
136 | p /= np.sum(p)
137 | Y = -Y
138 | ax.pcolormesh(X, Y, p)
139 |
140 | fig.tight_layout()
141 | plt.savefig(path)
142 | plt.close()
143 |
144 |
145 | if __name__ == '__main__':
146 | # show data
147 | print("show synethtic data, close the data image and continue")
148 | visualize.plot_density()
149 |
150 | K = 32
151 | z_dim = 2
152 | L = 256
153 | steps = 4000000
154 | is_training = True
155 | learning_rate = 0.001
156 | save_model_every_steps = 1000
157 | print_loss_every_steps = 100
158 | logdir = './log/'
159 | logdir = os.path.join(logdir, 'K=' + str(K))
160 | logdir_image = os.path.join(logdir, 'images')
161 | checkpoint = r'model.ckpt-3980000'
162 |
163 | if not tf.gfile.Exists(logdir_image):
164 | tf.gfile.MakeDirs(logdir_image)
165 |
166 | U1 = getattr(synthetic_data, 'U1_tf')
167 | U2 = getattr(synthetic_data, 'U2_tf')
168 | U3 = getattr(synthetic_data, 'U3_tf')
169 | U4 = getattr(synthetic_data, 'U4_tf')
170 | U_arr = [U1, U2, U3, U4]
171 | input_z0_placeholder = tf.placeholder(tf.float32, [None, 2])
172 | log_q0_placehoder = tf.placeholder(tf.float32, [None])
173 |
174 | zk_arr = []
175 | logqk_arr = []
176 | loss_arr = []
177 | train_op_arr = []
178 | for i, U in enumerate(U_arr):
179 | zk, logqk = build_network(input_z0_placeholder, log_q0_placehoder, K=K, z_dim=z_dim, name="dist/" + U.__name__)
180 | loss = compute_loss(U, logqk, zk)
181 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)
182 | zk_arr.append(zk)
183 | logqk_arr.append(logqk)
184 | loss_arr.append(loss)
185 | train_op_arr.append(train_op)
186 |
187 | sess = tf.InteractiveSession()
188 | init = tf.global_variables_initializer()
189 | sess.run(init)
190 |
191 | saver = tf.train.Saver(var_list=tf.trainable_variables())
192 |
193 | # TODO: restore from
194 | if not is_training:
195 | # restore model from checkpoint
196 | saver.restore(sess, checkpoint)
197 | print('Model restore successfully!')
198 |
199 | sampler = synthetic_data.normal_sampler()
200 | if is_training:
201 | for step in range(steps):
202 | z0, log_q0 = sampler(L)
203 | for i, U in enumerate(U_arr):
204 | loss = loss_arr[i]
205 | train_op = train_op_arr[i]
206 | zk = zk_arr[i]
207 | logqk = logqk_arr[i]
208 |
209 | l, _ = sess.run([loss, train_op], feed_dict={input_z0_placeholder: z0, log_q0_placehoder: log_q0})
210 | if step % print_loss_every_steps == 0:
211 | print("Training {}, step {}, loss={}".format(U.__name__, step, l))
212 |
213 | if step % save_model_every_steps == 0:
214 | save(saver, sess, logdir, step, write_meta=False)
215 | path = os.path.join(logdir_image, str(step) + '.png')
216 | save_image(sess, zk_arr, logqk_arr, input_z0_placeholder, log_q0_placehoder, sampler, path)
217 |
218 | save(saver, sess, logdir, steps, write_meta=False)
219 |
220 | print("done!")
221 | path = os.path.join(logdir_image, 'final.png')
222 | save_image(sess, zk_arr, logqk_arr, input_z0_placeholder, log_q0_placehoder, sampler, path)
223 |
--------------------------------------------------------------------------------
/synthetic_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | # synthetic data in table 1 of paper: https://arxiv.org/abs/1505.05770
8 |
9 |
10 | def w1_tf(z):
11 | return tf.sin(2 * np.pi * z[:, 0] / 4)
12 |
13 |
14 | def w1(z):
15 | return np.sin(2 * np.pi * z[:, 0] / 4)
16 |
17 |
18 | def w2_tf(z):
19 | return 3 * tf.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2)
20 |
21 |
22 | def w2(z):
23 | return 3 * np.exp(-0.5 * ((z[:, 0] - 1) / 0.6) ** 2)
24 |
25 |
26 | def w3_tf(z):
27 | return 3 * tf.sigmoid((z[:, 0] - 1) / 0.3)
28 |
29 |
30 | def w3(z):
31 | return 3 * (1.0 / (1 + np.exp(-(z[:, 0] - 1) / 0.3)))
32 |
33 |
34 | def U1_tf(z):
35 | z_norm = tf.norm(z, 2, 1)
36 | add1 = 0.5 * ((z_norm - 2) / 0.4) ** 2
37 | add2 = -tf.log(tf.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2) + tf.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2) + 1e-9)
38 | return add1 + add2
39 |
40 |
41 | def U1(z):
42 | add1 = 0.5 * ((np.linalg.norm(z, 2, 1) - 2) / 0.4) ** 2
43 | add2 = -np.log(np.exp(-0.5 * ((z[:, 0] - 2) / 0.6) ** 2) + np.exp(-0.5 * ((z[:, 0] + 2) / 0.6) ** 2))
44 | return add1 + add2
45 |
46 |
47 | def U2(z):
48 | return 0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2
49 |
50 |
51 | def U2_tf(z):
52 | return 0.5 * ((z[:, 1] - w1_tf(z)) / 0.4) ** 2
53 |
54 |
55 | def U3(z):
56 | in1 = np.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.35) ** 2)
57 | in2 = np.exp(-0.5 * ((z[:, 1] - w1(z) + w2(z)) / 0.35) ** 2)
58 | return -np.log(in1 + in2 + 1e-9)
59 |
60 |
61 | def U3_tf(z):
62 | in1 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z)) / 0.35) ** 2)
63 | in2 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z) + w2_tf(z)) / 0.35) ** 2)
64 | return -tf.log(in1 + in2 + 1e-9)
65 |
66 |
67 | def U4(z):
68 | in1 = np.exp(-0.5 * ((z[:, 1] - w1(z)) / 0.4) ** 2)
69 | in2 = np.exp(-0.5 * ((z[:, 1] - w1(z) + w3(z)) / 0.35) ** 2)
70 | return -np.log(in1 + in2)
71 |
72 |
73 | def U4_tf(z):
74 | in1 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z)) / 0.4) ** 2)
75 | in2 = tf.exp(-0.5 * ((z[:, 1] - w1_tf(z) + w3_tf(z)) / 0.35) ** 2)
76 | return -tf.log(in1 + in2 + 1e-9)
77 |
78 |
79 | def normal_sampler(mean=np.zeros(2), sigma=np.ones(2)):
80 | dim = mean.shape[0]
81 |
82 | def sampler(N):
83 | z = mean + np.random.randn(N, dim) * sigma
84 | logq = -0.5 * np.sum(2 * np.log(sigma) + np.log(2 * np.pi) + ((z - mean) / sigma) ** 2, 1)
85 | return z, logq
86 |
87 | return sampler
88 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import synthetic_data
7 |
8 |
9 | def compute_density(U_func, Z):
10 | neg_logp = U_func(Z)
11 | p = np.exp(-neg_logp)
12 | p /= np.sum(p)
13 | return p
14 |
15 |
16 | def plot_density():
17 | fig, axes = plt.subplots(2, 2)
18 | U_list = [synthetic_data.U1, synthetic_data.U2, synthetic_data.U3, synthetic_data.U4]
19 |
20 | space = np.linspace(-5, 5, 500)
21 | X, Y = np.meshgrid(space, space)
22 | shape = X.shape
23 | X_flatten, Y_flatten = np.reshape(X, (-1, 1)), np.reshape(Y, (-1, 1))
24 | Z = np.concatenate([X_flatten, Y_flatten], 1)
25 |
26 | # ISSUE
27 | Y = -Y # not sure why, but my plots are upside down compared to paper
28 |
29 | for U, ax in zip(U_list, axes.flatten()):
30 | density = compute_density(U, Z)
31 | density = np.reshape(density, shape)
32 | ax.pcolormesh(X, Y, density)
33 | ax.set_title(U.__name__)
34 | ax.axis('off')
35 |
36 | fig.tight_layout()
37 | plt.savefig('./data.png')
38 |
39 |
40 | if __name__ == '__main__':
41 | plot_density()
42 |
--------------------------------------------------------------------------------