├── .gitignore ├── res.sh ├── README.md ├── mlp_vae.py ├── vae.py └── autoencoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | mnist.pkl.gz 2 | temp/ 3 | *.pyc 4 | *.swp 5 | *.*~ 6 | *.png 7 | *checkpoint* 8 | -------------------------------------------------------------------------------- /res.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for f in images/*.png; do 3 | convert "$f" -resize 200x200 "$f" 4 | done 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MNIST Autoencoder 2 | 3 | Just an autoencoder example, specifically for the MNIST dataset. 4 | This compresses the image from 786 dimmensions down to 8. 5 | 6 | ## Usage 7 | ```python 8 | >>> python autoencoder.py 9 | ``` 10 | 11 | This will automatically download the MNIST dataset. Batch size is set to 1000 so you might want to change it depending on your system. 12 | A checkpoint is included in the `checkpoint` folder, and will automatically be loaded upon running. To train your own, delete the checkpoint. 13 | 14 | ### Results 15 | 16 | It didn't do too bad considering the images are compressed to only 8 dimensions. Below are some examples. 17 | 18 | ![img](http://i.imgur.com/Qa6HfhT.png) ![img](http://i.imgur.com/EGekJBm.png) 19 | 20 | ![img](http://i.imgur.com/HGo4Rso.png) ![img](http://i.imgur.com/WKnig11.png) 21 | 22 | ![img](http://i.imgur.com/GTM05PF.png) ![img](http://i.imgur.com/D0WpaNy.png) 23 | 24 | ![img](http://i.imgur.com/uBm7gGD.png) ![img](http://i.imgur.com/ba3vFsd.png) 25 | 26 | -------------------------------------------------------------------------------- /mlp_vae.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cPickle as pickle 3 | import tensorflow as tf 4 | import tensorflow.contrib.layers as tcl 5 | import numpy as np 6 | import requests 7 | import random 8 | import gzip 9 | import os 10 | 11 | batch_size = 128 12 | 13 | ''' 14 | Leaky RELU 15 | ''' 16 | def lrelu(x, leak=0.2, name="lrelu"): 17 | return tf.maximum(x, leak*x) 18 | 19 | 20 | def encoder(x): 21 | 22 | mean = tf.layers.dense(e_conv2_flat, 32, name='mean') 23 | 24 | stddev = tf.layers.dense(e_conv2_flat, 32, name='stddev') 25 | 26 | return mean, stddev 27 | 28 | def decoder(z): 29 | print 30 | print 'z: ', z 31 | 32 | d_fc1 = tf.layers.dense(z, 7*7*32, name='d_fc1') 33 | d_fc1 = lrelu(d_fc1) 34 | print 'd_fc1: ', d_fc1 35 | 36 | e_transpose_conv1 = tf.layers.conv2d_transpose(d_fc1, 16, 5, strides=2, name='e_transpose_conv1') 37 | e_transpose_conv1 = lrelu(e_transpose_conv1) 38 | print 'e_transpose_conv1: ', e_transpose_conv1 39 | 40 | e_transpose_conv2 = tf.layers.conv2d_transpose(e_transpose_conv1, 1, 5, strides=2, name='e_transpose_conv2') 41 | e_transpose_conv2 = tf.nn.sigmoid(e_transpose_conv2) 42 | e_transpose_conv2 = e_transpose_conv2[:,:28,:28,:] 43 | print 'e_transpose_conv2: ', e_transpose_conv2 44 | 45 | return e_transpose_conv2 46 | 47 | 48 | def train(mnist_train, mnist_test): 49 | with tf.Graph().as_default(): 50 | global_step = tf.Variable(0, trainable=False, name='global_step') 51 | 52 | # placeholder for mnist images 53 | images = tf.placeholder(tf.float32, [batch_size, 28*28, 1]) 54 | 55 | # encode images to 8 dim vector 56 | z_mean, z_stddev = encoder(images) 57 | 58 | samples = tf.random_normal([batch_size, 32],0,1,dtype=tf.float32) 59 | 60 | z_pred = z_mean + (z_stddev * samples) 61 | 62 | decoded = decoder(z_pred) 63 | 64 | #reconstructed_loss = -tf.reduce_sum(images*tf.log(1e-10 + decoded) + (1-images)*tf.log(1e-10+1-decoded), 1) 65 | reconstructed_loss = tf.nn.l2_loss(images-decoded) 66 | latent_loss = 0.5*tf.reduce_sum(tf.square(z_mean) + tf.square(z_stddev) - tf.log(tf.square(z_stddev))-1,1) 67 | 68 | cost = tf.reduce_mean(reconstructed_loss+latent_loss) 69 | 70 | train_op = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(cost) 71 | 72 | # saver for the model 73 | saver = tf.train.Saver(tf.all_variables()) 74 | 75 | init = tf.initialize_all_variables() 76 | sess = tf.Session() 77 | sess.run(init) 78 | 79 | try: os.mkdir('images/') 80 | except: pass 81 | try: os.mkdir('checkpoint/') 82 | except: pass 83 | 84 | ckpt = tf.train.get_checkpoint_state('checkpoint/') 85 | if ckpt and ckpt.model_checkpoint_path: 86 | try: 87 | saver.restore(sess, ckpt.model_checkpoint_path) 88 | print 'Model restored' 89 | except: 90 | print 'Could not restore model' 91 | pass 92 | 93 | step = 0 94 | while True: 95 | step += 1 96 | 97 | # get random images from the training set 98 | batch_images = random.sample(mnist_train, batch_size) 99 | 100 | # send through the network 101 | _, loss_ = sess.run([train_op, cost], feed_dict={images: batch_images}) 102 | loss_ = sess.run([cost], feed_dict={images:batch_images})[0] 103 | print 'Step: ' + str(step) + ' Loss: ' + str(loss_) 104 | 105 | if step%1000 == 0: 106 | print 107 | print 'Saving model' 108 | print 109 | saver.save(sess, "checkpoint/checkpoint", global_step=global_step) 110 | 111 | # get random images from the test set 112 | batch_images = random.sample(mnist_test, batch_size) 113 | 114 | # encode them using the encoder, then decode them 115 | encode_decode = sess.run(decoded, feed_dict={images: batch_images}) 116 | 117 | # write out a few 118 | c = 0 119 | for real, dec in zip(batch_images, encode_decode): 120 | dec, real = np.squeeze(dec), np.squeeze(real) 121 | plt.imsave('images/'+str(step)+'_'+str(c)+'real.png', real) 122 | plt.imsave('images/'+str(step)+'_'+str(c)+'dec.png', dec) 123 | if c == 5: 124 | break 125 | c+=1 126 | 127 | def main(argv=None): 128 | # mnist data in gz format 129 | url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz' 130 | 131 | # check if it's already downloaded 132 | if not os.path.isfile('mnist.pkl.gz'): 133 | print 'Downloading mnist...' 134 | with open('mnist.pkl.gz', 'wb') as f: 135 | r = requests.get(url) 136 | if r.status_code == 200: 137 | f.write(r.content) 138 | else: 139 | print 'Could not connect to ', url 140 | 141 | print 'opening mnist' 142 | f = gzip.open('mnist.pkl.gz', 'rb') 143 | train_set, val_set, test_set = pickle.load(f) 144 | 145 | mnist_train = [] 146 | mnist_test = [] 147 | 148 | print 'Reading mnist...' 149 | # reshape mnist to make it easier for understanding convs 150 | for t,l in zip(*train_set): 151 | mnist_train.append(np.reshape(t, (28*28,1))) 152 | for t,l in zip(*val_set): 153 | mnist_train.append(np.reshape(t, (28*28,1))) 154 | for t,l in zip(*test_set): 155 | mnist_test.append(np.reshape(t, (28*28,1))) 156 | 157 | mnist_train = np.asarray(mnist_train) 158 | mnist_test = np.asarray(mnist_test) 159 | 160 | train(mnist_train, mnist_test) 161 | 162 | if __name__ == '__main__': 163 | tf.app.run() 164 | -------------------------------------------------------------------------------- /vae.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import cPickle as pickle 3 | import tensorflow as tf 4 | import numpy as np 5 | import requests 6 | import random 7 | import gzip 8 | import os 9 | 10 | batch_size = 256 11 | 12 | ''' 13 | Kullback Leibler divergence 14 | https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 15 | https://github.com/fastforwardlabs/vae-tf/blob/master/vae.py#L178 16 | ''' 17 | def kullbackleibler(mu, log_sigma): 18 | return -0.5*tf.reduce_sum(1+2*log_sigma-mu**2-tf.exp(2*log_sigma),1) 19 | 20 | 21 | ''' 22 | Leaky RELU 23 | ''' 24 | def lrelu(x, leak=0.2, name="lrelu"): 25 | return tf.maximum(x, leak*x) 26 | 27 | 28 | def encoder(x): 29 | 30 | e_conv1 = tf.layers.conv2d(x, 16, 5, strides=2, name='e_conv1') 31 | e_conv1 = lrelu(e_conv1) 32 | print 'conv1: ', e_conv1 33 | 34 | e_conv2 = tf.layers.conv2d(e_conv1, 32, 5, strides=2, name='e_conv2') 35 | e_conv2 = lrelu(e_conv2) 36 | print 'conv2: ', e_conv2 37 | 38 | e_conv2_flat = tf.reshape(e_conv2, [batch_size, -1]) 39 | 40 | ''' 41 | z is distributed as a multivariate normal with mean z_mean and diagonal 42 | covariance values sigma^2 43 | 44 | z ~ N(z_mean, np.exp(z_log_sigma)**2) 45 | ''' 46 | z_mean = tf.layers.dense(e_conv2_flat, 32, name='mean') 47 | z_log_sigma = tf.layers.dense(e_conv2_flat, 32, name='stddev') 48 | 49 | 50 | return z_mean, z_log_sigma 51 | 52 | def decoder(z): 53 | print 54 | print 'z: ', z 55 | 56 | d_fc1 = tf.layers.dense(z, 7*7*32, name='d_fc1') 57 | d_fc1 = lrelu(d_fc1) 58 | print 'd_fc1: ', d_fc1 59 | d_fc1 = tf.reshape(d_fc1, [batch_size, 7,7,32]) 60 | 61 | e_transpose_conv1 = tf.layers.conv2d_transpose(d_fc1, 16, 5, strides=2, name='e_transpose_conv1') 62 | e_transpose_conv1 = lrelu(e_transpose_conv1) 63 | print 'e_transpose_conv1: ', e_transpose_conv1 64 | 65 | e_transpose_conv2 = tf.layers.conv2d_transpose(e_transpose_conv1, 1, 5, strides=2, name='e_transpose_conv2') 66 | e_transpose_conv2 = tf.nn.sigmoid(e_transpose_conv2) 67 | e_transpose_conv2 = e_transpose_conv2[:,:28,:28,:] 68 | print 'e_transpose_conv2: ', e_transpose_conv2 69 | 70 | return e_transpose_conv2 71 | 72 | 73 | def train(mnist_train, mnist_test): 74 | with tf.Graph().as_default(): 75 | global_step = tf.Variable(0, trainable=False, name='global_step') 76 | 77 | # placeholder for mnist images 78 | images = tf.placeholder(tf.float32, [batch_size, 28, 28, 1]) 79 | 80 | # encode images to 8 dim vector 81 | z_mean, z_log_sigma = encoder(images) 82 | 83 | # reparameterization trick 84 | epsilon = tf.random_normal(tf.shape(z_log_sigma), name='epsilon') 85 | z = z_mean + epsilon * tf.exp(z_log_sigma) # N(mu, sigma**2) 86 | 87 | decoded = decoder(z) 88 | 89 | #reconstructed_loss = -tf.reduce_sum(images*tf.log(1e-10 + decoded) + (1-images)*tf.log(1e-10+1-decoded), 1) 90 | reconstructed_loss = tf.nn.l2_loss(images-decoded) 91 | latent_loss = kullbackleibler(z_mean, z_log_sigma) 92 | 93 | cost = tf.reduce_mean(reconstructed_loss+latent_loss) 94 | 95 | train_op = tf.train.AdamOptimizer(learning_rate=1e-3).minimize(cost) 96 | 97 | # saver for the model 98 | saver = tf.train.Saver(tf.all_variables()) 99 | 100 | init = tf.initialize_all_variables() 101 | sess = tf.Session() 102 | sess.run(init) 103 | 104 | try: os.makedirs('checkpoint/images/') 105 | except: pass 106 | 107 | ckpt = tf.train.get_checkpoint_state('checkpoint/') 108 | if ckpt and ckpt.model_checkpoint_path: 109 | try: 110 | saver.restore(sess, ckpt.model_checkpoint_path) 111 | print 'Model restored' 112 | except: 113 | print 'Could not restore model' 114 | pass 115 | 116 | step = 0 117 | while True: 118 | step += 1 119 | 120 | # get random images from the training set 121 | batch_images = random.sample(mnist_train, batch_size) 122 | 123 | # send through the network 124 | _, loss_ = sess.run([train_op, cost], feed_dict={images: batch_images}) 125 | loss_ = sess.run([cost], feed_dict={images:batch_images})[0] 126 | print 'Step: ' + str(step) + ' Loss: ' + str(loss_) 127 | 128 | if step%500 == 0: 129 | print 130 | print 'Saving model' 131 | print 132 | saver.save(sess, "checkpoint/checkpoint", global_step=global_step) 133 | 134 | # get random images from the test set 135 | batch_images = random.sample(mnist_test, batch_size) 136 | 137 | # encode them using the encoder, then decode them 138 | encode_decode = sess.run(decoded, feed_dict={images: batch_images}) 139 | 140 | # write out a few 141 | c = 0 142 | for real, dec in zip(batch_images, encode_decode): 143 | dec, real = np.squeeze(dec), np.squeeze(real) 144 | plt.imsave('checkpoint/images/'+str(step)+'_'+str(c)+'real.png', real) 145 | plt.imsave('checkpoint/images/'+str(step)+'_'+str(c)+'dec.png', dec) 146 | if c == 5: 147 | break 148 | c+=1 149 | 150 | def main(argv=None): 151 | # mnist data in gz format 152 | url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz' 153 | 154 | # check if it's already downloaded 155 | if not os.path.isfile('mnist.pkl.gz'): 156 | print 'Downloading mnist...' 157 | with open('mnist.pkl.gz', 'wb') as f: 158 | r = requests.get(url) 159 | if r.status_code == 200: 160 | f.write(r.content) 161 | else: 162 | print 'Could not connect to ', url 163 | 164 | print 'opening mnist' 165 | f = gzip.open('mnist.pkl.gz', 'rb') 166 | train_set, val_set, test_set = pickle.load(f) 167 | 168 | mnist_train = [] 169 | mnist_test = [] 170 | 171 | print 'Reading mnist...' 172 | # reshape mnist to make it easier for understanding convs 173 | for t,l in zip(*train_set): 174 | mnist_train.append(np.reshape(t, (28,28,1))) 175 | for t,l in zip(*val_set): 176 | mnist_train.append(np.reshape(t, (28,28,1))) 177 | for t,l in zip(*test_set): 178 | mnist_test.append(np.reshape(t, (28,28,1))) 179 | 180 | mnist_train = np.asarray(mnist_train) 181 | mnist_test = np.asarray(mnist_test) 182 | 183 | train(mnist_train, mnist_test) 184 | 185 | if __name__ == '__main__': 186 | tf.app.run() 187 | -------------------------------------------------------------------------------- /autoencoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow.contrib.slim as slim 2 | import matplotlib.pyplot as plt 3 | import cPickle as pickle 4 | import tensorflow as tf 5 | import numpy as np 6 | import requests 7 | import random 8 | import time 9 | import gzip 10 | import os 11 | 12 | batch_size = 5000 13 | 14 | ''' 15 | Leaky RELU 16 | ''' 17 | def lrelu(x, leak=0.2, name="lrelu"): 18 | return tf.maximum(x, leak*x) 19 | 20 | 21 | def encoder(x): 22 | 23 | e_conv1 = slim.convolution(x, 32, 2, stride=2, activation_fn=tf.identity, normalizer_fn=slim.batch_norm, scope='e_conv1') 24 | e_conv1 = lrelu(e_conv1) 25 | print 'conv1: ', e_conv1 26 | 27 | e_conv2 = slim.convolution(e_conv1, 64, 2, stride=2, activation_fn=tf.identity, normalizer_fn=slim.batch_norm, scope='e_conv2') 28 | e_conv2 = lrelu(e_conv2) 29 | print 'conv2: ', e_conv2 30 | 31 | # convolutional layer with a leaky Relu activation 32 | e_conv3 = slim.convolution(e_conv2, 128, 2, stride=2, activation_fn=tf.identity, normalizer_fn=slim.batch_norm, scope='e_conv3') 33 | e_conv3 = lrelu(e_conv3) 34 | print 'conv3: ', e_conv3 35 | 36 | e_conv3_flat = tf.reshape(e_conv3, [batch_size, -1]) 37 | 38 | e_fc1 = slim.fully_connected(e_conv3_flat, 256, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='e_fc1') 39 | e_fc1 = lrelu(e_fc1) 40 | print 'fc1: ', e_fc1 41 | 42 | e_fc2 = slim.fully_connected(e_fc1, 64, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='e_fc2') 43 | e_fc2 = lrelu(e_fc2) 44 | print 'fc2: ', e_fc2 45 | 46 | e_fc3 = slim.fully_connected(e_fc2, 32, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='e_fc3') 47 | e_fc3 = lrelu(e_fc3) 48 | print 'fc3: ', e_fc3 49 | 50 | e_fc4 = slim.fully_connected(e_fc3, 8, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='e_fc4') 51 | e_fc4 = lrelu(e_fc4) 52 | print 'fc4: ', e_fc4 53 | return e_fc4 54 | 55 | def decoder(x): 56 | print 57 | print 'x: ', x 58 | 59 | d_fc1 = slim.fully_connected(x, 32, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='d_fc1') 60 | d_fc1 = lrelu(d_fc1) 61 | print 'd_fc1: ', d_fc1 62 | 63 | d_fc2 = slim.fully_connected(d_fc1, 64, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='d_fc2') 64 | d_fc2 = lrelu(d_fc2) 65 | print 'd_fc2: ', d_fc2 66 | 67 | d_fc3 = slim.fully_connected(d_fc2, 256, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='d_fc3') 68 | d_fc3 = lrelu(d_fc3) 69 | print 'd_fc3: ', d_fc3 70 | 71 | d_fc3 = tf.reshape(d_fc3, [batch_size, 4, 4, 16]) 72 | print 'd_fc3: ', d_fc3 73 | 74 | e_transpose_conv1 = slim.convolution2d_transpose(d_fc3, 64, 2, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='e_transpose_conv1') 75 | e_transpose_conv1 = lrelu(e_transpose_conv1) 76 | print 'e_transpose_conv1: ', e_transpose_conv1 77 | 78 | e_transpose_conv2 = slim.convolution2d_transpose(e_transpose_conv1, 32, 2, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='e_transpose_conv2') 79 | e_transpose_conv2 = lrelu(e_transpose_conv2) 80 | print 'e_transpose_conv2: ', e_transpose_conv2 81 | 82 | e_transpose_conv3 = slim.convolution2d_transpose(e_transpose_conv2, 1, 2, stride=2, normalizer_fn=slim.batch_norm, activation_fn=tf.identity, scope='e_transpose_conv3') 83 | e_transpose_conv3 = lrelu(e_transpose_conv3) 84 | e_transpose_conv3 = e_transpose_conv3[:,:28,:28,:] 85 | print 'e_transpose_conv3: ', e_transpose_conv3 86 | return e_transpose_conv3 87 | 88 | 89 | def train(mnist_train, mnist_test): 90 | with tf.Graph().as_default(): 91 | global_step = tf.Variable(0, trainable=False, name='global_step') 92 | 93 | # placeholder for mnist images 94 | images = tf.placeholder(tf.float32, [batch_size, 28, 28, 1]) 95 | 96 | # encode images to 128 dim vector 97 | encoded = encoder(images) 98 | 99 | # decode 128 dim vector to (28,28) dim image 100 | decoded = decoder(encoded) 101 | 102 | loss = tf.nn.l2_loss(images - decoded) 103 | 104 | train_op = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss) 105 | 106 | # saver for the model 107 | saver = tf.train.Saver(tf.all_variables()) 108 | 109 | init = tf.initialize_all_variables() 110 | sess = tf.Session() 111 | sess.run(init) 112 | 113 | try: 114 | os.mkdir('images/') 115 | except: 116 | pass 117 | try: 118 | os.mkdir('checkpoint/') 119 | except: 120 | pass 121 | 122 | ckpt = tf.train.get_checkpoint_state('checkpoint/') 123 | if ckpt and ckpt.model_checkpoint_path: 124 | try: 125 | saver.restore(sess, ckpt.model_checkpoint_path) 126 | print 'Model restored' 127 | except: 128 | print 'Could not restore model' 129 | pass 130 | 131 | step = 0 132 | while True: 133 | step += 1 134 | 135 | # get random images from the training set 136 | batch_images = random.sample(mnist_train, batch_size) 137 | 138 | # send through the network 139 | s = time.time() 140 | _, loss_ = sess.run([train_op, loss], feed_dict={images: batch_images}) 141 | t = time.time()-s 142 | print 'Step: ' + str(step) + ' Loss: ' + str(loss_) + ' time: ' + str(t) 143 | 144 | if step%100 == 0: 145 | print 146 | print 'Saving model' 147 | print 148 | saver.save(sess, "checkpoint/checkpoint", global_step=global_step) 149 | 150 | # get random images from the test set 151 | batch_images = random.sample(mnist_test, batch_size) 152 | 153 | # encode them using the encoder, then decode them 154 | encode_decode = sess.run(decoded, feed_dict={images: batch_images}) 155 | 156 | # write out a few 157 | c = 0 158 | for real, dec in zip(batch_images, encode_decode): 159 | dec, real = np.squeeze(dec), np.squeeze(real) 160 | plt.imsave('images/'+str(step)+'_'+str(c)+'real.png', real) 161 | plt.imsave('images/'+str(step)+'_'+str(c)+'dec.png', dec) 162 | if c == 5: 163 | break 164 | c+=1 165 | 166 | def main(argv=None): 167 | # mnist data in gz format 168 | url = 'http://deeplearning.net/data/mnist/mnist.pkl.gz' 169 | 170 | # check if it's already downloaded 171 | if not os.path.isfile('mnist.pkl.gz'): 172 | print 'Downloading mnist...' 173 | with open('mnist.pkl.gz', 'wb') as f: 174 | r = requests.get(url) 175 | if r.status_code == 200: 176 | f.write(r.content) 177 | else: 178 | print 'Could not connect to ', url 179 | 180 | print 'opening mnist' 181 | f = gzip.open('mnist.pkl.gz', 'rb') 182 | train_set, val_set, test_set = pickle.load(f) 183 | 184 | mnist_train = [] 185 | mnist_test = [] 186 | 187 | print 'Reading mnist...' 188 | # reshape mnist to make it easier for understanding convs 189 | for t,l in zip(*train_set): 190 | mnist_train.append(np.reshape(t, (28,28,1))) 191 | for t,l in zip(*val_set): 192 | mnist_train.append(np.reshape(t, (28,28,1))) 193 | for t,l in zip(*test_set): 194 | mnist_test.append(np.reshape(t, (28,28,1))) 195 | 196 | mnist_train = np.asarray(mnist_train) 197 | mnist_test = np.asarray(mnist_test) 198 | 199 | train(mnist_train, mnist_test) 200 | 201 | if __name__ == '__main__': 202 | tf.app.run() 203 | --------------------------------------------------------------------------------