├── .gitignore ├── README.md ├── common ├── __init__.py ├── deconv.py ├── io_utils.py ├── tensorflow_context.py └── vae_args.py ├── model ├── __init__.py ├── network_base.py ├── network_implementations.py └── vaemodel.py └── vae.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | generated 3 | outputmodels 4 | data 5 | orig 6 | samples 7 | results 8 | images_to_hdf5.py 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VAE implementation in Tensorflow for face expression reconstruction 2 | 3 | 4 | The main motivation of this work is to use Variational Autoencoder model to embed unseen faces into the latent space of pre-trained single actor-centric face expressions. The datasets used in described experiments are based on youtube videos passed through openface feature extraction utility 5 | 6 | 7 | 8 | ## Requirements: 9 | 10 | * python v2.7.6 11 | * numpy v1.11.1 12 | * scipy v0.13.3 13 | * h5py v2.6.0 14 | * Pillow v2.3.0 15 | * progressbar v2.3 16 | * argparse v1.2.1 17 | * tensorflow 0.9.0 18 | * prettytensor 0.6.2 19 | 20 | 21 | for help try 22 | ```bash 23 | $ python vae.py -h 24 | 25 | usage: vae.py [-h] {train,sample,reconstruct} ... 26 | 27 | positional arguments: 28 | {train,sample,reconstruct} 29 | train train VAE model [vae.py train -h] 30 | sample sample from existing model [vae.py sample -h] 31 | reconstruct reconstruct images based on existing model [vae.py 32 | reconstruct -h] 33 | 34 | optional arguments: 35 | -h, --help show this help message and exit 36 | ``` 37 | 38 | The tool implements 3 high level commands: 39 | 40 | * train 41 | 42 | ```bash 43 | $ python vae.py train -h 44 | vagrant@vagrant-ubuntu-trusty-64:~/tflow/VAE_OOP$ python vae.py train -h 45 | usage: vae.py train [-h] [--hdf5-dataset-name] [--batch-size] [--epochs] 46 | [--learning-rate] [--latent-dim] [--input-width] 47 | [--input-height] 48 | input output 49 | 50 | positional arguments: 51 | input Path of HDF5 data file 52 | output Output tf model dir 53 | 54 | optional arguments: 55 | -h, --help show this help message and exit 56 | --hdf5-dataset-name Name of dataset in hdf5 57 | --batch-size Batch size 58 | --epochs Number of epochs to run 59 | --learning-rate Learning rate 60 | --latent-dim latent variable dimensionality 61 | --input-width Width of input images 62 | --input-height Height of input images 63 | ``` 64 | 65 | the most important detail here is format of input data. Input is expected to be HDF5 file containing dataset that is 4 dimensional tensor of size: 66 | ``` 67 | nr_of_objects x width x height x 3 68 | ``` 69 | 70 | the output directory by default will be populated with tensorflow saved model + metadata files 71 | 72 | 73 | 74 | 75 | * sample 76 | 77 | when your model is fully trained you can use sample command to draw points from latent space and walk from one to another randomly 78 | 79 | ```bash 80 | $ python vae.py sample -h 81 | usage: vae.py sample [-h] [--latent-dim] [--input-width] [--input-height] 82 | [--output--dir] 83 | input 84 | 85 | positional arguments: 86 | input Path of tensorflow model dir 87 | 88 | optional arguments: 89 | -h, --help show this help message and exit 90 | --latent-dim latent variable dimensionality 91 | --input-width Width of input images 92 | --input-height Height of input images 93 | --output--dir Output dir where png files are stored 94 | ``` 95 | 96 | * reconstruct 97 | 98 | If you want to use existing model to reconstruct input images use reconstruct command 99 | 100 | ```bash 101 | $ python vae.py reconstruct -h 102 | usage: vae.py reconstruct [-h] [--latent-dim] [--input-width] [--input-height] 103 | input inputdir output 104 | 105 | positional arguments: 106 | input Path of tensorflow model dir 107 | inputdir Path directory where input images are stored 108 | output Output directory where reconstructions will be stored 109 | 110 | optional arguments: 111 | -h, --help show this help message and exit 112 | --latent-dim latent variable dimensionality 113 | --input-width Width of model input images 114 | --input-height Height of model input images 115 | 116 | ``` 117 | -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int8/VAE_tensorflow/942be29519cb57cb685e0b0ce533de9a25162285/common/__init__.py -------------------------------------------------------------------------------- /common/deconv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | """A quick hack to try deconv out.""" 12 | 13 | import collections 14 | 15 | import tensorflow as tf 16 | from tensorflow.python.framework import tensor_shape 17 | 18 | from prettytensor import layers 19 | from prettytensor import pretty_tensor_class as prettytensor 20 | from prettytensor.pretty_tensor_class import PAD_SAME 21 | from prettytensor.pretty_tensor_class import Phase 22 | from prettytensor.pretty_tensor_class import PROVIDED 23 | 24 | # pylint: disable=redefined-outer-name,invalid-name 25 | @prettytensor.Register( 26 | assign_defaults=('activation_fn', 'l2loss', 'stddev', 'batch_normalize')) 27 | class deconv2d(prettytensor.VarStoreMethod): 28 | 29 | def __call__(self, 30 | input_layer, 31 | kernel, 32 | depth, 33 | name=PROVIDED, 34 | stride=None, 35 | activation_fn=None, 36 | l2loss=None, 37 | init=None, 38 | stddev=None, 39 | bias=True, 40 | edges=PAD_SAME, 41 | batch_normalize=False): 42 | """Adds a convolution to the stack of operations. 43 | 44 | The current head must be a rank 4 Tensor. 45 | 46 | Args: 47 | input_layer: The chainable object, supplied. 48 | kernel: The size of the patch for the pool, either an int or a length 1 or 49 | 2 sequence (if length 1 or int, it is expanded). 50 | depth: The depth of the new Tensor. 51 | name: The name for this operation is also used to create/find the 52 | parameter variables. 53 | stride: The strides as a length 1, 2 or 4 sequence or an integer. If an 54 | int, length 1 or 2, the stride in the first and last dimensions are 1. 55 | activation_fn: A tuple of (activation_function, extra_parameters). Any 56 | function that takes a tensor as its first argument can be used. More 57 | common functions will have summaries added (e.g. relu). 58 | l2loss: Set to a value greater than 0 to use L2 regularization to decay 59 | the weights. 60 | init: An optional initialization. If not specified, uses Xavier 61 | initialization. 62 | stddev: A standard deviation to use in parameter initialization. 63 | bias: Set to False to not have a bias. 64 | edges: Either SAME to use 0s for the out of bounds area or VALID to shrink 65 | the output size and only uses valid input pixels. 66 | batch_normalize: Set to True to batch_normalize this layer. 67 | Returns: 68 | Handle to the generated layer. 69 | Raises: 70 | ValueError: If head is not a rank 4 tensor or the depth of the input 71 | (4th dim) is not known. 72 | """ 73 | if len(input_layer.shape) != 4: 74 | raise ValueError( 75 | 'Cannot perform conv2d on tensor with shape %s' % input_layer.shape) 76 | if input_layer.shape[3] is None: 77 | raise ValueError('Input depth must be known') 78 | kernel = _kernel(kernel) 79 | stride = _stride(stride) 80 | size = [kernel[0], kernel[1], depth, input_layer.shape[3]] 81 | 82 | books = input_layer.bookkeeper 83 | if init is None: 84 | if stddev is None: 85 | patch_size = size[0] * size[1] 86 | init = layers.xavier_init(size[2] * patch_size, size[3] * patch_size) 87 | elif stddev: 88 | init = tf.truncated_normal_initializer(stddev=stddev) 89 | else: 90 | init = tf.zeros_initializer 91 | elif stddev is not None: 92 | raise ValueError('Do not set both init and stddev.') 93 | dtype = input_layer.tensor.dtype 94 | params = self.variable('weights', size, init, dt=dtype) 95 | 96 | input_height = input_layer.shape[1] 97 | input_width = input_layer.shape[2] 98 | 99 | filter_height = kernel[0] 100 | filter_width = kernel[1] 101 | 102 | row_stride = stride[1] 103 | col_stride = stride[2] 104 | 105 | out_rows, out_cols = get2d_deconv_output_size(input_height, input_width, filter_height, 106 | filter_width, row_stride, col_stride, edges) 107 | 108 | output_shape = [input_layer.shape[0], out_rows, out_cols, depth] 109 | y = tf.nn.conv2d_transpose(input_layer, params, output_shape, stride, edges) 110 | layers.add_l2loss(books, params, l2loss) 111 | if bias: 112 | y += self.variable( 113 | 'bias', 114 | [size[-2]], 115 | tf.zeros_initializer, 116 | dt=dtype) 117 | books.add_scalar_summary( 118 | tf.reduce_mean( 119 | layers.spatial_slice_zeros(y)), '%s/zeros_spatial' % y.op.name) 120 | if batch_normalize: 121 | y = input_layer.with_tensor(y).batch_normalize() 122 | if activation_fn is not None: 123 | if not isinstance(activation_fn, collections.Sequence): 124 | activation_fn = (activation_fn,) 125 | y = layers.apply_activation( 126 | books, 127 | y, 128 | activation_fn[0], 129 | activation_args=activation_fn[1:]) 130 | return input_layer.with_tensor(y) 131 | # pylint: enable=redefined-outer-name,invalid-name 132 | 133 | # Helper methods 134 | 135 | def get2d_deconv_output_size(input_height, input_width, filter_height, 136 | filter_width, row_stride, col_stride, padding_type): 137 | """Returns the number of rows and columns in a convolution/pooling output.""" 138 | input_height = tensor_shape.as_dimension(input_height) 139 | input_width = tensor_shape.as_dimension(input_width) 140 | filter_height = tensor_shape.as_dimension(filter_height) 141 | filter_width = tensor_shape.as_dimension(filter_width) 142 | row_stride = int(row_stride) 143 | col_stride = int(col_stride) 144 | 145 | # Compute number of rows in the output, based on the padding. 146 | if input_height.value is None or filter_height.value is None: 147 | out_rows = None 148 | elif padding_type == "VALID": 149 | out_rows = (input_height.value - 1) * row_stride + filter_height.value 150 | elif padding_type == "SAME": 151 | out_rows = input_height.value * row_stride 152 | else: 153 | raise ValueError("Invalid value for padding: %r" % padding_type) 154 | 155 | # Compute number of columns in the output, based on the padding. 156 | if input_width.value is None or filter_width.value is None: 157 | out_cols = None 158 | elif padding_type == "VALID": 159 | out_cols = (input_width.value - 1) * col_stride + filter_width.value 160 | elif padding_type == "SAME": 161 | out_cols = input_width.value * col_stride 162 | 163 | return out_rows, out_cols 164 | 165 | def _kernel(kernel_spec): 166 | """Expands the kernel spec into a length 2 list. 167 | 168 | Args: 169 | kernel_spec: An integer or a length 1 or 2 sequence that is expanded to a 170 | list. 171 | Returns: 172 | A length 2 list. 173 | """ 174 | if isinstance(kernel_spec, int): 175 | return [kernel_spec, kernel_spec] 176 | elif len(kernel_spec) == 1: 177 | return [kernel_spec[0], kernel_spec[0]] 178 | else: 179 | assert len(kernel_spec) == 2 180 | return kernel_spec 181 | 182 | 183 | def _stride(stride_spec): 184 | """Expands the stride spec into a length 4 list. 185 | 186 | Args: 187 | stride_spec: None, an integer or a length 1, 2, or 4 sequence. 188 | Returns: 189 | A length 4 list. 190 | """ 191 | if stride_spec is None: 192 | return [1, 1, 1, 1] 193 | elif isinstance(stride_spec, int): 194 | return [1, stride_spec, stride_spec, 1] 195 | elif len(stride_spec) == 1: 196 | return [1, stride_spec[0], stride_spec[0], 1] 197 | elif len(stride_spec) == 2: 198 | return [1, stride_spec[0], stride_spec[1], 1] 199 | else: 200 | assert len(stride_spec) == 4 201 | return stride_spec 202 | -------------------------------------------------------------------------------- /common/io_utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import numpy as np 3 | import os 4 | import Image 5 | import PIL 6 | from scipy.misc import imread 7 | from progressbar import ProgressBar, Percentage, Bar 8 | 9 | class HDF5Reader: 10 | def __init__(self, filename, dataset_name, b_size): 11 | with h5py.File(filename,'r') as hf: 12 | self.data = np.array(hf.get(dataset_name)) 13 | self.i = 0 14 | self.b_size = b_size 15 | self.n = self.data.shape[0] 16 | self.cycles = 0 17 | 18 | def next(self): 19 | k_left = self.i 20 | k_right = self.i + self.b_size 21 | if k_right >= self.n: 22 | batch = np.append(self.data[range(k_left, self.n),:,:,:], self.data[range(0, k_right - self.n),:,:,:], axis = 0).astype(np.float32) / 255. 23 | self.data = self.data[np.random.permutation(self.n),:] 24 | self.cycles = self.cycles + 1 25 | else: 26 | batch = self.data[k_left:k_right, :, :, :].astype(np.float32) / 255. 27 | self.i = k_right % self.n 28 | return batch 29 | 30 | def get_cycles(self): 31 | return self.cycles 32 | 33 | 34 | def hdf5_generator(filename, dataset_name, batch_size): 35 | 36 | with h5py.File(filename,'r') as hf: 37 | data = np.array(hf.get(dataset_name)) 38 | i = 0 39 | nr_of_points = data.shape[0] 40 | while True: 41 | key_left = i 42 | key_right = i + batch_size 43 | if key_right >= nr_of_points: 44 | yield np.append(data[range(key_left, nr_of_points),:,:,:], data[range(0, key_right - nr_of_points),:,:,:], axis = 0).astype(np.float32) / 255. 45 | data = data[np.random.permutation(data.shape[0]),:] 46 | else: 47 | yield data[key_left:key_right, :, :, :].astype(np.float32) / 255. 48 | i = key_right % nr_of_points 49 | 50 | def resize_hdf5(input_filename, output_filename, dataset_name, batch_size, new_width = 64, new_height = 64): 51 | 52 | with h5py.File(filename,'r') as hf: 53 | data = np.array(hf.get(dataset_name)) 54 | nr_of_points = data.shape[0] 55 | depth = data.shape[3] 56 | newdata = np.empty(shape=(nr_of_points, new_width, new_height, depth), dtype=np.uint8) 57 | 58 | for i in xrange(nr_of_points): 59 | datum = data[i,:,:,:] 60 | resized_datum = np.asarray(Image.fromarray((datum * 255).astype(np.uint8), 'RGB').resize((new_width,new_height), PIL.Image.ANTIALIAS)) 61 | newdata[i,:,:,:] = resized_datum 62 | 63 | with h5py.File(output_filename, 'w') as hf: 64 | hf.create_dataset(dataset_name, data = newdata) 65 | 66 | 67 | def images_to_hdf5(dir_path, output_hdf5, size = (112,112), channels = 3, resize_to = None): 68 | files = sorted(os.listdir(dir_path)) 69 | nr_of_images = len(files) 70 | if resize_to: 71 | size = resize_to 72 | i = 0 73 | pbar = ProgressBar(widgets=[Percentage(), Bar()], maxval=nr_of_images).start() 74 | data = np.empty(shape=(nr_of_images, size[0], size[1], channels), dtype=np.uint8) 75 | for f in files: 76 | datum = imread(dir_path + '/' + f) 77 | if resize_to: 78 | datum = np.asarray(Image.fromarray((datum), 'RGB').resize((size[0],size[1]), PIL.Image.ANTIALIAS)) 79 | data[i,:,:,:] = datum 80 | i = i + 1 81 | pbar.update(i) 82 | pbar.finish() 83 | with h5py.File(output_hdf5, 'w') as hf: 84 | hf.create_dataset('data', data=data) 85 | 86 | def read_data_from_dir(dir_path, resize_to): 87 | files = os.listdir(dir_path) 88 | nr_of_images = len(files) 89 | data = { 90 | 'files': ['' for _ in xrange(nr_of_images)], 91 | 'tensors': np.empty(shape=(nr_of_images, resize_to[0], resize_to[1], resize_to[2]), dtype=np.uint8) 92 | } 93 | i = 0 94 | for f in files: 95 | datum = imread(dir_path + '/' + f) 96 | datum = np.asarray(Image.fromarray((datum), 'RGB').resize((resize_to[0],resize_to[1]), PIL.Image.ANTIALIAS)) 97 | data['tensors'][i, :, :, :] = datum 98 | data['files'][i] = f 99 | i = i + 1 100 | return data 101 | -------------------------------------------------------------------------------- /common/tensorflow_context.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import prettytensor as pt 3 | 4 | class TensorflowContext: 5 | 6 | def __init__(self, model_path = None): 7 | 8 | self.sess = tf.Session() 9 | self.prettytensor_scope = pt.defaults_scope(activation_fn=tf.nn.relu, batch_normalize=True, learned_moments_update_rate=0.0003, variance_epsilon=0.001, scale_after_normalization=True) 10 | 11 | def __enter__(self): 12 | return self 13 | 14 | def __exit__(self, exc_type, exc_value, traceback): 15 | self.sess.close() 16 | 17 | def load(self, model_path): 18 | saver = tf.train.Saver() 19 | saver.restore(self.sess, tf.train.latest_checkpoint(model_path)) 20 | 21 | def save(self, model_path): 22 | saver = tf.train.Saver() 23 | saver.save(self.sess, model_path + '/output') 24 | -------------------------------------------------------------------------------- /common/vae_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(prog='vae.py') 4 | 5 | subparsers = parser.add_subparsers(dest="command") 6 | 7 | train_parser = subparsers.add_parser("train", help="train VAE model [vae.py train -h]") 8 | sample_parser = subparsers.add_parser("sample", help="sample from existing model [vae.py sample -h]") 9 | recon_parser = subparsers.add_parser("reconstruct", help="reconstruct images based on existing model [vae.py reconstruct -h]") 10 | 11 | train_parser.add_argument('input', help='Path of HDF5 data file') 12 | train_parser.add_argument('output', help='Output tf model dir', type=str) 13 | train_parser.add_argument('--hdf5-dataset-name', help='Name of dataset in hdf5', default = 'data', metavar='') 14 | train_parser.add_argument('--batch-size', help='Batch size', default = 128, type=int, metavar='') 15 | train_parser.add_argument('--epochs', help='Number of epochs to run', default = 100, type=int, metavar='') 16 | train_parser.add_argument('--learning-rate', help='Learning rate', default = 1e-2, type=float, metavar='') 17 | train_parser.add_argument('--latent-dim', help='latent variable dimensionality', default = 30, type=int, metavar='') 18 | train_parser.add_argument('--input-width', help='Width of input images', default = 64, type=int, metavar='') 19 | train_parser.add_argument('--input-height', help='Height of input images', default = 64, type=int, metavar='') 20 | 21 | sample_parser.add_argument('input', help='Path of tensorflow model dir', type=str) 22 | sample_parser.add_argument('--latent-dim', help='latent variable dimensionality', default = 30, type=int, metavar='') 23 | sample_parser.add_argument('--input-width', help='Width of input images', default = 64, type=int, metavar='') 24 | sample_parser.add_argument('--input-height', help='Height of input images', default = 64, type=int, metavar='') 25 | sample_parser.add_argument('--output--dir', help='Output dir where png files are stored ', type=str, default='samples/', metavar='') 26 | 27 | recon_parser.add_argument('input', help='Path of tensorflow model dir', type=str) 28 | recon_parser.add_argument('inputdir', help='Path directory where input images are stored', type=str) 29 | recon_parser.add_argument('output', help='Output directory where reconstructions will be stored', type=str) 30 | recon_parser.add_argument('--latent-dim', help='latent variable dimensionality', default = 30, type=int, metavar='') 31 | recon_parser.add_argument('--input-width', help='Width of model input images', default = 64, type=int, metavar='') 32 | recon_parser.add_argument('--input-height', help='Height of model input images', default = 64, type=int, metavar='') 33 | 34 | 35 | args = parser.parse_args() 36 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/int8/VAE_tensorflow/942be29519cb57cb685e0b0ce533de9a25162285/model/__init__.py -------------------------------------------------------------------------------- /model/network_base.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | class VAEEncoderBase: 4 | 5 | def __init__(self, input_tensor_size, representation_size, batch_size): 6 | 7 | self.batch_size = batch_size 8 | self.input_tensor_size = input_tensor_size 9 | self.representation_size = representation_size 10 | 11 | self.input_data = tf.placeholder(tf.float32, [batch_size] + input_tensor_size) 12 | 13 | def guts(self, batch_size = None): 14 | raise("Implement guts() function in your Encoder Implementation") 15 | 16 | def get_network_output(self): 17 | return self.guts(); 18 | 19 | 20 | class VAEDecoderBase: 21 | 22 | def __init__(self, representation_size, batch_size): 23 | self.representation_size = representation_size 24 | self.batch_size = batch_size 25 | 26 | def guts(self, batch_size = None): 27 | raise("Implement guts() function in your Decoder Implementation") 28 | 29 | def genereate_network_output(self, input_tensor, batch_size = None): 30 | 31 | batch_size = self._determine_batch_size(batch_size) 32 | epsilon = tf.random_normal([batch_size, self.representation_size]) 33 | self.mu = input_tensor[0] 34 | # encoder output is suposed ty model log(sigma**2) (its domain is -inf, inf) 35 | self.stddev = tf.sqrt(tf.exp(input_tensor[1])) 36 | self.latent_var = self.mu + epsilon * self.stddev 37 | return self.guts(), self.mu, self.stddev 38 | 39 | 40 | def genereate_network_output_without_noise(self, input_tensor, batch_size = None): 41 | self.latent_var = input_tensor[0] 42 | return self.guts() 43 | 44 | 45 | def generate_network_random_sample(self): 46 | self.epsilon = tf.placeholder(tf.float32, [1, self.representation_size]) 47 | self.latent_var = self.epsilon 48 | return self.guts(1) 49 | 50 | def _determine_batch_size(self, batch_size): 51 | batch_size = batch_size if batch_size else self.batch_size 52 | return batch_size 53 | -------------------------------------------------------------------------------- /model/network_implementations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import prettytensor as pt 4 | from common.deconv import deconv2d 5 | from model.network_base import VAEEncoderBase, VAEDecoderBase 6 | 7 | class ConvolutionalEncoder(VAEEncoderBase): 8 | 9 | def __init__(self, input_tensor_size, representation_size, batch_size): 10 | VAEEncoderBase.__init__(self, input_tensor_size, representation_size, batch_size) 11 | 12 | def guts(self): 13 | conv_layers = (pt.wrap(self.input_data). 14 | conv2d(4, 32, stride=2, name="enc_conv1"). 15 | conv2d(4, 64, stride=2, name="enc_conv2"). 16 | conv2d(4, 128, stride=2, name="enc_conv3"). 17 | conv2d(4, 256, stride=2, name="enc_conv4"). 18 | flatten()) 19 | 20 | mu = conv_layers.fully_connected(self.representation_size, activation_fn=None, name = "mu") 21 | stddev_log_sq = conv_layers.fully_connected(self.representation_size, activation_fn=None, name = "stddev_log_sq") 22 | return mu, stddev_log_sq 23 | 24 | 25 | class DeconvolutionalDecoder(VAEDecoderBase): 26 | 27 | def __init__(self, representation_size, batch_size): 28 | VAEDecoderBase.__init__(self, representation_size, batch_size) 29 | 30 | def guts(self, batch_size = None): 31 | batch_size = self._determine_batch_size(batch_size) 32 | 33 | return (pt.wrap(self.latent_var). 34 | fully_connected(4*256, activation_fn=None, name="dec_fc1"). 35 | reshape([batch_size, 1, 1, 4*256]). 36 | deconv2d(5, 128, stride=2, edges='VALID', name="dec_deconv2"). 37 | deconv2d(5, 64, stride=2, edges='VALID', name="dec_deconv3"). 38 | deconv2d(6, 32, stride=2, edges='VALID', name="dec_deconv4"). 39 | deconv2d(6, 3, stride=2, edges="VALID", activation_fn=tf.nn.sigmoid, name="dec_deconv5")).tensor 40 | -------------------------------------------------------------------------------- /model/vaemodel.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import prettytensor as pt 3 | from common.tensorflow_context import TensorflowContext 4 | import numpy as np 5 | from progressbar import ProgressBar, Percentage, Bar 6 | 7 | class VaeAutoencoderSampler(TensorflowContext): 8 | 9 | def __init__(self, decoder, model_path): 10 | TensorflowContext.__init__(self) 11 | self.decoder = decoder 12 | 13 | with self.prettytensor_scope as scope: 14 | with tf.variable_scope("vae"): 15 | self.sample = self.decoder.generate_network_random_sample() 16 | self.load(model_path) 17 | 18 | def generate_random_sample(self, epsilon): 19 | sample = self.sess.run(self.sample, {self.decoder.epsilon: epsilon}) 20 | return sample 21 | 22 | def walk_between_points(self, a, b, number_of_steps): 23 | points = [a] + [np.add(a * (1 - 1. * i/number_of_steps) , b * (1. * i/number_of_steps)) for i in range(1, number_of_steps)] + [b] 24 | samples = [] 25 | for point in points: 26 | sample = self.sess.run(self.sample, {self.decoder.epsilon: point}) 27 | samples.append(sample) 28 | return samples 29 | 30 | class VaeAutoencoderReconstructor(TensorflowContext): 31 | 32 | def __init__(self, encoder, decoder, model_path): 33 | TensorflowContext.__init__(self) 34 | 35 | self.encoder = encoder 36 | self.decoder = decoder 37 | 38 | with self.prettytensor_scope: 39 | with tf.variable_scope("vae"): 40 | self.reconstruction = self.decoder.genereate_network_output_without_noise(self.encoder.get_network_output()) 41 | self.load(model_path) 42 | 43 | def reconstruct(self, input_data): 44 | reconstructions = self.sess.run(self.reconstruction, { self.encoder.input_data: input_data }) 45 | return reconstructions 46 | 47 | class VaeAutoencoderTrainer(TensorflowContext): 48 | 49 | def __init__(self, encoder, decoder, hdf5reader): 50 | TensorflowContext.__init__(self) 51 | self.encoder = encoder 52 | self.decoder = decoder 53 | self.hdf5reader = hdf5reader 54 | 55 | # compose autoencoder + error function + sampling routine 56 | with self.prettytensor_scope: 57 | with tf.variable_scope("vae"): 58 | self.reconstruction, self.mu, self.stddev = self.decoder.genereate_network_output(self.encoder.get_network_output()) 59 | self.build_error_function() 60 | 61 | def build_error_function(self, eps = 1e-9): 62 | 63 | self.prob_reconstruction_error = -tf.reduce_sum( 64 | self.encoder.input_data * tf.log(self.reconstruction + eps) + (1.0 - self.encoder.input_data) * tf.log(1.0 - self.reconstruction + eps), 65 | reduction_indices = [1,2,3] 66 | ) 67 | 68 | # alternatively one could use the following reconstruction error component 69 | # self.euc_reconstruction_error = tf.reduce_sum(tf.pow(self.encoder.input_data - self.reconstruction + eps,2)) 70 | 71 | self.vae_error = - 0.5 * tf.reduce_sum( 72 | 1 + tf.log(tf.square(self.stddev + eps)) - tf.square(self.mu) - tf.square(self.stddev), 73 | reduction_indices = 1 74 | ) 75 | 76 | self.error_function = tf.reduce_mean(self.vae_error + self.prob_reconstruction_error) 77 | 78 | def train(self, epochs, batch_size, learning_rate, save_to=None): 79 | 80 | self.train_step = pt.apply_optimizer(tf.train.AdamOptimizer(learning_rate, epsilon=1), losses = [self.error_function]) 81 | init = tf.initialize_all_variables() 82 | self.sess.run(init) 83 | pbar = ProgressBar(widgets=[Percentage(), Bar()], maxval=epochs).start() 84 | while self.get_epoch() < epochs: 85 | input_data = self.hdf5reader.next() 86 | _, loss_value = self.sess.run( 87 | [self.train_step, self.error_function], 88 | { 89 | self.encoder.input_data: input_data 90 | } 91 | ) 92 | pbar.update(self.get_epoch()) 93 | pbar.finish() 94 | 95 | def get_epoch(self): 96 | return self.hdf5reader.get_cycles() 97 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from scipy.misc import imsave 4 | from common.vae_args import args 5 | from common.io_utils import HDF5Reader, read_data_from_dir 6 | from model.network_implementations import ConvolutionalEncoder, DeconvolutionalDecoder 7 | from model.vaemodel import VaeAutoencoderTrainer, VaeAutoencoderSampler, VaeAutoencoderReconstructor 8 | 9 | if __name__ == "__main__": 10 | 11 | if args.command == "train": 12 | input_size = [args.input_width, args.input_height, 3] 13 | hdf5reader = HDF5Reader(args.input, args.hdf5_dataset_name, args.batch_size) 14 | # build encoder graph 15 | encoder = ConvolutionalEncoder(input_size, args.latent_dim, args.batch_size) 16 | # build decoder graph 17 | decoder = DeconvolutionalDecoder(args.latent_dim, args.batch_size) 18 | 19 | with VaeAutoencoderTrainer(encoder, decoder, hdf5reader) as trainer: 20 | trainer.train(args.epochs, args.batch_size, args.learning_rate) 21 | trainer.save(args.output) 22 | 23 | elif args.command == "reconstruct": 24 | 25 | input_size = [args.input_width, args.input_height, 3] 26 | # data from directory 27 | data = read_data_from_dir(args.inputdir, input_size) 28 | # build encoder graph 29 | encoder = ConvolutionalEncoder(input_size, args.latent_dim , len(data['files'])) 30 | # build decoder graph 31 | decoder = DeconvolutionalDecoder(args.latent_dim, len(data['files'])) 32 | 33 | with VaeAutoencoderReconstructor(encoder, decoder, args.input) as reconstructor: 34 | r = reconstructor.reconstruct(data['tensors']) 35 | for i in xrange(r.shape[0]): 36 | imsave(args.output + '/reconstruction_' + str(data['files'][i]), r[i,:,:,:]) 37 | 38 | elif args.command == "sample": 39 | # build decoder graph 40 | decoder = DeconvolutionalDecoder(args.latent_dim, 1) 41 | with VaeAutoencoderSampler(decoder, args.input) as sampler: 42 | a = np.random.randn(1,args.latent_dim) 43 | b = np.random.randn(1,args.latent_dim) 44 | 45 | for k in xrange(30): 46 | a = b.copy() 47 | b = np.random.randn(1,args.latent_dim) 48 | samples = sampler.walk_between_points(a, b, 10) 49 | for i in xrange(len(samples)): 50 | imsave(args.output_dir + str(k) + '_' + str(i) + '.png', samples[i][0,:,:,:]) 51 | --------------------------------------------------------------------------------