├── .floo ├── .flooignore ├── .gitignore └── variation.py /.floo: -------------------------------------------------------------------------------- 1 | { 2 | "url": "https://floobits.com/berleon/tensorflow_vae" 3 | } -------------------------------------------------------------------------------- /.flooignore: -------------------------------------------------------------------------------- 1 | extern 2 | node_modules 3 | tmp 4 | vendor 5 | .idea/workspace.xml 6 | .idea/misc.xml 7 | .idea/ 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | env/ 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *,cover 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # PyBuilder 59 | target/ 60 | ### IPythonNotebook template 61 | # Temporary data 62 | .ipynb_checkpoints/ 63 | 64 | # Created by .ignore support plugin (hsz.mobi) 65 | -------------------------------------------------------------------------------- /variation.py: -------------------------------------------------------------------------------- 1 | from tempfile import mkdtemp 2 | import tensorflow as tf 3 | import numpy as np 4 | import keras.optimizers 5 | import keras.models 6 | from keras.backend import tensorflow_backend 7 | from keras.layers import Dense, Activation, Reshape 8 | from keras.optimizers import Adam 9 | from keras.datasets import mnist 10 | 11 | def batch(iterable, n=1): 12 | l = len(iterable) 13 | for ndx in range(0, l, n): 14 | it = iterable[ndx:min(ndx + n, l)] 15 | if len(it) == n: 16 | yield it 17 | 18 | class VAE(): 19 | def __init__(self, encoder, decoder): 20 | self.x = tf.placeholder(tf.float32, name='input') 21 | self.latent_shape = (encoder.output_shape[0], encoder.output_shape[1] // 2) 22 | self.encoder = encoder 23 | self.decoder = decoder 24 | self.batch_size = self.latent_shape[0] 25 | 26 | assert None not in self.latent_shape, "All dimensions must be known" 27 | encoded = tf.reshape(encoder(self.x), (self.batch_size, 2, self.latent_shape[1])) 28 | self.mu, self.log_sigma = encoded[:, 0, :], encoded[:, 1, :] 29 | self.mu = tf.reshape(self.mu, self.latent_shape) 30 | self.log_sigma = tf.reshape(self.log_sigma, self.latent_shape) 31 | 32 | self.eps = tf.random_normal(self.latent_shape, 33 | mean=0.0, stddev=1.0, name="eps") 34 | self.z = self.mu + tf.exp(self.log_sigma) * self.eps 35 | 36 | decoded = decoder(self.z) 37 | decoder_shape = decoder.output_shape 38 | if len(decoder_shape) == 2: 39 | decoded = tf.reshape(decoded, (self.batch_size, decoder_shape[1] // 2, 1, 2)) 40 | else: 41 | assert decoder_shape[-1] == 2 42 | 43 | self.x_hat_mu, self.x_hat_log_sigma = decoded[:, :, :, 0], decoded[:, :, :, 1] 44 | self.x_hat_mu = tf.reshape(self.x_hat_mu, (self.batch_size, decoder_shape[1] // 2)) 45 | self.x_hat_log_sigma = tf.reshape(self.x_hat_log_sigma, (self.batch_size, decoder_shape[1] // 2)) 46 | 47 | self.params = encoder.trainable_weights + decoder.trainable_weights 48 | 49 | self.latent_loss = -0.5 * tf.reduce_mean(1 + self.log_sigma - self.mu**2 - tf.exp(self.log_sigma)) 50 | self.reconstruction_loss = -tf.reduce_mean(((self.x_hat_mu - self.x)**2) / (2 * tf.exp(self.x_hat_log_sigma))) 51 | 52 | self.loss = self.latent_loss + self.reconstruction_loss 53 | 54 | def compile(self, optimizer): 55 | optimizer = keras.optimizers.get(optimizer) 56 | params = self.encoder.trainable_weights + self.decoder.trainable_weights 57 | regularizers = self.encoder.regularizers + self.decoder.regularizers 58 | constraints = self.encoder.constraints + self.decoder.constraints 59 | updates = self.encoder.updates + self.decoder.updates 60 | 61 | updates += optimizer.get_updates(params, constraints, self.loss) 62 | loss = self.loss 63 | for r in regularizers: 64 | loss += r(loss) 65 | self.train_loss = loss 66 | 67 | with tf.control_dependencies([self.train_loss]): 68 | self.train_updates = [tf.assign(p, new_p) for (p, new_p) in updates] 69 | 70 | def fit_batch(self, X, session): 71 | updated = session.run([self.train_loss] + self.train_updates, feed_dict={self.x: X}) 72 | return updated[0] 73 | 74 | def fit(self, X, num_epochs=1): 75 | session = tensorflow_backend._get_session() 76 | writer_file = '/tmp/tmp0UWgeI'#mkdtemp() 77 | print(writer_file) 78 | writer = tf.train.SummaryWriter(writer_file, session.graph_def) 79 | for batch_idx in range(num_epochs): 80 | errors = [] 81 | for x in batch(X, self.batch_size): 82 | errors.append(self.fit_batch(x, session)) 83 | 84 | print('({}) Epoch error: {}'.format(batch_idx, np.mean(errors))) 85 | 86 | def reconstruct(self, X): 87 | session = tensorflow_backend._get_session() 88 | return session.run([self.x_hat_mu, self.x_hat_log_sigma], feed_dict={self.x: X}) 89 | 90 | def encode(self, X): 91 | session = tensorflow_backend._get_session() 92 | return session.run([self.mu, self.log_sigma], feed_dict={self.x: X}) 93 | 94 | def generate(self, Z=None): 95 | if Z is None: 96 | Z = np.random.normal(0, 1, self.latent_shape) 97 | 98 | session = tensorflow_backend._get_session() 99 | return session.run([self.x_hat_mu, self.x_hat_log_sigma], feed_dict={self.z: Z}) 100 | 101 | 102 | if __name__ == '__main__': 103 | (X_train, y_train), (X_test, y_test) = mnist.load_data() 104 | 105 | print(X_train.shape) 106 | 107 | z_dim = 10 108 | 109 | encoder = keras.models.Sequential() 110 | encoder.add(Dense(50, batch_input_shape=(64, 28 * 28))) 111 | encoder.add(Activation('tanh')) 112 | 113 | encoder.add(Dense(z_dim * 2, init='uniform')) 114 | encoder.add(Activation('tanh')) 115 | 116 | decoder = keras.models.Sequential() 117 | decoder.add(Dense(60, batch_input_shape=(64, z_dim))) 118 | decoder.add(Activation('tanh')) 119 | 120 | decoder.add(Dense(28 * 28 * 2)) 121 | decoder.add(Activation('tanh')) 122 | 123 | optimizer = Adam() 124 | 125 | vae = VAE(encoder, decoder) 126 | vae.compile(optimizer) 127 | 128 | X_in = X_train.reshape((-1, 28 * 28)).astype(np.float32) / 255. 129 | print(X_in.shape) 130 | 131 | vae.fit(X_in * 2 - 1, num_epochs=10) 132 | --------------------------------------------------------------------------------