├── .gitignore ├── LICENSE ├── README.md ├── data └── mnist.pkl.gz └── src ├── __init__.py ├── gmvae.py └── vi_GMM_2d.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### Variational Inference/Variational AutoEncoder + Gaussian Mixtures implementation in zhusuan 2 | * `vi_GMM_2d.py`: a toy example of Variational Inference + Gaussian Mixture in 2D 3 | * `gmvae.py`: Variational AutoEncoder + Gaussian Mixture, using MNIST dataset -------------------------------------------------------------------------------- /data/mnist.pkl.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangg12/vi_vae_gmm/49a6e848c13f9bdcd9c3858dd3ef90e987b0f966/data/mnist.pkl.gz -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangg12/vi_vae_gmm/49a6e848c13f9bdcd9c3858dd3ef90e987b0f966/src/__init__.py -------------------------------------------------------------------------------- /src/gmvae.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division 2 | import os 3 | import sys 4 | THIS_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.insert(1, os.path.join(THIS_DIR, '../zhusuan/')) 6 | 7 | import tensorflow as tf 8 | from tensorflow.contrib import layers 9 | from six.moves import range 10 | import six.moves.cPickle as pickle 11 | import time 12 | import numpy as np 13 | import matplotlib as mpl 14 | mpl.use('Agg') # TkAgg to show 15 | import matplotlib.pyplot as plt 16 | import matplotlib.gridspec as gridspec 17 | import shutil 18 | import zhusuan as zs 19 | import random 20 | from tqdm import tqdm 21 | from skimage import io, img_as_ubyte 22 | from skimage.exposure import rescale_intensity 23 | 24 | # from examples import conf 25 | from examples.utils import dataset, save_image_collections 26 | 27 | def main(): 28 | # manual seed 29 | #seed = random.randint(0, 10000) # fix seed 30 | seed = 1234 # N=100, K=3 31 | print("Random Seed: ", seed) 32 | random.seed(seed) 33 | np.random.seed(seed) 34 | tf.set_random_seed(seed) 35 | 36 | 37 | # load MNIST data --------------------------------------------------------- 38 | data_path = os.path.join('../data/', 'mnist.pkl.gz') 39 | x_train, t_train, x_valid, t_valid, x_test, t_test = \ 40 | dataset.load_mnist_realval(data_path) 41 | x_train = np.vstack([x_train, x_valid]).astype('float32') 42 | 43 | # model parameters -------------------------------------------------------- 44 | K = 10 45 | D = 40 46 | dim_z = K 47 | dim_h = D 48 | dim_x = x_train.shape[1] # 784 49 | N = x_train.shape[0] 50 | 51 | # Define training/evaluation parameters --------------------------------------------- 52 | resume = False 53 | epoches = 50 # 2000 54 | save_freq = 5 55 | batch_size = 100 56 | train_iters = int(np.ceil(N / batch_size)) 57 | 58 | learning_rate = 0.001 59 | anneal_lr_freq = 10 60 | anneal_lr_rate = 0.9 61 | n_particles = 20 62 | 63 | n_gen = 100 64 | 65 | result_path = "./results/3_gmvae" 66 | 67 | @zs.reuse(scope='decoder') 68 | def vae(observed, n, n_particles, is_training, dim_h=40, dim_z=10, dim_x=784): 69 | '''decoder: z-->h-->x 70 | n: batch_size 71 | dim_z: K = 10 72 | dim_x: 784 73 | dim_h: D = 40 74 | ''' 75 | with zs.BayesianNet(observed=observed) as model: 76 | normalizer_params = {'is_training': is_training, 77 | 'updates_collections': None} 78 | pai = tf.get_variable('pai', shape=[dim_z], 79 | dtype=tf.float32, 80 | trainable=True, 81 | initializer=tf.constant_initializer(1.0) 82 | ) 83 | n_pai = tf.tile(tf.expand_dims(pai, 0), [n, 1]) 84 | z = zs.OnehotCategorical('z', logits=n_pai, 85 | dtype=tf.float32, 86 | n_samples=n_particles 87 | ) 88 | mu = tf.get_variable('mu', shape=[dim_z, dim_h], 89 | dtype=tf.float32, 90 | initializer=tf.random_uniform_initializer(-1, 1)) 91 | log_sigma = tf.get_variable('log_sigma', shape=[dim_z, dim_h], 92 | dtype=tf.float32, 93 | initializer=tf.random_uniform_initializer(-3, -2) 94 | ) 95 | h_mean = tf.reshape(tf.matmul(tf.reshape(z, [-1, dim_z]), mu), [n_particles, -1, dim_h]) # [n_particles, None, dim_x] 96 | h_logstd = tf.reshape(tf.matmul(tf.reshape(z, [-1, dim_z]), log_sigma), [n_particles, -1, dim_h]) 97 | 98 | h = zs.Normal('h', mean=h_mean, logstd=h_logstd, 99 | #n_samples=n_particles, 100 | group_event_ndims=1 101 | ) 102 | lx_h = layers.fully_connected( 103 | h, 512, 104 | # normalizer_fn=layers.batch_norm, 105 | # normalizer_params=normalizer_params 106 | ) 107 | lx_h = layers.fully_connected( 108 | lx_h, 512, 109 | # normalizer_fn=layers.batch_norm, 110 | # normalizer_params=normalizer_params 111 | ) 112 | x_logits = layers.fully_connected(lx_h, dim_x, activation_fn=None) # the log odds of being 1 113 | x = zs.Bernoulli('x', x_logits, 114 | #n_samples=n_particles, 115 | group_event_ndims=1) 116 | return model, x_logits, h, z.tensor 117 | 118 | 119 | @zs.reuse(scope='encoder') 120 | def q_net(x, dim_h, n_particles, is_training): 121 | '''encoder: x-->h''' 122 | with zs.BayesianNet() as variational: 123 | normalizer_params = {'is_training': is_training, 124 | # 'updates_collections': None 125 | } 126 | lh_x = layers.fully_connected(tf.to_float(x), 512, 127 | # normalizer_fn=layers.batch_norm, 128 | # normalizer_params=normalizer_params, 129 | weights_initializer=tf.contrib.layers.xavier_initializer()) 130 | lh_x = tf.contrib.layers.dropout(lh_x, keep_prob=0.9, is_training=is_training) 131 | lh_x = layers.fully_connected(lh_x, 512, 132 | # normalizer_fn=layers.batch_norm, 133 | # normalizer_params=normalizer_params, 134 | weights_initializer=tf.contrib.layers.xavier_initializer()) 135 | lh_x = tf.contrib.layers.dropout(lh_x, keep_prob=0.9, is_training=is_training) 136 | h_mean = layers.fully_connected(lh_x, dim_h, activation_fn=None, 137 | weights_initializer=tf.contrib.layers.xavier_initializer()) 138 | h_logstd = layers.fully_connected(lh_x, dim_h, activation_fn=None, 139 | weights_initializer=tf.contrib.layers.xavier_initializer()) 140 | h = zs.Normal('h', mean=h_mean, logstd=h_logstd, 141 | n_samples=n_particles, 142 | group_event_ndims=1 143 | ) 144 | return variational 145 | 146 | 147 | x_ph = tf.placeholder(tf.int32, shape=[None, dim_x], name='x_ph') 148 | x_orig_ph = tf.placeholder(tf.float32, shape=[None, dim_x], name='x_orig_ph') 149 | x_bin = tf.cast(tf.less(tf.random_uniform(tf.shape(x_orig_ph), 0, 1), x_orig_ph), tf.int32) 150 | is_training_ph = tf.placeholder(tf.bool, shape=[], name='is_training_ph') 151 | 152 | n = tf.shape(x_ph)[0] 153 | 154 | 155 | def log_joint(observed): 156 | z_obs = tf.eye(dim_z, batch_shape=[n_particles, n]) 157 | z_obs = tf.transpose(z_obs, [2, 0, 1, 3]) # [K, n_p, bs, K] 158 | log_pz_list = [] 159 | log_ph_z_list = [] 160 | log_px_h = None 161 | for i in range(dim_z): 162 | observed['z'] = z_obs[i,:] # the i-th dimension is 1 163 | model, _, _, _ = vae(observed, n, n_particles, is_training_ph, dim_h=dim_h, dim_z=dim_z, dim_x=dim_x) 164 | log_pz_i, log_ph_z_i, log_px_h = model.local_log_prob(['z', 'h', 'x']) 165 | log_pz_list.append(log_pz_i) 166 | log_ph_z_list.append(log_ph_z_i) 167 | log_pz = tf.stack(log_pz_list, axis=0) 168 | log_ph_z = tf.stack(log_ph_z_list, axis=0) 169 | # p(X, H) = p(X|H) sum_Z(p(Z) * p(H|Z)) 170 | # log p(X, H) = log p(X|H) + log sum_Z exp(log p(Z) + log p(H|Z)) 171 | log_p_xh = log_px_h + tf.reduce_logsumexp(log_pz + log_ph_z, axis=0) # log p(X, H) 172 | return log_p_xh 173 | 174 | variational = q_net(x_ph, dim_h, n_particles, is_training_ph) 175 | qh_samples, log_qh = variational.query('h', outputs=True, 176 | local_log_prob=True) 177 | 178 | x_obs = tf.tile(tf.expand_dims(x_ph, 0), [n_particles, 1, 1]) 179 | 180 | lower_bound = zs.sgvb(log_joint, 181 | observed={'x': x_obs}, 182 | latent={'h': [qh_samples, log_qh]}, 183 | axis=0) 184 | 185 | mean_lower_bound = tf.reduce_mean(lower_bound) 186 | with tf.name_scope('neg_lower_bound'): 187 | neg_lower_bound = tf.reduce_mean(- mean_lower_bound) 188 | 189 | train_vars = tf.trainable_variables() 190 | with tf.variable_scope('decoder', reuse=True): 191 | pai = tf.get_variable('pai') 192 | mu = tf.get_variable('mu') 193 | log_sigma = tf.get_variable('log_sigma') 194 | 195 | clip_pai = pai.assign(tf.clip_by_value(pai, 0.7, 1.3)) 196 | 197 | # _, pai_var = tf.nn.moments(pai, axes=[-1]) 198 | # _, mu_var = tf.nn.moments(mu, axes=[0, 1], keep_dims=False) 199 | # regularizer = tf.add_n([tf.nn.l2_loss(v) for v in train_vars 200 | # if not 'pai' in v.name and not 'mu' in v.name]) 201 | # loss = neg_lower_bound + pai_var - mu_var # + 1e-4 * regularizer # loss ------------- 202 | loss = neg_lower_bound #+ 0.001 * tf.nn.l2_loss(mu-1) 203 | 204 | learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr') 205 | 206 | optimizer = tf.train.AdamOptimizer(learning_rate_ph, epsilon=1e-4) 207 | grads_and_vars = optimizer.compute_gradients(loss) 208 | clipped_gvs = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in grads_and_vars] 209 | infer = optimizer.apply_gradients(clipped_gvs) 210 | 211 | # Generate images ----------------------------------------------------- 212 | z_manual_feed = tf.eye(dim_z, batch_shape=[10]) # [10, K, K] 213 | z_manual_feed = tf.transpose(z_manual_feed, [1, 0, 2]) # [K, 10, K] 214 | _, x_logits, _, z_onehot = vae({'z': z_manual_feed}, 10, n_particles=1, is_training=False, 215 | dim_h=dim_h, dim_z=dim_z, dim_x=dim_x) # n and n_particles do not matter, since we have manually feeded z 216 | print('x_logits:', x_logits.shape.as_list()) # [1, 100, 784] 217 | x_gen = tf.reshape(tf.sigmoid(x_logits), [-1, 28, 28, 1]) 218 | z_gen = tf.argmax(tf.reshape(z_onehot, [-1, dim_z]), axis=1) 219 | 220 | 221 | # tensorboard summary --------------------------------------------------- 222 | image_for_summ = [] 223 | for i in range(n_gen//10): 224 | tmp = [x_gen[j+i*10,:] for j in range(10)] 225 | tmp = tf.concat(tmp, 1) 226 | image_for_summ.append(tmp) 227 | image_for_summ = tf.expand_dims(tf.concat(image_for_summ, 0), 0) 228 | print('image_for_summ:', image_for_summ.shape.as_list()) 229 | gen_image_summ = tf.summary.image('gen_images', image_for_summ, max_outputs=100) 230 | lb_summ = tf.summary.scalar("lower_bound", mean_lower_bound) 231 | lr_summ = tf.summary.scalar("learning_rate", learning_rate_ph) 232 | loss_summ = tf.summary.scalar('loss', loss) 233 | 234 | for var in train_vars: 235 | tf.summary.histogram(var.name, var) 236 | for grad, _ in grads_and_vars: 237 | tf.summary.histogram(grad.name, grad) 238 | 239 | for i in train_vars: 240 | print(i.name, i.get_shape()) 241 | # Merge all summaries into a single op 242 | merged_summary_op = tf.summary.merge_all() 243 | 244 | saver = tf.train.Saver(max_to_keep=10) 245 | 246 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) 247 | config.gpu_options.allow_growth = True 248 | config.gpu_options.per_process_gpu_memory_fraction = 0.3 249 | 250 | with tf.Session(config=config) as sess: 251 | sess.run(tf.global_variables_initializer()) 252 | 253 | # Restore from the latest checkpoint 254 | ckpt_file = tf.train.latest_checkpoint(result_path) 255 | begin_epoch = 1 256 | if ckpt_file is not None and resume: # resume --------------------------------------- 257 | print('Restoring model from {}...'.format(ckpt_file)) 258 | begin_epoch = int(ckpt_file.split('.')[-2]) + 1 259 | saver.restore(sess, ckpt_file) 260 | 261 | x_train_normed = x_train # no normalization 262 | x_train_normed_no_shuffle = x_train_normed 263 | 264 | 265 | log_dir = './log/3_gmvae/' 266 | if os.path.exists(log_dir): 267 | shutil.rmtree(log_dir) 268 | summary_writer = tf.summary.FileWriter(log_dir, graph=tf.get_default_graph()) 269 | 270 | global mu_res, log_sigma_res, pai_res 271 | global gen_images, z_gen_res, epoch 272 | print('training...') # ---------------------------------------------------------------- 273 | pai_res_0, mu_res_0, log_sigma_res_0 = sess.run([pai, mu, log_sigma]) 274 | global_step = 0 275 | for epoch in tqdm(range(begin_epoch, epoches + 1)): 276 | time_epoch = -time.time() 277 | if epoch % anneal_lr_freq == 0: 278 | learning_rate *= anneal_lr_rate 279 | np.random.shuffle(x_train_normed) # shuffle training data 280 | lbs = [] 281 | 282 | for t in tqdm(range(train_iters)): 283 | global_step += 1 284 | x_batch = x_train_normed[t * batch_size : (t + 1) * batch_size] # get batched data 285 | x_batch_bin = sess.run(x_bin, feed_dict={x_orig_ph: x_batch}) 286 | # sess.run(clip_pai) 287 | _, lb, merge_all = sess.run([infer, mean_lower_bound, merged_summary_op], 288 | feed_dict={x_ph: x_batch_bin, 289 | learning_rate_ph: learning_rate, 290 | is_training_ph: True}) 291 | lbs.append(lb) 292 | time_epoch += time.time() 293 | print('Epoch {} ({:.1f}s): Lower bound = {}'.format( 294 | epoch, time_epoch, np.mean(lbs))) 295 | # print(grad_var_res[-3:]) 296 | 297 | 298 | summary_writer.add_summary(merge_all, global_step=epoch) 299 | 300 | if epoch % save_freq == 0: # save --------------------------------------------------- 301 | print('Saving model...') 302 | save_path = os.path.join(result_path, "gmvae.epoch.{}.ckpt".format(epoch)) 303 | if not os.path.exists(os.path.dirname(save_path)): 304 | os.makedirs(os.path.dirname(save_path)) 305 | saver.save(sess, save_path) 306 | 307 | gen_images, z_gen_res = sess.run([x_gen, z_gen]) #, feed_dict={is_training_ph: False}) 308 | 309 | # dump data 310 | pai_res, mu_res, log_sigma_res = sess.run([pai, mu, log_sigma]) 311 | data_dump = {'epoch':epoch, 312 | 'images': gen_images, 'clusters': z_gen_res, 313 | 'pai_0': pai_res_0, 'mu_0': mu_res_0, 'log_sigma_0': log_sigma_res_0, 314 | 'pai_res': pai_res, 'mu_res': mu_res, 'log_sigma_res': log_sigma_res 315 | } 316 | pickle.dump(data_dump, open(os.path.join(result_path, 'gmvae_results_epoch_{}.pkl'.format(epoch)), 'w'), protocol=2) 317 | save_image_with_clusters(gen_images, z_gen_res, filename="results/3_gmvae/gmvae_epoch_{}.png".format(epoch)) 318 | print('Done') 319 | 320 | 321 | pai_res, mu_res, log_sigma_res = sess.run([pai, mu, log_sigma]) 322 | print("Random Seed: ", seed) 323 | data_dump = {'epoch':epoch, 324 | 'images': gen_images, 'clusters': z_gen_res, 325 | 'pai_0': pai_res_0, 'mu_0': mu_res_0, 'log_sigma_0': log_sigma_res_0, 326 | 'pai_res': pai_res, 'mu_res': mu_res, 'log_sigma_res': log_sigma_res 327 | } 328 | pickle.dump(data_dump, open(os.path.join(result_path, 'gmvae_results_epoch_{}.pkl'.format(epoch)), 'w'), protocol=2) 329 | plot_images_and_clusters(gen_images, z_gen_res, epoch, save_path=result_path, ncol=10) 330 | 331 | 332 | def save_images_and_clusters(images, clusters, epoch, shape=(10,10)): 333 | for i in range(10): 334 | name_i = "results/3_gmvae/epoch_{}/cluster_{}.png".format(epoch, i) 335 | images_i = images[clusters==i, :] 336 | if images_i.shape[0] == 0: 337 | continue 338 | save_image_collections(images_i, name_i, shape=shape) 339 | 340 | 341 | def makedirs(filename): 342 | if not os.path.exists(os.path.dirname(filename)): 343 | os.makedirs(os.path.dirname(filename)) 344 | 345 | 346 | def save_image_with_clusters(x, clusters, filename, shape=(10, 10), scale_each=False, 347 | transpose=False): 348 | '''single image, each row is a cluster''' 349 | makedirs(filename) 350 | n = x.shape[0] 351 | 352 | images = np.zeros_like(x) 353 | curr_len = 0 354 | for i in range(10): 355 | images_i = x[clusters==i, :] 356 | n_i = images_i.shape[0] 357 | images[curr_len : curr_len+n_i, :] = images_i 358 | curr_len += n_i 359 | 360 | x = images 361 | 362 | if transpose: 363 | x = x.transpose(0, 2, 3, 1) 364 | if scale_each is True: 365 | for i in range(n): 366 | x[i] = rescale_intensity(x[i], out_range=(0, 1)) 367 | 368 | n_channels = x.shape[3] 369 | x = img_as_ubyte(x) 370 | r, c = shape 371 | if r * c < n: 372 | print('Shape too small to contain all images') 373 | h, w = x.shape[1:3] 374 | ret = np.zeros((h * r, w * c, n_channels), dtype='uint8') 375 | for i in range(r): 376 | for j in range(c): 377 | if i * c + j < n: 378 | ret[i * h:(i + 1) * h, j * w:(j + 1) * w, :] = x[i * c + j] 379 | ret = ret.squeeze() 380 | io.imsave(filename, ret) 381 | 382 | 383 | def plot_images_and_clusters(images, clusters, epoch, save_path, ncol=10): 384 | '''use multiple images''' 385 | fig = plt.figure()#facecolor='black') 386 | images = np.squeeze(images, -1) 387 | 388 | nrow = int(np.ceil(images.shape[0] / float(ncol))) 389 | gs = gridspec.GridSpec(nrow, ncol, 390 | width_ratios=[1]*ncol, height_ratios=[1]*nrow, 391 | # wspace=0.01, hspace=0.001, 392 | # top=0.95, bottom=0.05, 393 | # left=0.05, right=0.95 394 | ) 395 | gs.update(wspace=0, hspace=0) 396 | n = 0 397 | for i in range(10): 398 | images_i = images[clusters==i, :, :] 399 | if images_i.shape[0] == 0: 400 | continue 401 | 402 | for j in range(images_i.shape[0]): 403 | ax = plt.subplot(gs[n]) 404 | n += 1 405 | plt.imshow(images_i[j,:], cmap='gray') 406 | plt.axis('off') 407 | ax.set_aspect('auto') 408 | plt.savefig(os.path.join(save_path, 'plot_gmvae_epoch_{}.png'.format(epoch)), dpi=fig.dpi) 409 | 410 | if __name__ == "__main__": 411 | main() 412 | 413 | 414 | 415 | 416 | 417 | -------------------------------------------------------------------------------- /src/vi_GMM_2d.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function, division 2 | import os 3 | import sys 4 | THIS_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.insert(1, os.path.join(THIS_DIR, '../zhusuan/')) 6 | 7 | import tensorflow as tf 8 | from tensorflow.contrib import layers 9 | from six.moves import range 10 | import numpy as np 11 | import matplotlib as mpl 12 | mpl.use('Agg') # TkAgg to show 13 | import matplotlib.pyplot as plt 14 | from matplotlib.patches import Ellipse 15 | import hickle 16 | import io 17 | import shutil 18 | import zhusuan as zs 19 | import random 20 | 21 | # from examples import conf 22 | from examples.utils import dataset, save_image_collections 23 | 24 | 25 | def main(N=100, K=3): 26 | # manual seed 27 | seed = random.randint(0, 10000) # fix seed 28 | # seed = 3899 # N=100, K=3 29 | print("Random Seed: ", seed) 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | tf.set_random_seed(seed) 33 | 34 | # tf.set_random_seed(2333) 35 | # np.random.seed(4567) 36 | 37 | D = 2 38 | n_gen = N # the number of generated samples x 39 | dim_z = K 40 | dim_x = D 41 | 42 | # Define training parameters --------------------------------------------- 43 | epoches = 2000 44 | batch_size = min(100, N) 45 | iters_per_batch = N // batch_size 46 | save_freq = 10 47 | plot_freq = 200 48 | learning_rate = 0.001 49 | anneal_lr_freq = 200 50 | anneal_lr_rate = 0.9 51 | n_particles = 100 52 | 53 | @zs.reuse(scope='decoder') 54 | def vae(observed, n, dim_x, dim_z, n_particles): 55 | '''decoder: z-->x''' 56 | with zs.BayesianNet(observed=observed) as model: 57 | pai = tf.get_variable('pai', shape=[dim_z], 58 | dtype=tf.float32, 59 | trainable=True, 60 | initializer=tf.constant_initializer(1.0), #tf.random_uniform_initializer(), #tf.ones([dim_z]), 61 | ) 62 | n_pai = tf.tile(tf.expand_dims(pai, 0), [n, 1]) 63 | z = zs.OnehotCategorical('z', logits=n_pai, 64 | dtype=tf.float32, 65 | n_samples=n_particles 66 | #group_event_ndims=1 67 | ) # zhusuan.model.stochastic.OnehotCategorical 68 | print('-'*10, 'z:', z.tensor.get_shape().as_list()) # [n_particles, None, dim_z] 69 | mu = tf.get_variable('mu', shape=[dim_z, dim_x], 70 | dtype=tf.float32, 71 | initializer=tf.random_uniform_initializer(0, 1)) 72 | log_sigma = tf.get_variable('log_sigma', shape=[dim_z, dim_x], 73 | dtype=tf.float32, 74 | initializer=tf.random_uniform_initializer(-3, -2) 75 | ) # tf.random_normal_initializer(-3, 0.5)) #tf.contrib.layers.xavier_initializer()) 76 | x_mean = tf.reshape(tf.matmul(tf.reshape(z, [-1, dim_z]), mu), [n_particles, n, dim_x]) # [n_particles, None, dim_x] 77 | x_logstd = tf.reshape(tf.matmul(tf.reshape(z, [-1, dim_z]), log_sigma), [n_particles, n, dim_x]) 78 | 79 | # print('x_mean:', x_mean.get_shape().as_list()) 80 | # print('x_logstd:', x_logstd.get_shape().as_list()) 81 | x = zs.Normal('x', mean=x_mean, logstd=x_logstd, group_event_ndims=1) 82 | # print('x:', x.tensor.get_shape().as_list()) 83 | return model, x.tensor, z.tensor 84 | 85 | 86 | @zs.reuse(scope='encoder') 87 | def q_net(x, dim_z, n_particles): 88 | '''encoder: x-->z''' 89 | with zs.BayesianNet() as variational: 90 | lz_x = layers.fully_connected(tf.to_float(x), 256, 91 | weights_initializer=tf.contrib.layers.xavier_initializer()) 92 | # lz_x = layers.fully_connected(lz_x, 256, 93 | # weights_initializer=tf.contrib.layers.xavier_initializer()) 94 | z_logits = layers.fully_connected(lz_x, dim_z, activation_fn=None, 95 | weights_initializer=tf.contrib.layers.xavier_initializer()) 96 | z = zs.OnehotCategorical('z', logits=z_logits, dtype=tf.float32, 97 | n_samples=n_particles, 98 | #group_event_ndims=1 99 | ) 100 | return variational, z_logits 101 | 102 | # def baseline_net(x): 103 | # with tf.variable_scope('baseline_net'): 104 | # lc_x = layers.fully_connected(tf.to_float(x), 100, 105 | # weights_initializer=tf.contrib.layers.xavier_initializer()) 106 | # lc_x = layers.fully_connected(lc_x, 1, activation_fn=None, 107 | # weights_initializer=tf.contrib.layers.xavier_initializer()) 108 | # lc_x = tf.squeeze(lc_x, -1) 109 | # return lc_x 110 | 111 | x_ph = tf.placeholder(tf.float32, shape=[None, dim_x], name='x') 112 | # is_training = tf.placeholder(tf.bool, shape=[], name='is_training') 113 | n = tf.shape(x_ph)[0] 114 | 115 | 116 | def log_joint(observed): 117 | model, _, _ = vae(observed, n, dim_x, dim_z, n_particles) 118 | log_pz, log_px_z = model.local_log_prob(['z', 'x']) 119 | return log_pz + log_px_z 120 | 121 | variational, q_cluster = q_net(x_ph, dim_z, n_particles) 122 | qz_samples, log_qz = variational.query('z', outputs=True, 123 | local_log_prob=True) 124 | 125 | # cx = tf.expand_dims(baseline_net(x_ph), 0) 126 | x_obs = tf.tile(tf.expand_dims(x_ph, 0), [n_particles, 1, 1]) 127 | surrogate_cost, lower_bound = zs.nvil(log_joint, 128 | observed={'x': x_obs}, 129 | latent={'z': [qz_samples, log_qz]}, 130 | #baseline=cx, 131 | axis=0) 132 | # print('-'*10) 133 | mean_lower_bound = tf.reduce_mean(lower_bound) 134 | with tf.name_scope('model_loss'): 135 | loss = tf.reduce_mean(surrogate_cost) 136 | 137 | train_vars = tf.trainable_variables() 138 | learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr') 139 | optimizer = tf.train.AdamOptimizer(learning_rate_ph, epsilon=1e-4) 140 | grads_and_vars = optimizer.compute_gradients(loss) 141 | infer = optimizer.apply_gradients(grads_and_vars) 142 | 143 | # Generate x samples 144 | _, x_gen, z_gen = vae({}, 1, dim_x, dim_z, n_particles=n_gen) 145 | x_gen = tf.squeeze(x_gen, 1) 146 | z_gen = tf.squeeze(z_gen, 1) 147 | 148 | # tensorboard summary --------------------------------------------------- 149 | lb_summ = tf.summary.scalar("lower_bound", mean_lower_bound) 150 | loss_summ = tf.summary.scalar("loss", loss) 151 | lr_ = tf.reduce_mean(learning_rate_ph, name='lr_') 152 | lr_summ = tf.summary.scalar("learning_rate", lr_) 153 | 154 | for var in train_vars: 155 | tf.summary.histogram(var.name, var) 156 | 157 | for i in train_vars: 158 | print(i.name, i.get_shape()) 159 | # Merge all summaries into a single op 160 | merged_summary_op = tf.summary.merge_all() 161 | 162 | plot_buf_ph = tf.placeholder(tf.string) 163 | image = tf.image.decode_png(plot_buf_ph, channels=4) 164 | image = tf.expand_dims(image, 0) # make it batched 165 | plot_image_summary = tf.summary.image('clusters', image, max_outputs=10) 166 | 167 | # initialization for train data generation---------------------------- 168 | global pai_init, mu_init, log_sigma_init 169 | pai_init = np.random.uniform(0, 2, (dim_z)).astype(np.float32) #np.ones(K) 170 | # mu_init = np.array([[3, 5],[-3, -4], [-5, 5]], dtype=np.float32) 171 | # sigma_init = np.array([[0, 0],[0, 0], [0, 0]], dtype=np.float32) 172 | mu_init = np.random.uniform(0, 1, (dim_z, dim_x)).astype(np.float32) 173 | log_sigma_init = np.random.normal(-3, 0.5, (dim_z, dim_x)).astype(np.float32) 174 | print('pai init for generating train data: ', pai_init) 175 | print('mu init for generating train data: \n', mu_init) 176 | print('sigma init for generating train data: \n', np.exp(log_sigma_init)) 177 | with tf.variable_scope('decoder', reuse=True): 178 | pai = tf.get_variable('pai') 179 | mu = tf.get_variable('mu') 180 | log_sigma = tf.get_variable('log_sigma') 181 | pai_assign = pai.assign(pai_init) 182 | mu_assign = mu.assign(mu_init) 183 | log_sigma_assign = log_sigma.assign(log_sigma_init) 184 | 185 | with tf.Session() as sess: 186 | sess.run(tf.global_variables_initializer()) 187 | # generate train data ------------------------------------------------- 188 | train_filename = '../data/N_{}_K_{}_2d_gaussian_gzip.hkl'.format(N, K) 189 | global x_train, z_train 190 | # if not os.path.exists(train_filename): 191 | sess.run([pai_assign, mu_assign, log_sigma_assign]) 192 | x_train, z_train = sess.run([x_gen, z_gen]) 193 | print('x_train shape:', x_train.shape) 194 | hickle.dump(x_train, train_filename, mode='w', compression='gzip') 195 | # x_train_mean = np.mean(x_train, 0) 196 | # x_train_std = np.std(x_train, 0) 197 | # x_train_normed = (x_train - x_train_mean)/x_train_std 198 | x_train_normed = x_train # no normalization 199 | x_train_normed_no_shuffle = x_train_normed 200 | # print(x_train_mean) 201 | # print(x_train_std) 202 | # x_train_min = x_train.min() 203 | # x_train_max = x_train.max() 204 | # x_train_normed = (x_train - x_train_min)/(x_train_max - x_train_min) 205 | print(x_train_normed.max(), x_train_normed.min()) 206 | # else: # load existing file 207 | # print('load existing file: {}'.format(train_filename)) 208 | # x_train = hickle.load(train_filename) 209 | # print(x_train.shape) 210 | # plt.plot(x_train[:,0], x_train[:,1], '+') 211 | # plt.show() 212 | log_dir = './log/N_{}_K_{}_2d_gaussian/'.format(N, K) 213 | if os.path.exists(log_dir): 214 | shutil.rmtree(log_dir) 215 | summary_writer = tf.summary.FileWriter(log_dir, 216 | graph=tf.get_default_graph()) 217 | 218 | global x_gen_list, clusters_list # , q_res_list 219 | global mu_res, log_sigma_res, pai_res 220 | x_gen_list = [] 221 | # q_res_list = [] 222 | clusters_list = [] 223 | 224 | print('training...') # --------------------------------------------------- 225 | sess.run(tf.global_variables_initializer()) 226 | pai_res_0, mu_res_0, log_sigma_res_0 = sess.run([pai, mu, log_sigma]) 227 | print('random initializing...') 228 | print('pai_res_0: ', pai_res_0) 229 | print('mu_res_0: \n', mu_res_0) 230 | print('sigma_res_0: \n', np.exp(log_sigma_res_0)) 231 | global_step = 0 232 | for epoch in range(1, epoches + 1): 233 | if epoch % anneal_lr_freq == 0: 234 | learning_rate *= anneal_lr_rate 235 | np.random.shuffle(x_train_normed) # shuffle training data 236 | lbs = [] 237 | for t in range(iters_per_batch): 238 | global_step += 1 239 | x_batch = x_train_normed[t * batch_size : (t + 1) * batch_size] # get batched data 240 | # print('x_batch shape:', x_batch.shape) 241 | _, lb, q_cluster_res = sess.run([infer, mean_lower_bound, 242 | q_cluster], 243 | feed_dict={x_ph: x_batch, 244 | learning_rate_ph: learning_rate}) 245 | lbs.append(lb) 246 | # print(grad_var_res[-3:]) 247 | 248 | if epoch % save_freq == 0: # results ------------------------------------------------- 249 | # x_train_gen = sess.run(x_gen) 250 | # print(x_train == x_train) 251 | # x_gen_list.append(x_train_gen) 252 | print('Epoch {}: average Lower bound = {}'.format( 253 | epoch, np.mean(lbs))) 254 | q_cluster_res_save, lb_summ_res, merge_all = sess.run([q_cluster, lb_summ, merged_summary_op], 255 | feed_dict={x_ph: x_train_normed_no_shuffle, learning_rate_ph: learning_rate}) 256 | # summary_writer.add_summary(lb_summ_res, global_step=epoch) 257 | summary_writer.add_summary(merge_all, global_step=epoch) 258 | clusters = np.argmax(q_cluster_res_save, axis=1) 259 | # print(clusters.shape) 260 | # print(qz_samples_res_save.shape) 261 | clusters_list.append(clusters) 262 | 263 | if epoch % plot_freq == 0: # plot scatter ------------------------------------ 264 | # plot_buf = get_plot_buf(x_train_normed_no_shuffle, clusters) 265 | pai_res, mu_res, log_sigma_res = sess.run([pai, mu, log_sigma]) 266 | plot_buf = get_plot_buf(x_train, clusters, mu_res, log_sigma_res, mu_init, log_sigma_init) 267 | plot_image_summary_ = sess.run( 268 | plot_image_summary, 269 | feed_dict={plot_buf_ph: plot_buf.getvalue()}) 270 | summary_writer.add_summary(plot_image_summary_, global_step=epoch) 271 | # q_res_list.append(qz_samples_res) 272 | # print(qz_samples_res) 273 | # name = "results/vae/vae.epoch.{}.png".format(epoch) 274 | 275 | pai_res, mu_res, log_sigma_res = sess.run([pai, mu, log_sigma]) 276 | print("Random Seed: ", seed) 277 | print('pai init for generating train data: ', pai_init) 278 | print('mu init for generating train data: \n', mu_init) 279 | print('sigma init for generating train data: \n', np.exp(log_sigma_init)) 280 | print('*'*10) 281 | print('pai_res: ', pai_res) 282 | print('mu_res: \n', mu_res) 283 | print('sigma_res: \n', np.exp(log_sigma_res)) 284 | 285 | 286 | 287 | def get_plot_buf(x, clusters, mu, logstd, true_mu, true_logstd): 288 | N = x.shape[0] 289 | K = mu.shape[0] 290 | fig = plt.figure() 291 | # print(clusters.shape) 292 | # print(x.shape) 293 | ax = fig.add_subplot(111, aspect='auto') 294 | plt.scatter(x[:, 0], x[:, 1], c=clusters, s=50) 295 | # print(mu, logstd) 296 | ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]), 297 | angle=0, facecolor='none', zorder=10, edgecolor='g', label='predict' if i==0 else None) 298 | for i, (mean_, logstd_) in enumerate(zip(mu, logstd))] 299 | true_ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]), 300 | angle=0, facecolor='none', zorder=10, edgecolor='r', label='true' if i==0 else None) 301 | for i,(mean_, logstd_) in enumerate(zip(true_mu, true_logstd))] 302 | # print(ells[0]) 303 | [ax.add_patch(ell) for ell in ells] 304 | [ax.add_patch(true_ell) for true_ell in true_ells] 305 | ax.legend(loc='best') 306 | ax.set_title('N={},K={}'.format(N, K)) 307 | plt.autoscale(True) 308 | buf = io.BytesIO() 309 | fig.savefig(buf, format='png') 310 | plt.close() 311 | buf.seek(0) 312 | return buf 313 | 314 | if __name__ == "__main__": 315 | N = 100 316 | K = 3 317 | main(N=N, K=K) 318 | clusters = clusters_list[-1] 319 | # print(clusters) 320 | # for clusters in clusters_list: 321 | fig = plt.figure() 322 | ax = fig.add_subplot(111, aspect='auto') 323 | plt.scatter(x_train[:, 0], x_train[:, 1], c=clusters, s=50) 324 | 325 | ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]), 326 | angle=0, facecolor='none', zorder=10, edgecolor='g', label='predict' if i==0 else None) 327 | for i,(mean_, logstd_) in enumerate(zip(mu_res, log_sigma_res))] 328 | true_ells = [Ellipse(xy=mean_, width=6*np.exp(logstd_[0]), height=6*np.exp(logstd_[1]), 329 | angle=0, facecolor='none', zorder=10, edgecolor='r', label='true' if i==0 else None) 330 | for i,(mean_, logstd_) in enumerate(zip(mu_init, log_sigma_init))] 331 | 332 | [ax.add_patch(ell) for ell in ells] 333 | [ax.add_patch(true_ell) for true_ell in true_ells] 334 | ax.legend(loc='best') 335 | ax.set_title('N={},K={}'.format(N, K)) 336 | plt.autoscale(True) 337 | fig.savefig('./results/result_N_{}_K_{}.png'.format(N, K), dpi=fig.dpi) 338 | 339 | # plt.show() 340 | 341 | 342 | 343 | --------------------------------------------------------------------------------