├── MNIST.py ├── MNISTTester.py ├── MNISTTrainer.py ├── README.md ├── TFUtils.py ├── imgs ├── digit-2.png ├── digit-4.png └── digit-5.png ├── logs └── mnist-cnn │ └── events.out.tfevents.1482385321.mpr ├── mnist └── data │ ├── t10k-images-idx3-ubyte.gz │ ├── t10k-labels-idx1-ubyte.gz │ ├── train-images-idx3-ubyte.gz │ └── train-labels-idx1-ubyte.gz ├── models ├── checkpoint ├── mnist-cnn.data-00000-of-00001 ├── mnist-cnn.index └── mnist-cnn.meta ├── test.py └── train.py /MNIST.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.examples.tutorials.mnist import input_data 4 | from TFUtils import TFUtils 5 | 6 | 7 | # MNIST base class 8 | # main purpose is building cnn model 9 | # can add other models 10 | class MNIST: 11 | model_path = None 12 | data_path = None 13 | 14 | sess = None 15 | model = None 16 | mnist = None 17 | 18 | X = tf.placeholder(tf.float32, [None, 28, 28, 1]) 19 | Y = tf.placeholder(tf.float32, [None, 10]) 20 | 21 | p_keep_conv = tf.placeholder(tf.float32) 22 | p_keep_hidden = tf.placeholder(tf.float32) 23 | 24 | def __init__(self, model_path=None, data_path=None): 25 | self.model_path = model_path 26 | self.data_path = data_path 27 | 28 | def init_session(self): 29 | init = tf.global_variables_initializer() 30 | self.sess = tf.Session() 31 | self.sess.run(init) 32 | 33 | def print_status(self, text): 34 | print '---' 35 | print text 36 | 37 | def load_training_data(self, data_path): 38 | print 'Preparing MNIST data..' 39 | 40 | self.mnist = input_data.read_data_sets(data_path, one_hot=True) 41 | 42 | def build_feed_dict(self, X, Y, p_keep_conv=1., p_keep_hidden=1.): 43 | return { 44 | self.X: X, 45 | self.Y: Y, 46 | self.p_keep_conv: p_keep_conv, 47 | self.p_keep_hidden: p_keep_hidden 48 | } 49 | 50 | # define model 51 | def build_cnn_model(self, p_keep_conv=1., p_keep_hidden=1.): 52 | W1 = TFUtils.xavier_init([3, 3, 1, 32], 'W1') 53 | W2 = TFUtils.xavier_init([3, 3, 32, 64], 'W2') 54 | W3 = TFUtils.xavier_init([3, 3, 64, 128], 'W3') 55 | W4 = TFUtils.xavier_init([128 * 4 * 4, 625], 'W4') 56 | W5 = TFUtils.xavier_init([625, 10], 'W5') 57 | 58 | with tf.name_scope('layer1') as scope: 59 | # L1 Conv shape=(?, 28, 28, 32) 60 | # Pool ->(?, 14, 14, 32) 61 | L1 = TFUtils.build_cnn_layer(self.X, W1, p_keep_conv) 62 | with tf.name_scope('layer2') as scope: 63 | # L2 Conv shape=(?, 14, 14, 64) 64 | # Pool ->(?, 7, 7, 64) 65 | L2 = TFUtils.build_cnn_layer(L1, W2, p_keep_conv) 66 | with tf.name_scope('layer3') as scope: 67 | # L3 Conv shape=(?, 7, 7, 128) 68 | # Pool ->(?, 4, 4, 128) 69 | # Reshape ->(?, 625) 70 | reshape = [-1, W4.get_shape().as_list()[0]] 71 | L3 = TFUtils.build_cnn_layer(L2, W3, p_keep_conv, reshape=reshape) 72 | with tf.name_scope('layer4') as scope: 73 | # L4 FC 4x4x128 inputs -> 625 outputs 74 | L4 = tf.nn.relu(tf.matmul(L3, W4)) 75 | L4 = tf.nn.dropout(L4, p_keep_hidden) 76 | 77 | # Output(labels) FC 625 inputs -> 10 outputs 78 | self.model = tf.matmul(L4, W5, name='model') 79 | 80 | return self.model 81 | 82 | def save_model(self): 83 | if self.model_path is not None: 84 | self.print_status('Saving my model..') 85 | 86 | saver = tf.train.Saver(tf.global_variables()) 87 | saver.save(self.sess, self.model_path) 88 | 89 | def load_model(self): 90 | self.build_cnn_model() 91 | 92 | saver = tf.train.Saver() 93 | saver.restore(self.sess, self.model_path) 94 | 95 | def check_accuracy(self, test_feed_dict=None): 96 | check_prediction = tf.equal(tf.argmax(self.model, 1), tf.argmax(self.Y, 1)) 97 | accuracy = tf.reduce_mean(tf.cast(check_prediction, tf.float32)) 98 | accuracy_rates = self.sess.run(accuracy, feed_dict=test_feed_dict) 99 | 100 | return accuracy_rates 101 | -------------------------------------------------------------------------------- /MNISTTester.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | 5 | from PIL import Image, ImageFilter 6 | from random import randint 7 | from matplotlib import pyplot as plt 8 | from MNIST import MNIST 9 | 10 | 11 | # MNIST Tester class 12 | # check accuracy of test set 13 | # predict random number from test set 14 | # predict number from image 15 | class MNISTTester(MNIST): 16 | def __init__(self, model_path=None, data_path=None): 17 | MNIST.__init__(self, model_path, data_path) 18 | 19 | self.init() 20 | 21 | def init(self): 22 | self.print_status('Loading a model..') 23 | 24 | self.init_session() 25 | 26 | self.load_model() 27 | 28 | if self.data_path is not None: 29 | self.load_training_data(self.data_path) 30 | 31 | def classify(self, feed_dict): 32 | number = self.sess.run(tf.argmax(self.model, 1), feed_dict)[0] 33 | accuracy = self.sess.run(tf.nn.softmax(self.model), feed_dict)[0] 34 | 35 | return number, accuracy[number] 36 | 37 | def accuracy_of_testset(self): 38 | self.print_status('Calculating accuracy of test set..') 39 | 40 | X = self.mnist.test.images.reshape(-1, 28, 28, 1) 41 | Y = self.mnist.test.labels 42 | test_feed_dict = self.build_feed_dict(X, Y) 43 | 44 | accuracy = self.check_accuracy(test_feed_dict) 45 | 46 | self.print_status('CNN accuracy of test set: %f' % accuracy) 47 | 48 | def predict_random(self, show_image=False): 49 | num = randint(0, self.mnist.test.images.shape[0]) 50 | image = self.mnist.test.images[num] 51 | label = self.mnist.test.labels[num] 52 | 53 | feed_dict = self.build_feed_dict(image.reshape(-1, 28, 28, 1), [label]) 54 | 55 | (number, accuracy) = self.classify(feed_dict) 56 | label = self.sess.run(tf.argmax(label, 0)) 57 | 58 | self.print_status('Predict random item: %d is %d, accuracy: %f' % 59 | (label, number, accuracy)) 60 | 61 | if show_image is True: 62 | plt.imshow(image.reshape(28, 28)) 63 | plt.show() 64 | 65 | def predict(self, filename): 66 | data = self.load_image(filename) 67 | 68 | number, accuracy = self.classify({self.X: data}) 69 | 70 | self.print_status('%d is %s, accuracy: %f' % (number, os.path.basename(filename), accuracy)) 71 | 72 | def load_image(self, filename): 73 | img = Image.open(filename).convert('L') 74 | 75 | # resize to 28x28 76 | img = img.resize((28, 28), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) 77 | 78 | # normalization : 255 RGB -> 0, 1 79 | data = [(255 - x) * 1.0 / 255.0 for x in list(img.getdata())] 80 | 81 | # reshape -> [-1, 28, 28, 1] 82 | return np.reshape(data, (-1, 28, 28, 1)).tolist() 83 | -------------------------------------------------------------------------------- /MNISTTrainer.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from MNIST import MNIST 4 | 5 | 6 | # MNIST trainer class 7 | # training with CNN model and write log.. 8 | # can use another model after adding in MNIST class 9 | class MNISTTrainer(MNIST): 10 | train_op = None 11 | summary = None 12 | writer = None 13 | test_feed_dict = None 14 | 15 | def __init__(self, data_path=None, model_path=None, log_path=None): 16 | MNIST.__init__(self, model_path, data_path) 17 | 18 | self.log_path = log_path 19 | 20 | if data_path is not None: 21 | self.load_training_data(data_path) 22 | 23 | def init_log(self): 24 | if self.log_path is not None: 25 | X = self.mnist.test.images.reshape(-1, 28, 28, 1) 26 | Y = self.mnist.test.labels 27 | self.test_feed_dict = self.build_feed_dict(X, Y) 28 | 29 | self.summary = tf.summary.merge_all() 30 | self.writer = tf.summary.FileWriter(self.log_path, self.sess.graph) 31 | 32 | def add_log(self, name, graph, type='histogram'): 33 | if self.log_path is not None: 34 | if type == 'scalar': 35 | tf.summary.scalar(name, graph) 36 | else: 37 | tf.summary.histogram(name, graph) 38 | 39 | def write_log(self, epoch): 40 | if self.log_path is not None: 41 | summary = self.sess.run(self.summary, feed_dict=self.test_feed_dict) 42 | self.writer.add_summary(summary, epoch) 43 | 44 | def print_accuracy(self, epoch): 45 | if self.log_path is not None: 46 | accuracy = self.check_accuracy(self.test_feed_dict) 47 | print 'Epoch:', '%04d' % (epoch + 1), '/ Accuracy =', accuracy 48 | 49 | def build_training_op(self, learning_rate, decay): 50 | with tf.name_scope('cost') as scope: 51 | cost = tf.reduce_mean( 52 | tf.nn.softmax_cross_entropy_with_logits(self.model, self.Y)) 53 | 54 | self.add_log('Y', self.Y) 55 | self.add_log('cost', cost, 'scalar') 56 | 57 | self.train_op = tf.train.RMSPropOptimizer(learning_rate, decay).minimize(cost) 58 | 59 | def training_once(self, batch_size, p_keep_conv, p_keep_hidden): 60 | total_batch = int(self.mnist.train.num_examples/batch_size) 61 | 62 | for step in range(total_batch): 63 | batch_xs, batch_ys = self.mnist.train.next_batch(batch_size) 64 | batch_xs = batch_xs.reshape(-1, 28, 28, 1) 65 | feed_dict = self.build_feed_dict(batch_xs, batch_ys, p_keep_conv, p_keep_hidden) 66 | 67 | self.sess.run(self.train_op, feed_dict=feed_dict) 68 | 69 | # training several times 70 | def training(self, 71 | learning_rate=0.001, 72 | decay=0.9, 73 | training_epochs=15, 74 | batch_size=100, 75 | p_keep_conv=1., 76 | p_keep_hidden=1.): 77 | 78 | self.print_status('Building CNN model..') 79 | 80 | self.build_cnn_model(p_keep_conv, p_keep_hidden) 81 | 82 | self.build_training_op(learning_rate, decay) 83 | 84 | self.print_status('Start training. Please be patient. :-)') 85 | 86 | self.init_session() 87 | 88 | # init summary for tensorboard 89 | self.init_log() 90 | 91 | # start training 92 | for epoch in range(training_epochs): 93 | self.training_once(batch_size, p_keep_conv, p_keep_hidden) 94 | 95 | self.write_log(epoch) 96 | 97 | self.print_accuracy(epoch) 98 | 99 | # TODO: save the best model only 100 | self.save_model() 101 | 102 | self.print_status('Learning Finished!') 103 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MNIST with TensorFlow 2 | 3 | - Just for training TensorFlow and Deep Learning 4 | - Try to make easy to understand building layers and using TensorFlow 5 | - write summaries for TensorBoard 6 | - save and load a model and reuse for prediction 7 | - Pre-trained model with default options is included 8 | - you can test prediction and TensorBoard without any hassle 9 | 10 | ## Class 11 | 12 | - **MNIST** : building model (currently CNN only) 13 | - **MNISTTrainer** : training logic and steps 14 | - **MNISTTester** : test trained model and an image 15 | - **TFUtils** : Xavier initialization and a small utilities for my laziness 16 | 17 | ## Excutable Scripts 18 | 19 | - **train.py** : can use below options 20 | - learning_rate=0.001 21 | - decay=0.9 22 | - training_epochs=10 23 | - batch_size=100 24 | - p_keep_conv=0.8 25 | - p_keep_hidden=0.5 26 | - **test.py** 27 | - prediction test with MNIST test set 28 | - prediction test with image file 29 | - only for square images and single number 30 | - size is not matter 31 | 32 | ## Results 33 | 34 | ``` 35 | ➜ TensorFlow-MNIST# python train.py 36 | Preparing MNIST data.. 37 | Extracting mnist/data/train-images-idx3-ubyte.gz 38 | Extracting mnist/data/train-labels-idx1-ubyte.gz 39 | Extracting mnist/data/t10k-images-idx3-ubyte.gz 40 | Extracting mnist/data/t10k-labels-idx1-ubyte.gz 41 | --- 42 | Building CNN model.. 43 | --- 44 | Start training. Please be patient. :-) 45 | Epoch: 0001 / Accuracy = 0.9511 46 | Epoch: 0002 / Accuracy = 0.9634 47 | ... 48 | --- 49 | Saving my model.. 50 | --- 51 | Learning Finished! 52 | ``` 53 | 54 | ``` 55 | ➜ TensorFlow-MNIST# python test.py 56 | --- 57 | Loading a model.. 58 | Preparing MNIST data.. 59 | Extracting mnist/data/train-images-idx3-ubyte.gz 60 | Extracting mnist/data/train-labels-idx1-ubyte.gz 61 | Extracting mnist/data/t10k-images-idx3-ubyte.gz 62 | Extracting mnist/data/t10k-labels-idx1-ubyte.gz 63 | --- 64 | Calculating accuracy of test set.. 65 | --- 66 | CNN accuracy of test set: 0.993600 67 | --- 68 | Predict random item: 5 is 5, accuracy: 1.000 69 | --- 70 | 4 is digit-4.png, accuracy: 1.000000 71 | --- 72 | 2 is digit-2.png, accuracy: 1.000000 73 | --- 74 | 5 is digit-5.png, accuracy: 0.997631 75 | ``` 76 | 77 | ``` 78 | ➜ TensorFlow-MNIST# tensorboard --logdir=logs/mnist-cnn 79 | ``` 80 | -------------------------------------------------------------------------------- /TFUtils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | # Utility packages for lazy me 5 | class TFUtils: 6 | def __init__(self): 7 | return 8 | 9 | # Xavier initialization 10 | @staticmethod 11 | def xavier_init(shape, name='', uniform=True): 12 | num_input = sum(shape[:-1]) 13 | num_output = shape[-1] 14 | 15 | if uniform: 16 | init_range = tf.sqrt(6.0 / (num_input + num_output)) 17 | init_value = tf.random_uniform_initializer(-init_range, init_range) 18 | else: 19 | stddev = tf.sqrt(3.0 / (num_input + num_output)) 20 | init_value = tf.truncated_normal_initializer(stddev=stddev) 21 | 22 | return tf.get_variable(name, shape=shape, initializer=init_value) 23 | 24 | @staticmethod 25 | def conv2d(X, W, strides=None, padding='SAME'): 26 | if strides is None: 27 | strides = [1, 1, 1, 1] 28 | 29 | return tf.nn.conv2d(X, W, strides=strides, padding=padding) 30 | 31 | @staticmethod 32 | def max_pool(X, ksize=None, strides=None, padding='SAME'): 33 | if ksize is None: 34 | ksize = [1, 2, 2, 1] 35 | 36 | if strides is None: 37 | strides = [1, 2, 2, 1] 38 | 39 | return tf.nn.max_pool(X, ksize=ksize, strides=strides, padding=padding) 40 | 41 | @staticmethod 42 | def build_cnn_layer(X, W, p_dropout=1., pool=True, reshape=None): 43 | L = tf.nn.relu(TFUtils.conv2d(X, W)) 44 | 45 | if pool is True: 46 | L = TFUtils.max_pool(L) 47 | 48 | if reshape is not None: 49 | L = tf.reshape(L, reshape) 50 | 51 | if p_dropout == 1: 52 | return L 53 | else: 54 | return tf.nn.dropout(L, p_dropout) 55 | -------------------------------------------------------------------------------- /imgs/digit-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/imgs/digit-2.png -------------------------------------------------------------------------------- /imgs/digit-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/imgs/digit-4.png -------------------------------------------------------------------------------- /imgs/digit-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/imgs/digit-5.png -------------------------------------------------------------------------------- /logs/mnist-cnn/events.out.tfevents.1482385321.mpr: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/logs/mnist-cnn/events.out.tfevents.1482385321.mpr -------------------------------------------------------------------------------- /mnist/data/t10k-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/mnist/data/t10k-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /mnist/data/t10k-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/mnist/data/t10k-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /mnist/data/train-images-idx3-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/mnist/data/train-images-idx3-ubyte.gz -------------------------------------------------------------------------------- /mnist/data/train-labels-idx1-ubyte.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/mnist/data/train-labels-idx1-ubyte.gz -------------------------------------------------------------------------------- /models/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "/Users/golbin/Documents/Working/tensorflow/TensorFlow-MNIST/models/mnist-cnn" 2 | all_model_checkpoint_paths: "/Users/golbin/Documents/Working/tensorflow/TensorFlow-MNIST/models/mnist-cnn" 3 | -------------------------------------------------------------------------------- /models/mnist-cnn.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/models/mnist-cnn.data-00000-of-00001 -------------------------------------------------------------------------------- /models/mnist-cnn.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/models/mnist-cnn.index -------------------------------------------------------------------------------- /models/mnist-cnn.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/575b78720ffc21211b48eb905cb827f40a142354/models/mnist-cnn.meta -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from MNISTTester import MNISTTester 4 | 5 | #################### 6 | # directory settings 7 | script_dir = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | data_path = script_dir + '/mnist/data/' 10 | model_path = script_dir + '/models/mnist-cnn' 11 | 12 | ##################################### 13 | # prediction test with MNIST test set 14 | mnist = MNISTTester( 15 | model_path=model_path, 16 | data_path=data_path) 17 | 18 | mnist.accuracy_of_testset() 19 | mnist.predict_random() 20 | 21 | ################################# 22 | # prediction test with image file 23 | # mnist = MNISTTester(model_path) 24 | mnist.predict(script_dir + '/imgs/digit-4.png') 25 | mnist.predict(script_dir + '/imgs/digit-2.png') 26 | mnist.predict(script_dir + '/imgs/digit-5.png') 27 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from MNISTTrainer import MNISTTrainer 4 | 5 | 6 | #################### 7 | # directory settings 8 | script_dir = os.path.dirname(os.path.abspath(__file__)) 9 | 10 | data_path = script_dir + '/mnist/data/' 11 | model_path = script_dir + '/models/mnist-cnn' 12 | log_path = script_dir + '/logs/mnist-cnn' 13 | 14 | ########## 15 | # training 16 | mnist = MNISTTrainer( 17 | data_path=data_path, 18 | model_path=model_path, 19 | log_path=log_path) 20 | 21 | mnist.training( 22 | learning_rate=0.001, 23 | decay=0.9, 24 | training_epochs=10, 25 | batch_size=100, 26 | p_keep_conv=0.8, 27 | p_keep_hidden=0.5) 28 | --------------------------------------------------------------------------------