├── LICENSE ├── README.md ├── kaggle_mnist_alexnet.py ├── kaggle_mnist_alexnet_model.py ├── kaggle_mnist_input.py ├── ops.py └── simple_kaggle_mnist_alexnet.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Jireh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Alexnet using Tensorflow 2 | Implemetations of alexnet using Tensorflow on mnist. 3 | 4 | ## For train 5 | python kaggle_mnist_alexnet.py 6 | 7 | ## For test 8 | python kaggle_mnist_alexnet.py --is_train=False 9 | -------------------------------------------------------------------------------- /kaggle_mnist_alexnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import kaggle_mnist_alexnet_model as model 4 | import kaggle_mnist_input as loader 5 | import time 6 | import csv 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | 10 | tf.app.flags.DEFINE_integer('training_epoch', 30, "training epoch") 11 | tf.app.flags.DEFINE_integer('batch_size', 128, "batch size") 12 | tf.app.flags.DEFINE_integer('validation_interval', 100, "validation interval") 13 | 14 | tf.app.flags.DEFINE_float('dropout_keep_prob', 0.5, "dropout keep prob") 15 | tf.app.flags.DEFINE_float('learning_rate', 0.001, "learning rate") 16 | tf.app.flags.DEFINE_float('rms_decay', 0.9, "rms optimizer decay") 17 | tf.app.flags.DEFINE_float('weight_decay', 0.0005, "l2 regularization weight decay") 18 | tf.app.flags.DEFINE_string('train_path', '/tmp/train.csv', "path to download training data") 19 | tf.app.flags.DEFINE_string('test_path', '/tmp/test.csv', "path to download test data") 20 | tf.app.flags.DEFINE_integer('validation_size', 2000, "validation size in training data") 21 | tf.app.flags.DEFINE_string('save_name', os.getcwd() + '/var.ckpt', "path to save variables") 22 | tf.app.flags.DEFINE_boolean('is_train', True, "True for train, False for test") 23 | tf.app.flags.DEFINE_string('test_result', 'result.csv', "test file path") 24 | 25 | image_size = 28 26 | image_channel = 1 27 | label_cnt = 10 28 | 29 | 30 | def train(): 31 | # build graph 32 | inputs, labels, dropout_keep_prob, learning_rate = model.input_placeholder(image_size, image_channel, label_cnt) 33 | logits = model.inference(inputs, dropout_keep_prob, label_cnt) 34 | 35 | accuracy = model.accuracy(logits, labels) 36 | loss = model.loss(logits, labels) 37 | train = tf.train.RMSPropOptimizer(learning_rate, FLAGS.rms_decay).minimize(loss) 38 | 39 | # session 40 | init = tf.initialize_all_variables() 41 | sess = tf.Session() 42 | sess.run(init) 43 | 44 | # ready for summary 45 | merged = tf.merge_all_summaries() 46 | train_writer = tf.train.SummaryWriter('./summary/train', sess.graph) 47 | validation_writer = tf.train.SummaryWriter('./summary/validation') 48 | 49 | # tf saver 50 | saver = tf.train.Saver() 51 | if os.path.isfile(FLAGS.save_name): 52 | saver.restore(sess, FLAGS.save_name) 53 | 54 | total_start_time = time.time() 55 | 56 | # load mnist data 57 | train_images, train_labels, train_range, validation_images, validation_labels, validation_indices = loader.load_mnist_train( 58 | FLAGS.validation_size, FLAGS.batch_size) 59 | 60 | total_train_len = len(train_images) 61 | i = 0 62 | cur_learning_rate = FLAGS.learning_rate 63 | for epoch in range(FLAGS.training_epoch): 64 | if epoch % 10 == 0 and epoch > 0: 65 | cur_learning_rate /= 10 66 | epoch_start_time = time.time() 67 | for start, end in train_range: 68 | batch_start_time = time.time() 69 | train_x = train_images[start:end] 70 | train_y = train_labels[start:end] 71 | if i % 20 == 0: 72 | summary, _, loss_result = sess.run([merged, train, loss], feed_dict={inputs: train_x, labels: train_y, 73 | dropout_keep_prob: FLAGS.dropout_keep_prob, 74 | learning_rate: cur_learning_rate}) 75 | train_writer.add_summary(summary, i) 76 | else: 77 | _, loss_result = sess.run([train, loss], feed_dict={inputs: train_x, labels: train_y, 78 | dropout_keep_prob: FLAGS.dropout_keep_prob, 79 | learning_rate: cur_learning_rate}) 80 | print('[%s][training][epoch %d, step %d exec %.2f seconds] [file: %5d ~ %5d / %5d] loss : %3.10f' % ( 81 | time.strftime("%Y-%m-%d %H:%M:%S"), epoch, i, (time.time() - batch_start_time), start, end, 82 | total_train_len, loss_result)) 83 | 84 | if i % FLAGS.validation_interval == 0 and i > 0: 85 | validation_start_time = time.time() 86 | shuffle_indices = loader.shuffle_validation(validation_indices, FLAGS.batch_size) 87 | validation_x = validation_images[shuffle_indices] 88 | validation_y = validation_labels[shuffle_indices] 89 | summary, accuracy_result, loss_result = sess.run([merged, accuracy, loss], 90 | feed_dict={inputs: validation_x, labels: validation_y, 91 | dropout_keep_prob: 1.0}) 92 | validation_writer.add_summary(summary, i) 93 | print('[%s][validation][epoch %d, step %d exec %.2f seconds] accuracy : %1.3f, loss : %3.10f' % ( 94 | time.strftime("%Y-%m-%d %H:%M:%S"), epoch, i, (time.time() - validation_start_time), 95 | accuracy_result, loss_result)) 96 | 97 | i += 1 98 | 99 | print("[%s][epoch exec %s seconds] epoch : %d" % ( 100 | time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - epoch_start_time), epoch)) 101 | saver.save(sess, FLAGS.save_name) 102 | print("[%s][total exec %s seconds" % (time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - total_start_time))) 103 | train_writer.close() 104 | validation_writer.close() 105 | 106 | 107 | def test(): 108 | # build graph 109 | inputs, labels, dropout_keep_prob, learning_rate = model.input_placeholder(FLAGS.image_size, FLAGS.image_channel, 110 | FLAGS.label_cnt) 111 | logits = model.inference(inputs, dropout_keep_prob) 112 | predict = tf.argmax(logits, 1) 113 | 114 | # session 115 | init = tf.initialize_all_variables() 116 | sess = tf.Session() 117 | sess.run(init) 118 | 119 | # tf saver 120 | saver = tf.train.Saver() 121 | if os.path.isfile(FLAGS.save_name): 122 | saver.restore(sess, FLAGS.save_name) 123 | 124 | i = 1 125 | 126 | # load test data 127 | test_images, test_ranges = loader.load_mnist_test(FLAGS.batch_size) 128 | 129 | # ready for result file 130 | test_result_file = open(FLAGS.test_result, 'wb') 131 | csv_writer = csv.writer(test_result_file) 132 | csv_writer.writerow(['ImageId', 'Label']) 133 | 134 | total_start_time = time.time() 135 | 136 | for file_start, file_end in test_ranges: 137 | test_x = test_images[file_start:file_end] 138 | predict_label = sess.run(predict, feed_dict={inputs: test_x, dropout_keep_prob: 1.0}) 139 | 140 | for cur_predict in predict_label: 141 | csv_writer.writerow([i, cur_predict]) 142 | print('[Result %s: %s]' % (i, cur_predict)) 143 | i += 1 144 | print("[%s][total exec %s seconds" % (time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - total_start_time))) 145 | 146 | 147 | def main(_): 148 | if FLAGS.is_train: 149 | train() 150 | else: 151 | test() 152 | 153 | 154 | if __name__ == '__main__': 155 | tf.app.run() 156 | -------------------------------------------------------------------------------- /kaggle_mnist_alexnet_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import ops as op 3 | 4 | 5 | def input_placeholder(image_size, image_channel, label_cnt): 6 | with tf.name_scope('inputlayer'): 7 | inputs = tf.placeholder("float", [None, image_size, image_size, image_channel], 'inputs') 8 | labels = tf.placeholder("float", [None, label_cnt], 'labels') 9 | dropout_keep_prob = tf.placeholder("float", None, 'keep_prob') 10 | learning_rate = tf.placeholder("float", None, name='learning_rate') 11 | 12 | return inputs, labels, dropout_keep_prob, learning_rate 13 | 14 | 15 | def inference(inputs, dropout_keep_prob, label_cnt): 16 | # todo: change lrn parameters 17 | # conv layer 1 18 | with tf.name_scope('conv1layer'): 19 | conv1 = op.conv(inputs, 7, 96, 3) 20 | conv1 = op.lrn(conv1) 21 | conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='VALID') 22 | 23 | # conv layer 2 24 | with tf.name_scope('conv2layer'): 25 | conv2 = op.conv(conv1, 5, 256, 1, 1.0) 26 | conv2 = op.lrn(conv2) 27 | conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='VALID') 28 | 29 | # conv layer 3 30 | with tf.name_scope('conv3layer'): 31 | conv3 = op.conv(conv2, 3, 384, 1) 32 | 33 | # conv layer 4 34 | with tf.name_scope('conv4layer'): 35 | conv4 = op.conv(conv3, 3, 384, 1, 1.0) 36 | 37 | # conv layer 5 38 | with tf.name_scope('conv5layer'): 39 | conv5 = op.conv(conv4, 3, 256, 1, 1.0) 40 | conv5 = tf.nn.max_pool(conv5, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID') 41 | 42 | # fc layer 1 43 | with tf.name_scope('fc1layer'): 44 | fc1 = op.fc(conv5, 4096, 1.0) 45 | fc1 = tf.nn.dropout(fc1, dropout_keep_prob) 46 | 47 | # fc layer 2 48 | with tf.name_scope('fc2layer'): 49 | fc2 = op.fc(fc1, 4096, 1.0) 50 | fc2 = tf.nn.dropout(fc2, dropout_keep_prob) 51 | 52 | # fc layer 3 - output 53 | with tf.name_scope('fc3layer'): 54 | return op.fc(fc2, label_cnt, 1.0, None) 55 | 56 | 57 | def accuracy(logits, labels): 58 | # accuracy 59 | with tf.name_scope('accuracy'): 60 | accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)), tf.float32)) 61 | tf.scalar_summary('accuracy', accuracy) 62 | return accuracy 63 | 64 | 65 | def loss(logits, labels): 66 | with tf.name_scope('loss'): 67 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, labels)) 68 | tf.scalar_summary('loss', loss) 69 | return loss 70 | 71 | 72 | def train_rms_prop(loss, learning_rate, decay=0.9, momentum=0.0, epsilon=1e-10, use_locking=False, name='RMSProp'): 73 | return tf.train.RMSPropOptimizer(learning_rate, decay, momentum, epsilon, use_locking, name).minimize(loss) 74 | -------------------------------------------------------------------------------- /kaggle_mnist_input.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import tensorflow as tf 4 | import os 5 | import sys 6 | from six.moves import urllib 7 | 8 | FLAGS = tf.app.flags.FLAGS 9 | TRAIN_DATA_URL = 'http://file.hovits.com/dl/train.csv' 10 | TEST_DATA_URL = 'http://file.hovits.com/dl/teset.csv' 11 | 12 | 13 | def dense_to_one_hot(labels_dense, num_classes): 14 | num_labels = labels_dense.shape[0] 15 | index_offset = np.arange(num_labels) * num_classes 16 | labels_one_hot = np.zeros((num_labels, num_classes)) 17 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 18 | return labels_one_hot 19 | 20 | 21 | def load_mnist_train(validation_size=2000, batch_size=128): 22 | download_train() 23 | 24 | data = pd.read_csv(FLAGS.train_path) 25 | 26 | images = data.iloc[:, 1:].values 27 | images = images.astype(np.float) 28 | 29 | images = np.multiply(images, 1.0 / 255.0) 30 | 31 | image_size = images.shape[1] 32 | 33 | image_width = image_height = np.ceil(np.sqrt(image_size)).astype(np.uint8) 34 | images = images.reshape(-1, image_width, image_height, 1) 35 | 36 | labels_flat = data[[0]].values.ravel() 37 | labels_count = np.unique(labels_flat).shape[0] 38 | 39 | labels = dense_to_one_hot(labels_flat, labels_count) 40 | labels = labels.astype(np.uint8) 41 | 42 | validation_images = images[:validation_size] 43 | validation_labels = labels[:validation_size] 44 | 45 | train_images = images[validation_size:] 46 | train_labels = labels[validation_size:] 47 | 48 | train_range = zip(range(0, len(train_images), batch_size), range(batch_size, len(train_images), batch_size)) 49 | 50 | if len(train_images) % batch_size > 0: 51 | train_range.append((train_range[-1][1], len(train_images))) 52 | 53 | validation_indices = np.arange(len(validation_images)) 54 | 55 | return train_images, train_labels, train_range, validation_images, validation_labels, validation_indices 56 | 57 | 58 | def shuffle_validation(validation_indices, batch_size): 59 | np.random.shuffle(validation_indices) 60 | return validation_indices[0:batch_size] 61 | 62 | 63 | def download_train(): 64 | statinfo = download(FLAGS.train_path, TRAIN_DATA_URL) 65 | if statinfo: 66 | print('Training data is successfully downloaded', statinfo.st_size, 'bytes.') 67 | else: 68 | print('Training data was already downloaded') 69 | 70 | 71 | def download_test(): 72 | statinfo = download(FLAGS.test_path, TEST_DATA_URL) 73 | if statinfo: 74 | print('Test data is successfully downloaded', statinfo.st_size, 'bytes.') 75 | else: 76 | print('Test data was already downloaded') 77 | 78 | 79 | def download(path, url): 80 | if not os.path.exists(path): 81 | if not os.path.isdir(os.path.basename(path)): 82 | os.makedirs(os.path.basename(path)) 83 | 84 | def _progress(count, block_size, total_size): 85 | sys.stdout.write( 86 | '\r>> Downloading %s %.1f%%' % (path, float(count * block_size) / float(total_size) * 100.0)) 87 | sys.stdout.flush() 88 | 89 | file_path, _ = urllib.request.urlretrieve(url, path, _progress) 90 | print() 91 | return os.stat(file_path) 92 | -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv(inputs, kernel_size, output_num, stride_size=1, init_bias=0.0, conv_padding='SAME', stddev=0.01, 5 | activation_func=tf.nn.relu): 6 | input_size = inputs.get_shape().as_list()[-1] 7 | conv_weights = tf.Variable( 8 | tf.random_normal([kernel_size, kernel_size, input_size, output_num], dtype=tf.float32, stddev=stddev), 9 | name='weights') 10 | conv_biases = tf.Variable(tf.constant(init_bias, shape=[output_num], dtype=tf.float32), 'biases') 11 | conv_layer = tf.nn.conv2d(inputs, conv_weights, [1, stride_size, stride_size, 1], padding=conv_padding) 12 | conv_layer = tf.nn.bias_add(conv_layer, conv_biases) 13 | if activation_func: 14 | conv_layer = activation_func(conv_layer) 15 | return conv_layer 16 | 17 | 18 | def fc(inputs, output_size, init_bias=0.0, activation_func=tf.nn.relu, stddev=0.01): 19 | input_shape = inputs.get_shape().as_list() 20 | if len(input_shape) == 4: 21 | fc_weights = tf.Variable( 22 | tf.random_normal([input_shape[1] * input_shape[2] * input_shape[3], output_size], dtype=tf.float32, 23 | stddev=stddev), 24 | name='weights') 25 | inputs = tf.reshape(inputs, [-1, fc_weights.get_shape().as_list()[0]]) 26 | else: 27 | fc_weights = tf.Variable(tf.random_normal([input_shape[-1], output_size], dtype=tf.float32, stddev=stddev), 28 | name='weights') 29 | 30 | fc_biases = tf.Variable(tf.constant(init_bias, shape=[output_size], dtype=tf.float32), name='biases') 31 | fc_layer = tf.matmul(inputs, fc_weights) 32 | fc_layer = tf.nn.bias_add(fc_layer, fc_biases) 33 | if activation_func: 34 | fc_layer = activation_func(fc_layer) 35 | return fc_layer 36 | 37 | 38 | def lrn(inputs, depth_radius=2, alpha=0.0001, beta=0.75, bias=1.0): 39 | return tf.nn.local_response_normalization(inputs, depth_radius=depth_radius, alpha=alpha, beta=beta, bias=bias) 40 | -------------------------------------------------------------------------------- /simple_kaggle_mnist_alexnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is simple Alexnet train implementation modified for Kaggle mnist data. 3 | """ 4 | 5 | import time 6 | import tensorflow as tf 7 | import kaggle_mnist_input as loader 8 | import os 9 | import csv 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | tf.app.flags.DEFINE_integer('training_epoch', 30, "training epoch") 14 | tf.app.flags.DEFINE_integer('batch_size', 128, "batch size") 15 | tf.app.flags.DEFINE_integer('validation_interval', 100, "validation interval") 16 | 17 | tf.app.flags.DEFINE_float('dropout_keep_prob', 0.5, "dropout keep prob") 18 | tf.app.flags.DEFINE_float('learning_rate', 0.001, "learning rate") 19 | tf.app.flags.DEFINE_float('rms_decay', 0.9, "rms optimizer decay") 20 | tf.app.flags.DEFINE_float('weight_decay', 0.0005, "l2 regularization weight decay") 21 | tf.app.flags.DEFINE_string('train_path', '/tmp/train.csv', "path to download training data") 22 | tf.app.flags.DEFINE_string('test_path', '/tmp/test.csv', "path to download test data") 23 | tf.app.flags.DEFINE_integer('validation_size', 2000, "validation size in training data") 24 | tf.app.flags.DEFINE_string('save_name', os.getcwd() + '/var.ckpt', "path to save variables") 25 | tf.app.flags.DEFINE_boolean('is_train', True, "True for train, False for test") 26 | tf.app.flags.DEFINE_string('test_result', 'result.csv', "test file path") 27 | 28 | image_size = 28 29 | image_channel = 1 30 | label_cnt = 10 31 | 32 | inputs = tf.placeholder("float", [None, image_size, image_size, image_channel]) 33 | labels = tf.placeholder("float", [None, label_cnt]) 34 | dropout_keep_prob = tf.placeholder("float", None) 35 | learning_rate_ph = tf.placeholder("float", None) 36 | 37 | # conv layer 1 38 | conv1_weights = tf.Variable(tf.random_normal([7, 7, image_channel, 96], dtype=tf.float32, stddev=0.01)) 39 | conv1_biases = tf.Variable(tf.constant(0.0, shape=[96], dtype=tf.float32)) 40 | conv1 = tf.nn.conv2d(inputs, conv1_weights, [1, 3, 3, 1], padding='SAME') 41 | conv1 = tf.nn.bias_add(conv1, conv1_biases) 42 | conv1_relu = tf.nn.relu(conv1) 43 | conv1_norm = tf.nn.local_response_normalization(conv1_relu, depth_radius=2, alpha=0.0001, beta=0.75, bias=1.0) 44 | conv1_pool = tf.nn.max_pool(conv1_norm, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='VALID') 45 | 46 | # conv layer 2 47 | conv2_weights = tf.Variable(tf.random_normal([5, 5, 96, 256], dtype=tf.float32, stddev=0.01)) 48 | conv2_biases = tf.Variable(tf.constant(1.0, shape=[256], dtype=tf.float32)) 49 | conv2 = tf.nn.conv2d(conv1_pool, conv2_weights, [1, 1, 1, 1], padding='SAME') 50 | conv2 = tf.nn.bias_add(conv2, conv2_biases) 51 | conv2_relu = tf.nn.relu(conv2) 52 | conv2_norm = tf.nn.local_response_normalization(conv2_relu) 53 | conv2_pool = tf.nn.max_pool(conv2_norm, ksize=[1, 2, 2, 1], strides=[1, 1, 1, 1], padding='VALID') 54 | 55 | # conv layer 3 56 | conv3_weights = tf.Variable(tf.random_normal([3, 3, 256, 384], dtype=tf.float32, stddev=0.01)) 57 | conv3_biases = tf.Variable(tf.constant(0.0, shape=[384], dtype=tf.float32)) 58 | conv3 = tf.nn.conv2d(conv2_pool, conv3_weights, [1, 1, 1, 1], padding='SAME') 59 | conv3 = tf.nn.bias_add(conv3, conv3_biases) 60 | conv3_relu = tf.nn.relu(conv3) 61 | 62 | # conv layer 4 63 | conv4_weights = tf.Variable(tf.random_normal([3, 3, 384, 384], dtype=tf.float32, stddev=0.01)) 64 | conv4_biases = tf.Variable(tf.constant(1.0, shape=[384], dtype=tf.float32)) 65 | conv4 = tf.nn.conv2d(conv3_relu, conv4_weights, [1, 1, 1, 1], padding='SAME') 66 | conv4 = tf.nn.bias_add(conv4, conv4_biases) 67 | conv4_relu = tf.nn.relu(conv4) 68 | 69 | # conv layer 5 70 | conv5_weights = tf.Variable(tf.random_normal([3, 3, 384, 256], dtype=tf.float32, stddev=0.01)) 71 | conv5_biases = tf.Variable(tf.constant(1.0, shape=[256], dtype=tf.float32)) 72 | conv5 = tf.nn.conv2d(conv4_relu, conv5_weights, [1, 1, 1, 1], padding='SAME') 73 | conv5 = tf.nn.bias_add(conv5, conv5_biases) 74 | conv5_relu = tf.nn.relu(conv5) 75 | conv5_pool = tf.nn.max_pool(conv5_relu, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], padding='VALID') 76 | 77 | # fc layer 1 78 | fc1_weights = tf.Variable(tf.random_normal([256 * 3 * 3, 4096], dtype=tf.float32, stddev=0.01)) 79 | fc1_biases = tf.Variable(tf.constant(1.0, shape=[4096], dtype=tf.float32)) 80 | conv5_reshape = tf.reshape(conv5_pool, [-1, fc1_weights.get_shape().as_list()[0]]) 81 | fc1 = tf.matmul(conv5_reshape, fc1_weights) 82 | fc1 = tf.nn.bias_add(fc1, fc1_biases) 83 | fc1_relu = tf.nn.relu(fc1) 84 | fc1_drop = tf.nn.dropout(fc1_relu, dropout_keep_prob) 85 | 86 | # fc layer 2 87 | fc2_weights = tf.Variable(tf.random_normal([4096, 4096], dtype=tf.float32, stddev=0.01)) 88 | fc2_biases = tf.Variable(tf.constant(1.0, shape=[4096], dtype=tf.float32)) 89 | fc2 = tf.matmul(fc1_drop, fc2_weights) 90 | fc2 = tf.nn.bias_add(fc2, fc2_biases) 91 | fc2_relu = tf.nn.relu(fc2) 92 | fc2_drop = tf.nn.dropout(fc2_relu, dropout_keep_prob) 93 | 94 | # fc layer 3 - output 95 | fc3_weights = tf.Variable(tf.random_normal([4096, label_cnt], dtype=tf.float32, stddev=0.01)) 96 | fc3_biases = tf.Variable(tf.constant(1.0, shape=[label_cnt], dtype=tf.float32)) 97 | fc3 = tf.matmul(fc2_drop, fc3_weights) 98 | logits = tf.nn.bias_add(fc3, fc3_biases) 99 | 100 | # loss 101 | loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, labels)) 102 | # l2 regularization 103 | regularizers = (tf.nn.l2_loss(conv1_weights) + tf.nn.l2_loss(conv1_biases) + 104 | tf.nn.l2_loss(conv2_weights) + tf.nn.l2_loss(conv2_biases) + 105 | tf.nn.l2_loss(conv3_weights) + tf.nn.l2_loss(conv3_biases) + 106 | tf.nn.l2_loss(conv4_weights) + tf.nn.l2_loss(conv4_biases) + 107 | tf.nn.l2_loss(conv5_weights) + tf.nn.l2_loss(conv5_biases) + 108 | tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) + 109 | tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases) + 110 | tf.nn.l2_loss(fc3_weights) + tf.nn.l2_loss(fc3_biases)) 111 | loss += FLAGS.weight_decay * regularizers 112 | 113 | # accuracy 114 | predict = tf.argmax(logits, 1) 115 | accuracy = tf.reduce_mean(tf.cast(tf.equal(predict, tf.argmax(labels, 1)), tf.float32)) 116 | 117 | # train 118 | train = tf.train.RMSPropOptimizer(learning_rate_ph, FLAGS.rms_decay).minimize(loss) 119 | # train = tf.train.MomentumOptimizer(learning_rate_ph, FLAGS.momentum).minimize(loss) 120 | 121 | # session 122 | init = tf.initialize_all_variables() 123 | sess = tf.Session() 124 | sess.run(init) 125 | 126 | # tf saver 127 | saver = tf.train.Saver() 128 | if os.path.isfile(FLAGS.save_name): 129 | saver.restore(sess, FLAGS.save_name) 130 | 131 | total_start_time = time.time() 132 | 133 | # begin training 134 | if FLAGS.is_train: 135 | # load mnist data 136 | train_images, train_labels, train_range, validation_images, validation_labels, validation_indices = loader.load_mnist_train( 137 | FLAGS.validation_size, FLAGS.batch_size) 138 | 139 | total_train_len = len(train_images) 140 | i = 0 141 | learning_rate = FLAGS.learning_rate 142 | for epoch in range(FLAGS.training_epoch): 143 | if epoch % 10 == 0 and epoch > 0: 144 | learning_rate /= 10 145 | epoch_start_time = time.time() 146 | for start, end in train_range: 147 | batch_start_time = time.time() 148 | trainX = train_images[start:end] 149 | trainY = train_labels[start:end] 150 | _, loss_result = sess.run([train, loss], feed_dict={inputs: trainX, labels: trainY, 151 | dropout_keep_prob: FLAGS.dropout_keep_prob, 152 | learning_rate_ph: learning_rate}) 153 | print('[%s][training][epoch %d, step %d exec %.2f seconds] [file: %5d ~ %5d / %5d] loss : %3.10f' % ( 154 | time.strftime("%Y-%m-%d %H:%M:%S"), epoch, i, (time.time() - batch_start_time), start, end, 155 | total_train_len, loss_result)) 156 | 157 | if i % FLAGS.validation_interval == 0 and i > 0: 158 | validation_start_time = time.time() 159 | shuffle_indices = loader.shuffle_validation(validation_indices, FLAGS.batch_size) 160 | validationX = validation_images[shuffle_indices] 161 | validationY = validation_labels[shuffle_indices] 162 | accuracy_result, loss_result = sess.run([accuracy, loss], 163 | feed_dict={inputs: validationX, labels: validationY, 164 | dropout_keep_prob: 1.0}) 165 | print('[%s][validation][epoch %d, step %d exec %.2f seconds] accuracy : %1.3f, loss : %3.10f' % ( 166 | time.strftime("%Y-%m-%d %H:%M:%S"), epoch, i, (time.time() - validation_start_time), 167 | accuracy_result, loss_result)) 168 | 169 | i += 1 170 | 171 | print("[%s][epoch exec %s seconds] epoch : %d" % ( 172 | time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - epoch_start_time), epoch)) 173 | saver.save(sess, FLAGS.save_name) 174 | # begin test 175 | else: 176 | i = 1 177 | test_images, test_ranges = loader.load_mnist_test(FLAGS.batch_size) 178 | 179 | test_result_file = open(FLAGS.test_result, 'wb') 180 | csv_writer = csv.writer(test_result_file) 181 | csv_writer.writerow(['ImageId', 'Label']) 182 | 183 | for file_start, file_end in test_ranges: 184 | testX = test_images[file_start:file_end] 185 | predict_label = sess.run(predict, feed_dict={inputs: testX, dropout_keep_prob: 1.0}) 186 | 187 | for cur_predict in predict_label: 188 | csv_writer.writerow([i, cur_predict]) 189 | print('[Result %s: %s]' % (i, cur_predict)) 190 | i += 1 191 | 192 | print("[%s][total exec %s seconds" % (time.strftime("%Y-%m-%d %H:%M:%S"), (time.time() - total_start_time))) 193 | --------------------------------------------------------------------------------