├── LICENSE ├── README.md ├── input_data.py ├── main.py ├── ops.py ├── results ├── 0.jpg ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg ├── 9.jpg └── base.jpg └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # variational-autoencoder 2 | generate MNIST using a Variational Autoencoder 3 | 4 | ![](http://kvfrans.com/content/images/2016/08/mnist.jpg) 5 | ![](http://kvfrans.com/content/images/2016/08/vae.jpg) 6 | 7 | This is code that goes along with [my post explaining the variational autoencoder.](http://kvfrans.com/variational-autoencoders-explained/) 8 | 9 | Based off this [really helpful post](https://jmetzen.github.io/2015-11-27/vae.html) 10 | -------------------------------------------------------------------------------- /input_data.py: -------------------------------------------------------------------------------- 1 | """Functions for downloading and reading MNIST data.""" 2 | import gzip 3 | import os 4 | import urllib 5 | import numpy 6 | SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' 7 | 8 | 9 | def maybe_download(filename, work_directory): 10 | """Download the data from Yann's website, unless it's already here.""" 11 | if not os.path.exists(work_directory): 12 | os.mkdir(work_directory) 13 | filepath = os.path.join(work_directory, filename) 14 | if not os.path.exists(filepath): 15 | filepath, _ = urllib.urlretrieve(SOURCE_URL + filename, filepath) 16 | statinfo = os.stat(filepath) 17 | print 'Succesfully downloaded', filename, statinfo.st_size, 'bytes.' 18 | return filepath 19 | 20 | 21 | def _read32(bytestream): 22 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 23 | return numpy.frombuffer(bytestream.read(4), dtype=dt) 24 | 25 | 26 | def extract_images(filename): 27 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 28 | print 'Extracting', filename 29 | with gzip.open(filename) as bytestream: 30 | magic = _read32(bytestream) 31 | if magic != 2051: 32 | raise ValueError( 33 | 'Invalid magic number %d in MNIST image file: %s' % 34 | (magic, filename)) 35 | num_images = _read32(bytestream) 36 | rows = _read32(bytestream) 37 | cols = _read32(bytestream) 38 | buf = bytestream.read(rows * cols * num_images) 39 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 40 | data = data.reshape(num_images, rows, cols, 1) 41 | return data 42 | 43 | 44 | def dense_to_one_hot(labels_dense, num_classes=10): 45 | """Convert class labels from scalars to one-hot vectors.""" 46 | num_labels = labels_dense.shape[0] 47 | index_offset = numpy.arange(num_labels) * num_classes 48 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 49 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 50 | return labels_one_hot 51 | 52 | 53 | def extract_labels(filename, one_hot=False): 54 | """Extract the labels into a 1D uint8 numpy array [index].""" 55 | print 'Extracting', filename 56 | with gzip.open(filename) as bytestream: 57 | magic = _read32(bytestream) 58 | if magic != 2049: 59 | raise ValueError( 60 | 'Invalid magic number %d in MNIST label file: %s' % 61 | (magic, filename)) 62 | num_items = _read32(bytestream) 63 | buf = bytestream.read(num_items) 64 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 65 | if one_hot: 66 | return dense_to_one_hot(labels) 67 | return labels 68 | 69 | 70 | class DataSet(object): 71 | def __init__(self, images, labels, fake_data=False): 72 | if fake_data: 73 | self._num_examples = 10000 74 | else: 75 | assert images.shape[0] == labels.shape[0], ( 76 | "images.shape: %s labels.shape: %s" % (images.shape, 77 | labels.shape)) 78 | self._num_examples = images.shape[0] 79 | # Convert shape from [num examples, rows, columns, depth] 80 | # to [num examples, rows*columns] (assuming depth == 1) 81 | assert images.shape[3] == 1 82 | images = images.reshape(images.shape[0], 83 | images.shape[1] * images.shape[2]) 84 | # Convert from [0, 255] -> [0.0, 1.0]. 85 | images = images.astype(numpy.float32) 86 | images = numpy.multiply(images, 1.0 / 255.0) 87 | self._images = images 88 | self._labels = labels 89 | self._epochs_completed = 0 90 | self._index_in_epoch = 0 91 | 92 | @property 93 | def images(self): 94 | return self._images 95 | 96 | @property 97 | def labels(self): 98 | return self._labels 99 | 100 | @property 101 | def num_examples(self): 102 | return self._num_examples 103 | 104 | @property 105 | def epochs_completed(self): 106 | return self._epochs_completed 107 | 108 | def next_batch(self, batch_size, fake_data=False): 109 | """Return the next `batch_size` examples from this data set.""" 110 | if fake_data: 111 | fake_image = [1.0 for _ in xrange(784)] 112 | fake_label = 0 113 | return [fake_image for _ in xrange(batch_size)], [ 114 | fake_label for _ in xrange(batch_size)] 115 | start = self._index_in_epoch 116 | self._index_in_epoch += batch_size 117 | if self._index_in_epoch > self._num_examples: 118 | # Finished epoch 119 | self._epochs_completed += 1 120 | # Shuffle the data 121 | perm = numpy.arange(self._num_examples) 122 | numpy.random.shuffle(perm) 123 | self._images = self._images[perm] 124 | self._labels = self._labels[perm] 125 | # Start next epoch 126 | start = 0 127 | self._index_in_epoch = batch_size 128 | assert batch_size <= self._num_examples 129 | end = self._index_in_epoch 130 | return self._images[start:end], self._labels[start:end] 131 | 132 | 133 | def read_data_sets(train_dir, fake_data=False, one_hot=False): 134 | class DataSets(object): 135 | pass 136 | data_sets = DataSets() 137 | if fake_data: 138 | data_sets.train = DataSet([], [], fake_data=True) 139 | data_sets.validation = DataSet([], [], fake_data=True) 140 | data_sets.test = DataSet([], [], fake_data=True) 141 | return data_sets 142 | TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' 143 | TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' 144 | TEST_IMAGES = 't10k-images-idx3-ubyte.gz' 145 | TEST_LABELS = 't10k-labels-idx1-ubyte.gz' 146 | VALIDATION_SIZE = 5000 147 | local_file = maybe_download(TRAIN_IMAGES, train_dir) 148 | train_images = extract_images(local_file) 149 | local_file = maybe_download(TRAIN_LABELS, train_dir) 150 | train_labels = extract_labels(local_file, one_hot=one_hot) 151 | local_file = maybe_download(TEST_IMAGES, train_dir) 152 | test_images = extract_images(local_file) 153 | local_file = maybe_download(TEST_LABELS, train_dir) 154 | test_labels = extract_labels(local_file, one_hot=one_hot) 155 | validation_images = train_images[:VALIDATION_SIZE] 156 | validation_labels = train_labels[:VALIDATION_SIZE] 157 | train_images = train_images[VALIDATION_SIZE:] 158 | train_labels = train_labels[VALIDATION_SIZE:] 159 | data_sets.train = DataSet(train_images, train_labels) 160 | data_sets.validation = DataSet(validation_images, validation_labels) 161 | data_sets.test = DataSet(test_images, test_labels) 162 | return data_sets 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import input_data 4 | import matplotlib.pyplot as plt 5 | import os 6 | from scipy.misc import imsave as ims 7 | from utils import * 8 | from ops import * 9 | 10 | class LatentAttention(): 11 | def __init__(self): 12 | self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 13 | self.n_samples = self.mnist.train.num_examples 14 | 15 | self.n_hidden = 500 16 | self.n_z = 20 17 | self.batchsize = 100 18 | 19 | self.images = tf.placeholder(tf.float32, [None, 784]) 20 | image_matrix = tf.reshape(self.images,[-1, 28, 28, 1]) 21 | z_mean, z_stddev = self.recognition(image_matrix) 22 | samples = tf.random_normal([self.batchsize,self.n_z],0,1,dtype=tf.float32) 23 | guessed_z = z_mean + (z_stddev * samples) 24 | 25 | self.generated_images = self.generation(guessed_z) 26 | generated_flat = tf.reshape(self.generated_images, [self.batchsize, 28*28]) 27 | 28 | self.generation_loss = -tf.reduce_sum(self.images * tf.log(1e-8 + generated_flat) + (1-self.images) * tf.log(1e-8 + 1 - generated_flat),1) 29 | 30 | self.latent_loss = 0.5 * tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev)) - 1,1) 31 | self.cost = tf.reduce_mean(self.generation_loss + self.latent_loss) 32 | self.optimizer = tf.train.AdamOptimizer(0.001).minimize(self.cost) 33 | 34 | 35 | # encoder 36 | def recognition(self, input_images): 37 | with tf.variable_scope("recognition"): 38 | h1 = lrelu(conv2d(input_images, 1, 16, "d_h1")) # 28x28x1 -> 14x14x16 39 | h2 = lrelu(conv2d(h1, 16, 32, "d_h2")) # 14x14x16 -> 7x7x32 40 | h2_flat = tf.reshape(h2,[self.batchsize, 7*7*32]) 41 | 42 | w_mean = dense(h2_flat, 7*7*32, self.n_z, "w_mean") 43 | w_stddev = dense(h2_flat, 7*7*32, self.n_z, "w_stddev") 44 | 45 | return w_mean, w_stddev 46 | 47 | # decoder 48 | def generation(self, z): 49 | with tf.variable_scope("generation"): 50 | z_develop = dense(z, self.n_z, 7*7*32, scope='z_matrix') 51 | z_matrix = tf.nn.relu(tf.reshape(z_develop, [self.batchsize, 7, 7, 32])) 52 | h1 = tf.nn.relu(conv_transpose(z_matrix, [self.batchsize, 14, 14, 16], "g_h1")) 53 | h2 = conv_transpose(h1, [self.batchsize, 28, 28, 1], "g_h2") 54 | h2 = tf.nn.sigmoid(h2) 55 | 56 | return h2 57 | 58 | def train(self): 59 | visualization = self.mnist.train.next_batch(self.batchsize)[0] 60 | reshaped_vis = visualization.reshape(self.batchsize,28,28) 61 | ims("results/base.jpg",merge(reshaped_vis[:64],[8,8])) 62 | # train 63 | saver = tf.train.Saver(max_to_keep=2) 64 | with tf.Session() as sess: 65 | sess.run(tf.initialize_all_variables()) 66 | for epoch in range(10): 67 | for idx in range(int(self.n_samples / self.batchsize)): 68 | batch = self.mnist.train.next_batch(self.batchsize)[0] 69 | _, gen_loss, lat_loss = sess.run((self.optimizer, self.generation_loss, self.latent_loss), feed_dict={self.images: batch}) 70 | # dumb hack to print cost every epoch 71 | if idx % (self.n_samples - 3) == 0: 72 | print "epoch %d: genloss %f latloss %f" % (epoch, np.mean(gen_loss), np.mean(lat_loss)) 73 | saver.save(sess, os.getcwd()+"/training/train",global_step=epoch) 74 | generated_test = sess.run(self.generated_images, feed_dict={self.images: visualization}) 75 | generated_test = generated_test.reshape(self.batchsize,28,28) 76 | ims("results/"+str(epoch)+".jpg",merge(generated_test[:64],[8,8])) 77 | 78 | 79 | model = LatentAttention() 80 | model.train() 81 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | class batch_norm(object): 5 | """Code modification of http://stackoverflow.com/a/33950177""" 6 | def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): 7 | with tf.variable_scope(name): 8 | self.epsilon = epsilon 9 | self.momentum = momentum 10 | 11 | self.ema = tf.train.ExponentialMovingAverage(decay=self.momentum) 12 | self.name = name 13 | 14 | def __call__(self, x, train=True): 15 | shape = x.get_shape().as_list() 16 | 17 | if train: 18 | with tf.variable_scope(self.name) as scope: 19 | self.beta = tf.get_variable("beta", [shape[-1]], 20 | initializer=tf.constant_initializer(0.)) 21 | self.gamma = tf.get_variable("gamma", [shape[-1]], 22 | initializer=tf.random_normal_initializer(1., 0.02)) 23 | 24 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], name='moments') 25 | ema_apply_op = self.ema.apply([batch_mean, batch_var]) 26 | self.ema_mean, self.ema_var = self.ema.average(batch_mean), self.ema.average(batch_var) 27 | 28 | with tf.control_dependencies([ema_apply_op]): 29 | mean, var = tf.identity(batch_mean), tf.identity(batch_var) 30 | else: 31 | mean, var = self.ema_mean, self.ema_var 32 | 33 | normed = tf.nn.batch_norm_with_global_normalization( 34 | x, mean, var, self.beta, self.gamma, self.epsilon, scale_after_normalization=True) 35 | 36 | return normed 37 | 38 | # standard convolution layer 39 | def conv2d(x, inputFeatures, outputFeatures, name): 40 | with tf.variable_scope(name): 41 | w = tf.get_variable("w",[5,5,inputFeatures, outputFeatures], initializer=tf.truncated_normal_initializer(stddev=0.02)) 42 | b = tf.get_variable("b",[outputFeatures], initializer=tf.constant_initializer(0.0)) 43 | conv = tf.nn.conv2d(x, w, strides=[1,2,2,1], padding="SAME") + b 44 | return conv 45 | 46 | def conv_transpose(x, outputShape, name): 47 | with tf.variable_scope(name): 48 | # h, w, out, in 49 | w = tf.get_variable("w",[5,5, outputShape[-1], x.get_shape()[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02)) 50 | b = tf.get_variable("b",[outputShape[-1]], initializer=tf.constant_initializer(0.0)) 51 | convt = tf.nn.conv2d_transpose(x, w, output_shape=outputShape, strides=[1,2,2,1]) 52 | return convt 53 | 54 | def deconv2d(input_, output_shape, 55 | k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, 56 | name="deconv2d"): 57 | with tf.variable_scope(name): 58 | # filter : [height, width, output_channels, in_channels] 59 | w = tf.get_variable('w', [k_h, k_h, output_shape[-1], input_.get_shape()[-1]], 60 | initializer=tf.random_normal_initializer(stddev=stddev)) 61 | 62 | deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape, strides=[1, d_h, d_w, 1]) 63 | 64 | biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 65 | deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape()) 66 | 67 | return deconv 68 | 69 | # leaky reLu unit 70 | def lrelu(x, leak=0.2, name="lrelu"): 71 | with tf.variable_scope(name): 72 | f1 = 0.5 * (1 + leak) 73 | f2 = 0.5 * (1 - leak) 74 | return f1 * x + f2 * abs(x) 75 | 76 | # fully-conected layer 77 | def dense(x, inputFeatures, outputFeatures, scope=None, with_w=False): 78 | with tf.variable_scope(scope or "Linear"): 79 | matrix = tf.get_variable("Matrix", [inputFeatures, outputFeatures], tf.float32, tf.random_normal_initializer(stddev=0.02)) 80 | bias = tf.get_variable("bias", [outputFeatures], initializer=tf.constant_initializer(0.0)) 81 | if with_w: 82 | return tf.matmul(x, matrix) + bias, matrix, bias 83 | else: 84 | return tf.matmul(x, matrix) + bias 85 | -------------------------------------------------------------------------------- /results/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/0.jpg -------------------------------------------------------------------------------- /results/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/1.jpg -------------------------------------------------------------------------------- /results/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/2.jpg -------------------------------------------------------------------------------- /results/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/3.jpg -------------------------------------------------------------------------------- /results/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/4.jpg -------------------------------------------------------------------------------- /results/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/5.jpg -------------------------------------------------------------------------------- /results/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/6.jpg -------------------------------------------------------------------------------- /results/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/7.jpg -------------------------------------------------------------------------------- /results/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/8.jpg -------------------------------------------------------------------------------- /results/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/9.jpg -------------------------------------------------------------------------------- /results/base.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kvfrans/variational-autoencoder/e6c94c6d4cee06b22c710923c3a8bbed84b88794/results/base.jpg -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | 4 | def merge(images, size): 5 | h, w = images.shape[1], images.shape[2] 6 | img = np.zeros((h * size[0], w * size[1])) 7 | 8 | for idx, image in enumerate(images): 9 | i = idx % size[1] 10 | j = idx / size[1] 11 | img[j*h:j*h+h, i*w:i*w+w] = image 12 | 13 | return img 14 | --------------------------------------------------------------------------------