├── README.md ├── bad_generated_balls.jpg ├── bad_real_balls.jpg ├── bouncing_balls.py ├── generated_balls.jpg ├── layer_def.py ├── main.py ├── main_all_conv.py └── real_balls.jpg /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Variational Autoencoders 3 | Variational autoencoders are pretty nice and, in my experience, a lot better then denoising encoders. They can be a bit tricky to train though so I made a small troubleshooting guild. To make the code a little more fun I used a dataset of bouncing ball images instead of MNIST. The code to generate bouncing ball images is included and was originally from Ilya Sutskever's [Recurrent Temporal Restricted Boltzmann Machine](http://www.uoguelph.ca/~gwtaylor/publications/nips2008/rtrbm.pdf). I converted the ball bouncing to 32, 32, 3 images where the second and third color depict the x and y velocity. 4 | 5 | # Troubleshooting 6 | 7 | Here is a short list of the problems I had getting training to work. Hopefully this will save someone a little time. 8 | 9 | ## Getting NANS a few steps after training 10 | Check to make sure that the output of your network is going through a sigmoid layer. The loss on the reconstruction will NAN if there are negatives because of the logs. 11 | 12 | ## Not NANing but not converging well either 13 | Ok, so this is really the whole reason I made this github. I had a problem for a while where my loss and network appeared to be set up just fine however it was not training beyond the average. After a bit of digging I found it was how I init my layers. Because I was just using a network from another problem I was working on, I init the layers very small (around .001 for both conv and fully connected). This caused problems because when first run the autoencoder produced small values for the mean and stddev encoder part causing the vae loss to be small. It seems that it gets stuck in this minimum and the reconstruction loss never really falls. In most tutorials they seem to have no problem with this for 2 reasons. They do mini batch normalization or they use xavier initialization. When I was first looking at this I didn't really want to do batch normalization and didn't think xavier initialization really mattered. Then I proceeded to waste several hours. 14 | 15 | For the example code I just set the fully connected layer to init to .1 std and that fixed it. I put it as a flag so you can see that if its set to .001 it will not converge. 16 | 17 | # Pictures!!! 18 | true image 19 | ![alt text](https://github.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/blob/master/real_balls.jpg) 20 | generated image 21 | ![alt text](https://github.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/blob/master/generated_balls.jpg) 22 | . This is only after like 10 mins on a cpu though. With the same training time and .001 std of the fully connected it does this, 23 | ![alt text](https://github.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/blob/master/bad_generated_balls.jpg) 24 | when the true is 25 | ![alt text](https://github.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/blob/master/bad_real_balls.jpg) 26 | and the same amount of training time 27 | 28 | -------------------------------------------------------------------------------- /bad_generated_balls.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/69506121e8632e451bb3bfeb19442eccb9cdaa92/bad_generated_balls.jpg -------------------------------------------------------------------------------- /bad_real_balls.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/69506121e8632e451bb3bfeb19442eccb9cdaa92/bad_real_balls.jpg -------------------------------------------------------------------------------- /bouncing_balls.py: -------------------------------------------------------------------------------- 1 | 2 | ####################################################### 3 | # 4 | # This code was taken from Ilya Sutskever project on 5 | # Recurrent Temporal Restricted Boltzmann Machines 6 | # The original source can be found 7 | # http://www.cs.utoronto.ca/~ilya/code/2008/RTRBM.tar 8 | # There have been a few modifications to this code 9 | # including adding gravity and dampening 10 | # 11 | ####################################################### 12 | 13 | 14 | from pylab import * 15 | 16 | import tensorflow as tf 17 | 18 | FLAGS = tf.app.flags.FLAGS 19 | 20 | tf.app.flags.DEFINE_bool('friction', False, 21 | """whether there is friction in the system""") 22 | tf.app.flags.DEFINE_integer('num_balls', 2, 23 | """num of balls in the simulation""") 24 | 25 | 26 | def norm(x): return sqrt((x**2).sum()) 27 | def sigmoid(x): return 1./(1.+exp(-x)) 28 | 29 | 30 | SIZE=10 31 | 32 | def new_speeds(m1, m2, v1, v2): 33 | new_v2 = (2*m1*v1 + v2*(m2-m1))/(m1+m2) 34 | new_v1 = new_v2 + (v2 - v1) 35 | return new_v1, new_v2 36 | 37 | # size of bounding box: SIZE X SIZE. 38 | 39 | def bounce_n(T=128, n=2, r=None, m=None): 40 | if r==None: r=array([4.0]*n) 41 | if m==None: m=array([1]*n) 42 | # r is to be rather small. 43 | X=zeros((T, n, 2), dtype='float') 44 | V = zeros((T, n, 2), dtype='float') 45 | v = randn(n,2) 46 | v = (v / norm(v)*.5)*1.0 47 | good_config=False 48 | while not good_config: 49 | x = 2+rand(n,2)*8 50 | good_config=True 51 | for i in range(n): 52 | for z in range(2): 53 | if x[i][z]-r[i]<0: good_config=False 54 | if x[i][z]+r[i]>SIZE: good_config=False 55 | 56 | # that's the main part. 57 | for i in range(n): 58 | for j in range(i): 59 | if norm(x[i]-x[j])SIZE: v[i][z]=-abs(v[i][z]) # want negative 89 | for i in range(n): 90 | for j in range(i): 91 | if norm(x[i]-x[j]) 0) and (x[i][0] < size) and (x[i][1] > 0) and (x[i][1] < size): 93 | # if (x[i][0] > 0) and (x[i][0] < size) and (x[i][1] > 0) and (x[i][1] < size): 94 | # the bouncing off part: 95 | w = x[i]-x[j] 96 | w = w / norm(w) 97 | 98 | v_i = dot(w.transpose(),v[i]) 99 | v_j = dot(w.transpose(),v[j]) 100 | 101 | new_v_i, new_v_j = new_speeds(m[i], m[j], v_i, v_j) 102 | 103 | v[i]+= w*(new_v_i - v_i) 104 | v[j]+= w*(new_v_j - v_j) 105 | 106 | 107 | 108 | ''' 109 | if flip: 110 | flip = False 111 | for i in range(n): 112 | for j in range(i): 113 | if norm(x[i]-x[j]) 0) and (x[i][0] < size) and (x[i][1] > 0) and (x[i][1] < size): 115 | # if (x[i][0] > 0) and (x[i][0] < size) and (x[i][1] > 0) and (x[i][1] < size): 116 | # the bouncing off part: 117 | w = x[i]-x[j] 118 | w = w / norm(w) 119 | 120 | v_i = dot(w.transpose(),v[i]) 121 | v_j = dot(w.transpose(),v[j]) 122 | 123 | new_v_i, new_v_j = new_speeds(m[i], m[j], v_i, v_j) 124 | 125 | v[i]+= w*(new_v_i - v_i) 126 | v[j]+= w*(new_v_j - v_j) 127 | 128 | else: 129 | flip = True 130 | for i in range(n): 131 | for j in range(i): 132 | if norm(x[(n-1)-i]-x[(n-1)-j]) 0) and (x[i][0] < size) and (x[i][1] > 0) and (x[i][1] < size): 134 | # if (x[i][0] > 0) and (x[i][0] < size) and (x[i][1] > 0) and (x[i][1] < size): 135 | # the bouncing off part: 136 | w = x[(n-1)-i]-x[(n-1)-j] 137 | w = w / norm(w) 138 | 139 | v_i = dot(w.transpose(),v[(n-1)-i]) 140 | v_j = dot(w.transpose(),v[(n-1)-j]) 141 | 142 | new_v_i, new_v_j = new_speeds(m[(n-1)-i], m[(n-1)-j], v_i, v_j) 143 | 144 | v[(n-1)-i]+= w*(new_v_i - v_i) 145 | v[(n-1)-j]+= w*(new_v_j - v_j) 146 | 147 | ''' 148 | return X, V 149 | 150 | def ar(x,y,z): 151 | return z/2+arange(x,y,z,dtype='float') 152 | 153 | def matricize(X,V,res,r=None): 154 | 155 | T, n= shape(X)[0:2] 156 | if r==None: r=array([4.0]*n) 157 | 158 | A=zeros((T,res,res, 3), dtype='float') 159 | 160 | [I, J]=meshgrid(ar(0,1,1./res)*SIZE, ar(0,1,1./res)*SIZE) 161 | 162 | for t in range(T): 163 | for i in range(n): 164 | A[t, :, :, 1] += exp(-( ((I-X[t,i,0])**2+(J-X[t,i,1])**2)/(r[i]**2) )**4 ) 165 | A[t, :, :, 0] += 1.0 * (V[t,i,0] + .5) * exp(-( ((I-X[t,i,0])**2+(J-X[t,i,1])**2)/(r[i]**2) )**4 ) 166 | A[t, :, :, 2] += 1.0 * (V[t,i,1] + .5) * exp(-( ((I-X[t,i,0])**2+(J-X[t,i,1])**2)/(r[i]**2) )**4 ) 167 | 168 | A[t,:,:,0][A[t,:,:,0]>1]=1 169 | A[t,:,:,1][A[t,:,:,1]>1]=1 170 | A[t,:,:,2][A[t,:,:,2]>1]=1 171 | return A 172 | 173 | def bounce_mat(res, n=2, T=128, r =None): 174 | if r==None: r=array([1.2]*n) 175 | x = bounce_n(T,n,r); 176 | A = matricize(x,res,r) 177 | return A 178 | 179 | def bounce_vec(res, n=2, T=128, r =None, m =None): 180 | if r==None: r=array([1.2]*n) 181 | x,v = bounce_n(T,n,r,m); 182 | V = matricize(x,v,res,r) 183 | return V 184 | 185 | def show_single_V(V): 186 | res = int(sqrt(shape(V)[0])) 187 | show(V.reshape(res, res)) 188 | 189 | def show_V(V): 190 | T = len(V) 191 | res = int(sqrt(shape(V)[1])) 192 | for t in range(T): 193 | show(V[t].reshape(res, res)) 194 | 195 | def unsigmoid(x): return log (x) - log (1-x) 196 | 197 | def show_A(A): 198 | T = len(A) 199 | for t in range(T): 200 | show(A[t]) 201 | 202 | 203 | -------------------------------------------------------------------------------- /generated_balls.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/69506121e8632e451bb3bfeb19442eccb9cdaa92/generated_balls.jpg -------------------------------------------------------------------------------- /layer_def.py: -------------------------------------------------------------------------------- 1 | 2 | """functions used to construct different architectures 3 | """ 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | tf.app.flags.DEFINE_float('weight_decay', 0.0005, 11 | """ """) 12 | 13 | def _activation_summary(x): 14 | """Helper to create summaries for activations. 15 | 16 | Creates a summary that provides a histogram of activations. 17 | Creates a summary that measure the sparsity of activations. 18 | 19 | Args: 20 | x: Tensor 21 | Returns: 22 | nothing 23 | """ 24 | tensor_name = x.op.name 25 | tf.histogram_summary(tensor_name + '/activations', x) 26 | tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) 27 | 28 | def _variable_on_cpu(name, shape, initializer): 29 | """Helper to create a Variable stored on CPU memory. 30 | 31 | Args: 32 | name: name of the variable 33 | shape: list of ints 34 | initializer: initializer for Variable 35 | 36 | Returns: 37 | Variable Tensor 38 | """ 39 | with tf.device('/cpu:0'): 40 | var = tf.get_variable(name, shape, initializer=initializer) 41 | return var 42 | 43 | 44 | def _variable_with_weight_decay(name, shape, stddev, wd): 45 | """Helper to create an initialized Variable with weight decay. 46 | 47 | Note that the Variable is initialized with a truncated normal distribution. 48 | A weight decay is added only if one is specified. 49 | 50 | Args: 51 | name: name of the variable 52 | shape: list of ints 53 | stddev: standard deviation of a truncated Gaussian 54 | wd: add L2Loss weight decay multiplied by this float. If None, weight 55 | decay is not added for this Variable. 56 | 57 | Returns: 58 | Variable Tensor 59 | """ 60 | var = _variable_on_cpu(name, shape, 61 | tf.truncated_normal_initializer(stddev=stddev)) 62 | if wd: 63 | weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') 64 | weight_decay.set_shape([]) 65 | tf.add_to_collection('losses', weight_decay) 66 | return var 67 | 68 | def conv_layer(inputs, kernel_size, stride, num_features, idx, linear = False): 69 | with tf.variable_scope('{0}_conv'.format(idx)) as scope: 70 | input_channels = inputs.get_shape()[3] 71 | 72 | weights = _variable_with_weight_decay('weights', shape=[kernel_size,kernel_size,input_channels,num_features],stddev=0.01, wd=FLAGS.weight_decay) 73 | biases = _variable_on_cpu('biases',[num_features],tf.constant_initializer(0.01)) 74 | 75 | conv = tf.nn.conv2d(inputs, weights, strides=[1, stride, stride, 1], padding='SAME') 76 | conv_biased = tf.nn.bias_add(conv, biases) 77 | if linear: 78 | return conv_biased 79 | conv_rect = tf.nn.elu(conv_biased,name='{0}_conv'.format(idx)) 80 | return conv_rect 81 | 82 | def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, linear = False): 83 | with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope: 84 | input_channels = inputs.get_shape()[3] 85 | 86 | weights = _variable_with_weight_decay('weights', shape=[kernel_size,kernel_size,num_features,input_channels], stddev=0.01, wd=FLAGS.weight_decay) 87 | biases = _variable_on_cpu('biases',[num_features],tf.constant_initializer(0.01)) 88 | batch_size = tf.shape(inputs)[0] 89 | output_shape = tf.pack([tf.shape(inputs)[0], tf.shape(inputs)[1]*stride, tf.shape(inputs)[2]*stride, num_features]) 90 | conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides=[1,stride,stride,1], padding='SAME') 91 | conv_biased = tf.nn.bias_add(conv, biases) 92 | if linear: 93 | return conv_biased 94 | conv_rect = tf.nn.elu(conv_biased,name='{0}_transpose_conv'.format(idx)) 95 | return conv_rect 96 | 97 | 98 | def fc_layer(inputs, hiddens, idx, flat = False, linear = False): 99 | with tf.variable_scope('{0}_fc'.format(idx)) as scope: 100 | input_shape = inputs.get_shape().as_list() 101 | if flat: 102 | dim = input_shape[1]*input_shape[2]*input_shape[3] 103 | inputs_processed = tf.reshape(inputs, [-1,dim]) 104 | else: 105 | dim = input_shape[1] 106 | inputs_processed = inputs 107 | 108 | weights = _variable_with_weight_decay('weights', shape=[dim,hiddens],stddev=FLAGS.weight_init, wd=FLAGS.weight_decay) 109 | biases = _variable_on_cpu('biases', [hiddens], tf.constant_initializer(FLAGS.weight_init)) 110 | if linear: 111 | return tf.add(tf.matmul(inputs_processed,weights),biases,name=str(idx)+'_fc') 112 | 113 | ip = tf.add(tf.matmul(inputs_processed,weights),biases) 114 | return tf.nn.elu(ip,name=str(idx)+'_fc') 115 | 116 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os.path 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | import cv2 9 | 10 | import bouncing_balls as b 11 | import layer_def as ld 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | FLAGS = tf.app.flags.FLAGS 17 | 18 | tf.app.flags.DEFINE_string('train_dir', '../checkpoints/train_store', 19 | """dir to store trained net""") 20 | tf.app.flags.DEFINE_integer('hidden_size', 20, 21 | """size of hidden layer""") 22 | tf.app.flags.DEFINE_integer('max_step', 50000, 23 | """max num of steps""") 24 | tf.app.flags.DEFINE_float('keep_prob', 1.0, 25 | """for dropout""") 26 | tf.app.flags.DEFINE_float('beta', .1, 27 | """ beta constant """) 28 | tf.app.flags.DEFINE_float('lr', .001, 29 | """for dropout""") 30 | tf.app.flags.DEFINE_integer('batch_size', 128, 31 | """batch size for training""") 32 | tf.app.flags.DEFINE_float('weight_init', .1, 33 | """weight init for fully connected layers""") 34 | 35 | 36 | def train(): 37 | """Train ring_net for a number of steps.""" 38 | with tf.Graph().as_default(): 39 | # make inputs 40 | x = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3]) 41 | 42 | # possible dropout inside 43 | keep_prob = tf.placeholder("float") 44 | 45 | # create network 46 | # encodeing part first 47 | # conv1 48 | conv1 = ld.conv_layer(x, 3, 2, 8, "encode_1") 49 | # conv2 50 | conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2") 51 | # conv3 52 | conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3") 53 | # conv4 54 | conv4 = ld.conv_layer(conv3, 1, 1, 4, "encode_4") 55 | # fc5 56 | fc5 = ld.fc_layer(conv4, 128, "encode_5", True, False) 57 | # dropout maybe 58 | fc5_dropout = tf.nn.dropout(fc5, keep_prob) 59 | # y 60 | y = ld.fc_layer(fc5_dropout, (FLAGS.hidden_size) * 2, "encode_6", False, True) 61 | mean, stddev = tf.split(1, 2, y) 62 | stddev = tf.sqrt(tf.exp(stddev)) 63 | # now decoding part 64 | # sample distrobution 65 | epsilon = tf.random_normal(mean.get_shape()) 66 | y_sampled = mean + epsilon * stddev 67 | # fc7 68 | fc7 = ld.fc_layer(y_sampled, 128, "decode_7", False, False) 69 | # fc8 70 | fc8 = ld.fc_layer(fc7, 4*8*8, "decode_8", False, False) 71 | conv9 = tf.reshape(fc8, [-1, 8, 8, 4]) 72 | # conv10 73 | conv10 = ld.transpose_conv_layer(conv9, 1, 1, 8, "decode_9") 74 | # conv11 75 | conv11 = ld.transpose_conv_layer(conv10, 3, 2, 8, "decode_10") 76 | # conv12 77 | conv12 = ld.transpose_conv_layer(conv11, 3, 1, 8, "decode_11") 78 | # conv13 79 | conv13 = ld.transpose_conv_layer(conv12, 3, 2, 3, "decode_12", True) 80 | # x_prime 81 | x_prime = conv13 82 | x_prime = tf.nn.sigmoid(x_prime) 83 | 84 | # now calc loss 85 | epsilon = 1e-8 86 | # calc loss from vae 87 | kl_loss = 0.5 * (tf.square(mean) + tf.square(stddev) - 88 | 2.0 * tf.log(stddev + epsilon) - 1.0) 89 | loss_vae = FLAGS.beta * tf.reduce_sum(kl_loss) 90 | # log loss for reconstruction 91 | loss_reconstruction = tf.reduce_sum(-x * tf.log(x_prime + epsilon) - 92 | (1.0 - x) * tf.log(1.0 - x_prime + epsilon)) 93 | # save for tensorboard 94 | tf.scalar_summary('loss_vae', loss_vae) 95 | tf.scalar_summary('loss_reconstruction', loss_reconstruction) 96 | # calc total loss 97 | loss = tf.reduce_sum(loss_vae + loss_reconstruction) 98 | 99 | # training 100 | train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss) 101 | 102 | # List of all Variables 103 | variables = tf.all_variables() 104 | 105 | # Build a saver 106 | saver = tf.train.Saver(tf.all_variables()) 107 | 108 | # Summary op 109 | summary_op = tf.merge_all_summaries() 110 | 111 | # Build an initialization operation to run below. 112 | init = tf.initialize_all_variables() 113 | 114 | # Start running operations on the Graph. 115 | sess = tf.Session() 116 | 117 | # init if this is the very time training 118 | print("init network from scratch") 119 | sess.run(init) 120 | 121 | # Summary op 122 | graph_def = sess.graph.as_graph_def(add_shapes=True) 123 | summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=graph_def) 124 | 125 | for step in xrange(FLAGS.max_step): 126 | dat = b.bounce_vec(32, FLAGS.num_balls, FLAGS.batch_size) 127 | t = time.time() 128 | _, loss_r = sess.run([train_op, loss],feed_dict={x:dat, keep_prob:FLAGS.keep_prob}) 129 | elapsed = time.time() - t 130 | #print(elapsed) 131 | 132 | if step%500 == 0: 133 | _ , loss_vae_r, loss_reconstruction_r, y_sampled_r, x_prime_r, kl_loss_dis, stddev_r = sess.run([train_op, loss_vae, loss_reconstruction, y_sampled, x_prime, kl_loss, stddev],feed_dict={x:dat, keep_prob:FLAGS.keep_prob}) 134 | summary_str = sess.run(summary_op, feed_dict={x:dat, keep_prob:FLAGS.keep_prob}) 135 | summary_writer.add_summary(summary_str, step) 136 | print("loss vae value at " + str(loss_vae_r)) 137 | print("loss reconstruction value at " + str(loss_reconstruction_r)) 138 | print("min sampled vector " + str(np.min(y_sampled_r))) 139 | print("max sampled vector " + str(np.max(y_sampled_r))) 140 | print("time per batch is " + str(elapsed)) 141 | cv2.imwrite("real_balls.jpg", np.uint8(dat[0, :, :, :]*255)) 142 | cv2.imwrite("generated_balls.jpg", np.uint8(x_prime_r[0, :, :, :]*255)) 143 | kl_loss_dis = np.sort(np.sum(kl_loss_dis, axis=0)) 144 | stddev_r = np.sort(np.sum(stddev_r, axis=0)) 145 | #plt.plot(kl_loss_dis, label="step " + str(step)) 146 | #plt.legend(loc = 'center left') 147 | #plt.savefig('kl_error_dis.png') 148 | plt.plot(stddev_r, label="step " + str(step)) 149 | plt.legend(loc = 'center left') 150 | plt.savefig('stddev_r.png') 151 | 152 | assert not np.isnan(loss_r), 'Model diverged with loss = NaN' 153 | 154 | if step%1000 == 0: 155 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') 156 | saver.save(sess, checkpoint_path, global_step=step) 157 | print("saved to " + FLAGS.train_dir) 158 | print("step " + str(step)) 159 | 160 | def main(argv=None): # pylint: disable=unused-argument 161 | if tf.gfile.Exists(FLAGS.train_dir): 162 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 163 | tf.gfile.MakeDirs(FLAGS.train_dir) 164 | train() 165 | 166 | if __name__ == '__main__': 167 | tf.app.run() 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /main_all_conv.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os.path 4 | import time 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | import cv2 9 | 10 | import bouncing_balls as b 11 | import layer_def as ld 12 | 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | FLAGS = tf.app.flags.FLAGS 17 | 18 | tf.app.flags.DEFINE_string('train_dir', '../checkpoints/train_store', 19 | """dir to store trained net""") 20 | tf.app.flags.DEFINE_integer('hidden_size', 20, 21 | """size of hidden layer""") 22 | tf.app.flags.DEFINE_integer('max_step', 50000, 23 | """max num of steps""") 24 | tf.app.flags.DEFINE_float('keep_prob', .5, 25 | """for dropout""") 26 | tf.app.flags.DEFINE_float('beta', 0.5, 27 | """constant for VAE loss""") 28 | tf.app.flags.DEFINE_float('lr', .001, 29 | """for dropout""") 30 | tf.app.flags.DEFINE_integer('batch_size', 128, 31 | """batch size for training""") 32 | tf.app.flags.DEFINE_float('weight_init', .1, 33 | """weight init for fully connected layers""") 34 | 35 | 36 | def train(): 37 | """Train ring_net for a number of steps.""" 38 | with tf.Graph().as_default(): 39 | # make inputs 40 | x = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3]) 41 | 42 | # possible dropout inside 43 | keep_prob = tf.placeholder("float") 44 | 45 | # create network 46 | # encodeing part first 47 | # conv1 48 | conv1 = ld.conv_layer(x, 3, 2, 8, "encode_1") 49 | # conv2 50 | conv2 = ld.conv_layer(conv1, 3, 1, 8, "encode_2") 51 | # conv3 52 | conv3 = ld.conv_layer(conv2, 3, 2, 8, "encode_3") 53 | # conv4 54 | conv4 = ld.conv_layer(conv3, 3, 2, 2, "encode_5", True) 55 | mean, stddev = tf.split(3, 2, conv4) 56 | stddev = tf.sqrt(tf.exp(stddev)) 57 | # now decoding part 58 | # sample distrobution 59 | epsilon = tf.random_normal(mean.get_shape()) 60 | y_sampled = mean + epsilon * stddev 61 | # conv10 62 | conv6 = ld.transpose_conv_layer(y_sampled, 3, 2, 8, "decode_5") 63 | # conv12 64 | conv7 = ld.transpose_conv_layer(conv6, 3, 2, 8, "decode_6") 65 | # conv13 66 | conv8 = ld.transpose_conv_layer(conv7, 3, 1, 8, "decode_7") 67 | # conv14 68 | conv9 = ld.transpose_conv_layer(conv8, 3, 2, 3, "decode_8", True) 69 | # x_prime 70 | x_prime = conv9 71 | x_prime = tf.nn.sigmoid(x_prime) 72 | 73 | # now calc loss 74 | epsilon = 1e-8 75 | # calc loss from vae 76 | kl_loss = 0.5 * (tf.square(mean) + tf.square(stddev) - 77 | 2.0 * tf.log(stddev + epsilon) - 1.0) 78 | loss_vae = FLAGS.beta * tf.reduce_sum(kl_loss) 79 | # log loss for reconstruction 80 | loss_reconstruction = tf.reduce_sum(-x * tf.log(x_prime + epsilon) - 81 | (1.0 - x) * tf.log(1.0 - x_prime + epsilon)) 82 | # save for tensorboard 83 | tf.scalar_summary('loss_vae', loss_vae) 84 | tf.scalar_summary('loss_reconstruction', loss_reconstruction) 85 | # calc total loss 86 | loss = tf.reduce_sum(loss_vae + loss_reconstruction) 87 | 88 | # training 89 | train_op = tf.train.AdamOptimizer(FLAGS.lr).minimize(loss) 90 | 91 | # List of all Variables 92 | variables = tf.all_variables() 93 | 94 | # Build a saver 95 | saver = tf.train.Saver(tf.all_variables()) 96 | 97 | # Summary op 98 | summary_op = tf.merge_all_summaries() 99 | 100 | # Build an initialization operation to run below. 101 | init = tf.initialize_all_variables() 102 | 103 | # Start running operations on the Graph. 104 | sess = tf.Session() 105 | 106 | # init if this is the very time training 107 | print("init network from scratch") 108 | sess.run(init) 109 | 110 | # Summary op 111 | graph_def = sess.graph.as_graph_def(add_shapes=True) 112 | summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, graph_def=graph_def) 113 | 114 | for step in xrange(FLAGS.max_step): 115 | dat = b.bounce_vec(32, FLAGS.num_balls, FLAGS.batch_size) 116 | t = time.time() 117 | _, loss_r = sess.run([train_op, loss],feed_dict={x:dat, keep_prob:FLAGS.keep_prob}) 118 | elapsed = time.time() - t 119 | 120 | if step%500 == 0: 121 | _ , loss_vae_r, loss_reconstruction_r, y_sampled_r, x_prime_r, kl_loss_dis, stddev_r = sess.run([train_op, loss_vae, loss_reconstruction, y_sampled, x_prime, kl_loss, stddev],feed_dict={x:dat, keep_prob:FLAGS.keep_prob}) 122 | summary_str = sess.run(summary_op, feed_dict={x:dat, keep_prob:FLAGS.keep_prob}) 123 | summary_writer.add_summary(summary_str, step) 124 | print("loss vae value at " + str(loss_vae_r)) 125 | print("loss reconstruction value at " + str(loss_reconstruction_r)) 126 | print("min sampled vector " + str(np.min(y_sampled_r))) 127 | print("max sampled vector " + str(np.max(y_sampled_r))) 128 | print("time per batch is " + str(elapsed)) 129 | cv2.imwrite("real_balls.jpg", np.uint8(dat[0, :, :, :]*255)) 130 | cv2.imwrite("generated_balls.jpg", np.uint8(x_prime_r[0, :, :, :]*255)) 131 | kl_loss_dis = np.sort(np.sum(kl_loss_dis.reshape(FLAGS.batch_size, 16), axis=0)) 132 | stddev_r = np.sort(np.sum(stddev_r.reshape(FLAGS.batch_size, 16), axis=0)) 133 | #plt.plot(kl_loss_dis, label="step " + str(step)) 134 | #plt.legend(loc = 'center left') 135 | #plt.savefig('kl_error_dis.png') 136 | plt.plot(stddev_r, label="step " + str(step)) 137 | plt.legend(loc = 'center left') 138 | plt.savefig('stddev_r.png') 139 | 140 | assert not np.isnan(loss_r), 'Model diverged with loss = NaN' 141 | 142 | if step%1000 == 0: 143 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') 144 | saver.save(sess, checkpoint_path, global_step=step) 145 | print("saved to " + FLAGS.train_dir) 146 | print("step " + str(step)) 147 | 148 | def main(argv=None): # pylint: disable=unused-argument 149 | if tf.gfile.Exists(FLAGS.train_dir): 150 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 151 | tf.gfile.MakeDirs(FLAGS.train_dir) 152 | train() 153 | 154 | if __name__ == '__main__': 155 | tf.app.run() 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /real_balls.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/loliverhennigh/Variational-autoencoder-tricks-and-tips/69506121e8632e451bb3bfeb19442eccb9cdaa92/real_balls.jpg --------------------------------------------------------------------------------