├── .gitignore ├── LICENSE ├── README.md ├── mnist.py ├── mnist_train.py ├── mnist_train_estimator.py ├── mnist_train_slim.py ├── mnist_train_tfdata.py ├── predict.py ├── validation.py └── validation_slim.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.dat 2 | *.jpg 3 | *.png 4 | *.tfrecords 5 | .idea 6 | __pycache__ 7 | train 8 | log.txt 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 jie 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # learn_tf 2 | TensorFlow: learn and practice 3 | -------------------------------------------------------------------------------- /mnist.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib import slim 3 | from tensorflow import keras 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | tf.app.flags.DEFINE_integer('image_height', 28, 'the height of image') 7 | tf.app.flags.DEFINE_integer('image_width', 28, 'the width of image') 8 | tf.app.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch') 9 | TRAIN_EXAMPLES_NUM = 55000 10 | VALIDATION_EXAMPLES_NUM = 5000 11 | TEST_EXAMPLES_NUM = 10000 12 | 13 | 14 | def parse_data(example_proto): 15 | features = {'img_raw': tf.FixedLenFeature([], tf.string, ''), 16 | 'label': tf.FixedLenFeature([], tf.int64, 0)} 17 | parsed_features = tf.parse_single_example(example_proto, features) 18 | image = tf.decode_raw(parsed_features['img_raw'], tf.uint8) 19 | label = tf.cast(parsed_features['label'], tf.int64) 20 | image = tf.reshape(image, [FLAGS.image_height, FLAGS.image_width, 1]) 21 | image = tf.cast(image, tf.float32) 22 | return image, label 23 | 24 | 25 | def read_mnist_tfrecords(filename_queue): 26 | reader = tf.TFRecordReader() 27 | _, serialized_example = reader.read(filename_queue) 28 | 29 | features = tf.parse_single_example(serialized_example, features={ 30 | 'img_raw': tf.FixedLenFeature([], tf.string, ''), 31 | 'label': tf.FixedLenFeature([], tf.int64, 0) 32 | }) 33 | image = tf.decode_raw(features['img_raw'], tf.uint8) 34 | label = tf.cast(features['label'], tf.int64) 35 | image = tf.reshape(image, [FLAGS.image_height, FLAGS.image_width, 1]) 36 | return image, label 37 | 38 | 39 | def inputs(filenames, examples_num, batch_size, shuffle): 40 | for f in filenames: 41 | if not tf.gfile.Exists(f): 42 | raise ValueError('Failed to find file: ' + f) 43 | with tf.name_scope('inputs'): 44 | filename_queue = tf.train.string_input_producer(filenames) 45 | image, label = read_mnist_tfrecords(filename_queue) 46 | image = tf.cast(image, tf.float32) 47 | min_fraction_of_examples_in_queue = 0.4 48 | min_queue_examples = int(min_fraction_of_examples_in_queue * examples_num) 49 | num_process_threads = 16 50 | if shuffle: 51 | images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, 52 | num_threads=num_process_threads, 53 | capacity=min_queue_examples + batch_size * 3, 54 | min_after_dequeue=min_queue_examples) 55 | else: 56 | images, labels = tf.train.batch([image, label], batch_size=batch_size, 57 | num_threads=num_process_threads, 58 | capacity=min_queue_examples + batch_size * 3) 59 | return images, labels 60 | 61 | 62 | def inference(images, training): 63 | with tf.variable_scope('conv1'): 64 | conv1 = tf.layers.conv2d(inputs=images, 65 | filters=32, 66 | kernel_size=[5, 5], 67 | padding='same', 68 | activation=tf.nn.relu) 69 | 70 | pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # 14*14*32 71 | 72 | with tf.variable_scope('conv2'): 73 | conv2 = tf.layers.conv2d(inputs=pool1, 74 | filters=64, 75 | kernel_size=[5, 5], 76 | padding='same', 77 | activation=tf.nn.relu) 78 | 79 | pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) # 7*7*64 80 | 81 | with tf.variable_scope('fc1'): 82 | pool2_flat = tf.reshape(pool2, [-1, 7*7*64]) 83 | fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) 84 | dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=training) 85 | 86 | with tf.variable_scope('logits'): 87 | logits = tf.layers.dense(inputs=dropout1, units=10) # 使用该值计算交叉熵损失 88 | predict = tf.nn.softmax(logits) 89 | 90 | return logits, predict 91 | 92 | 93 | def loss(logits, labels): 94 | labels = tf.cast(labels, tf.int64) 95 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='cross_entropy') 96 | cross_entropy_loss = tf.reduce_mean(cross_entropy) 97 | return cross_entropy_loss 98 | 99 | 100 | def train(total_loss, global_step): 101 | num_batches_per_epoch = TRAIN_EXAMPLES_NUM / FLAGS.batch_size 102 | decay_steps = int(num_batches_per_epoch * 10) 103 | 104 | # Decay the learning rate exponentially based on the number of steps. 105 | lr = tf.train.exponential_decay(learning_rate=0.001, 106 | global_step=global_step, 107 | decay_steps=decay_steps, 108 | decay_rate=0.1, 109 | staircase=True) 110 | 111 | # opt = tf.train.GradientDescentOptimizer(lr) 112 | # opt = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.99) 113 | opt = tf.train.AdamOptimizer(learning_rate=lr) 114 | grad = opt.compute_gradients(total_loss) 115 | apply_grad_op = opt.apply_gradients(grad, global_step) 116 | 117 | return apply_grad_op 118 | 119 | 120 | def model_slim(images, labels, is_training): 121 | net = slim.conv2d(images, 32, [5, 5], scope='conv1') 122 | net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool1') 123 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 124 | net = slim.max_pool2d(net, [2, 2], stride=2, scope='pool2') 125 | net = slim.flatten(net, scope='flatten') 126 | net = slim.fully_connected(net, 1024, scope='fully_connected1') 127 | net = slim.dropout(net, keep_prob=0.6, is_training=is_training) 128 | logits = slim.fully_connected(net, 10, activation_fn=None, scope='fully_connected2') 129 | 130 | prob = slim.softmax(logits) 131 | loss = slim.losses.sparse_softmax_cross_entropy(logits, labels) 132 | 133 | global_step = tf.train.get_or_create_global_step() 134 | num_batches_per_epoch = TRAIN_EXAMPLES_NUM / FLAGS.batch_size 135 | decay_steps = int(num_batches_per_epoch * 10) 136 | 137 | # Decay the learning rate exponentially based on the number of steps. 138 | lr = tf.train.exponential_decay(learning_rate=0.001, 139 | global_step=global_step, 140 | decay_steps=decay_steps, 141 | decay_rate=0.1, 142 | staircase=True) 143 | 144 | opt = tf.train.AdamOptimizer(learning_rate=lr) 145 | 146 | return opt, loss, prob 147 | 148 | 149 | def model_fn(features, labels, mode): 150 | with tf.variable_scope('conv1'): 151 | conv1 = tf.layers.conv2d(inputs=features, 152 | filters=32, 153 | kernel_size=[5, 5], 154 | padding='same', 155 | activation=tf.nn.relu) 156 | 157 | pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2) # 14*14*32 158 | 159 | with tf.variable_scope('conv2'): 160 | conv2 = tf.layers.conv2d(inputs=pool1, 161 | filters=64, 162 | kernel_size=[5, 5], 163 | padding='same', 164 | activation=tf.nn.relu) 165 | 166 | pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2) # 7*7*64 167 | 168 | with tf.variable_scope('fc1'): 169 | pool2_flat = tf.reshape(pool2, [-1, 7*7*64]) 170 | fc1 = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) 171 | dropout1 = tf.layers.dropout(inputs=fc1, rate=0.4, training=mode == tf.estimator.ModeKeys.TRAIN) 172 | 173 | with tf.variable_scope('logits'): 174 | logits = tf.layers.dense(inputs=dropout1, units=10) # 使用该值计算交叉熵损失 175 | predict = tf.nn.softmax(logits) 176 | 177 | predictions = { 178 | # Generate predictions (for PREDICT and EVAL mode) 179 | "classes": tf.argmax(input=logits, axis=1), 180 | # Add `softmax_tensor` to the graph. It is used for PREDICT and by the 181 | # `logging_hook`. 182 | "probabilities": tf.nn.softmax(logits, name="softmax_tensor") 183 | } 184 | 185 | if mode == tf.estimator.ModeKeys.PREDICT: 186 | return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) 187 | 188 | loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) 189 | accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions["classes"]) 190 | tf.summary.scalar('accuracy', accuracy[1]) 191 | 192 | if mode == tf.estimator.ModeKeys.TRAIN: 193 | global_step = tf.train.get_global_step() 194 | train_op = train(loss, global_step) 195 | return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) 196 | 197 | # Add evaluation metrics (for EVAL mode) 198 | eval_metric_ops = {"eval_accuracy": accuracy} 199 | return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) 200 | 201 | 202 | def input_fn(filenames, training): 203 | dataset = tf.data.TFRecordDataset(filenames) 204 | dataset = dataset.map(parse_data) 205 | 206 | if training: 207 | dataset = dataset.shuffle(buffer_size=50000) 208 | dataset = dataset.batch(FLAGS.batch_size) 209 | if training: 210 | dataset = dataset.repeat() 211 | 212 | iterator = dataset.make_one_shot_iterator() 213 | features, labels = iterator.get_next() 214 | return features, labels 215 | 216 | 217 | def model_keras(): 218 | model = keras.Sequential() 219 | model.add(keras.layers.Conv2D(filters=32, 220 | kernel_size=[5, 5], 221 | padding='same', 222 | activation=tf.nn.relu, 223 | input_shape=[FLAGS.image_height, FLAGS.image_width, 1])) 224 | model.add(keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)) 225 | model.add(keras.layers.Conv2D(filters=64, 226 | kernel_size=[5, 5], 227 | padding='same', 228 | activation=tf.nn.relu)) 229 | model.add(keras.layers.MaxPool2D(pool_size=[2, 2], strides=2)) 230 | model.add(keras.layers.Flatten(input_shape=[7, 7, 64])) 231 | model.add(keras.layers.Dense(units=1024, activation=tf.nn.relu)) 232 | model.add(keras.layers.Dropout(rate=0.4)) 233 | model.add(keras.layers.Dense(units=10)) 234 | model.add(keras.layers.Activation(tf.nn.softmax)) 235 | 236 | opt = keras.optimizers.Adam(0.001) 237 | model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy']) 238 | 239 | return model 240 | -------------------------------------------------------------------------------- /mnist_train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import mnist 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | tf.app.flags.DEFINE_integer('max_step', 1200, 'Number of steps to run trainer') 7 | tf.app.flags.DEFINE_string('train_dir', './train', 'Directory where to write event logs and checkpoint') 8 | 9 | 10 | def train(): 11 | images, labels = mnist.inputs(['train_img.tfrecords'], mnist.TRAIN_EXAMPLES_NUM, 12 | FLAGS.batch_size, shuffle=True) 13 | global_step = tf.train.get_or_create_global_step() 14 | 15 | logits, pred = mnist.inference(images, training=True) 16 | loss = mnist.loss(logits, labels) 17 | train_op = mnist.train(loss, global_step) 18 | saver = tf.train.Saver() 19 | with tf.Session() as sess: 20 | init_op = tf.group( 21 | tf.local_variables_initializer(), 22 | tf.global_variables_initializer()) 23 | sess.run(init_op) 24 | ckpt = os.path.join(FLAGS.train_dir, 'model.ckpt') 25 | 26 | coord = tf.train.Coordinator() 27 | threads = tf.train.start_queue_runners(sess, coord=coord) 28 | 29 | for i in range(1, FLAGS.max_step + 1): 30 | _, train_loss, predict, label = sess.run([train_op, loss, pred, labels]) 31 | # print(predict, '\n', label) 32 | if i % 100 == 0: 33 | print('step: {}, loss: {}'.format(i, train_loss)) 34 | # print(predict, '\n', label) 35 | saver.save(sess, ckpt, global_step=i) 36 | 37 | coord.request_stop() 38 | coord.join(threads) 39 | 40 | 41 | if __name__ == '__main__': 42 | if tf.gfile.Exists(FLAGS.train_dir): 43 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 44 | tf.gfile.MakeDirs(FLAGS.train_dir) 45 | train() 46 | -------------------------------------------------------------------------------- /mnist_train_estimator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import mnist 3 | import os 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | tf.app.flags.DEFINE_integer('max_step', 1000, 'Number of steps to run trainer') 7 | tf.app.flags.DEFINE_string('train_dir', './train', 'Directory where to write event logs and checkpoint') 8 | 9 | tf.logging.set_verbosity(tf.logging.INFO) 10 | 11 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 12 | 13 | 14 | def train(): 15 | my_checkpoint_config = tf.estimator.RunConfig(save_checkpoints_steps=100, keep_checkpoint_max=5) 16 | 17 | mnist_classifier = tf.estimator.Estimator(model_fn=mnist.model_fn, model_dir=FLAGS.train_dir, 18 | config=my_checkpoint_config) 19 | tensor_to_log = {'probabilities': 'softmax_tensor'} 20 | logging_hook = tf.train.LoggingTensorHook(tensors=tensor_to_log, every_n_iter=100) 21 | 22 | for i in range(FLAGS.max_step // 100): 23 | mnist_classifier.train(input_fn=lambda: mnist.input_fn(['./train_img.tfrecords'], True), 24 | # hooks=[logging_hook], 25 | steps=100) 26 | 27 | eval_results = mnist_classifier.evaluate(input_fn=lambda: mnist.input_fn(['./validation_img.tfrecords'], False)) 28 | print(eval_results) 29 | 30 | 31 | if __name__ == '__main__': 32 | if tf.gfile.Exists(FLAGS.train_dir): 33 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 34 | tf.gfile.MakeDirs(FLAGS.train_dir) 35 | train() 36 | -------------------------------------------------------------------------------- /mnist_train_slim.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import mnist 3 | from tensorflow.contrib import slim 4 | 5 | FLAGS = tf.app.flags.FLAGS 6 | tf.app.flags.DEFINE_integer('max_step', 12000, 'Number of steps to run trainer') 7 | tf.app.flags.DEFINE_string('train_dir', './train', 'Directory where to write event logs and checkpoint') 8 | 9 | tf.logging.set_verbosity(tf.logging.INFO) 10 | 11 | 12 | def train(): 13 | train_images, train_labels = mnist.input_fn(['./train_img.tfrecords'], True) 14 | 15 | train_op, loss, pred = mnist.model_slim(train_images, train_labels, is_training=True) 16 | train_tensor = slim.learning.create_train_op(loss, train_op) 17 | result = slim.learning.train(train_tensor, FLAGS.train_dir, number_of_steps=FLAGS.max_step, log_every_n_steps=100) 18 | print('final step loss: {}'.format(result)) 19 | 20 | 21 | if __name__ == '__main__': 22 | if tf.gfile.Exists(FLAGS.train_dir): 23 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 24 | tf.gfile.MakeDirs(FLAGS.train_dir) 25 | train() 26 | -------------------------------------------------------------------------------- /mnist_train_tfdata.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import mnist 3 | import numpy as np 4 | import os 5 | import math 6 | 7 | FLAGS = tf.app.flags.FLAGS 8 | tf.app.flags.DEFINE_integer('max_step', 10000, 'Number of steps to run trainer') 9 | tf.app.flags.DEFINE_string('train_dir', './train', 'Directory where to write event logs and checkpoint') 10 | 11 | 12 | def evaluate(sess, top_k_op, training, examples): 13 | iter_per_epoch = int(math.ceil(examples / FLAGS.batch_size)) 14 | # total_sample = iter_per_epoch * FLAGS.batch_size 15 | correct_predict = 0 16 | step = 0 17 | 18 | while step < iter_per_epoch: 19 | predict = sess.run(top_k_op, feed_dict={training: False}) 20 | correct_predict += np.sum(predict) 21 | step += 1 22 | 23 | precision = correct_predict / examples 24 | return precision 25 | 26 | 27 | def train(): 28 | filenames = tf.placeholder(tf.string, [None]) 29 | dataset = tf.data.TFRecordDataset(filenames) 30 | dataset = dataset.map(mnist.parse_data) 31 | dataset = dataset.shuffle(buffer_size=50000) 32 | dataset = dataset.batch(FLAGS.batch_size) 33 | dataset = dataset.repeat() 34 | 35 | iterator = dataset.make_initializable_iterator() 36 | 37 | global_step = tf.train.get_or_create_global_step() 38 | images, labels = iterator.get_next() 39 | logits, pred = mnist.inference(images, training=True) 40 | loss = mnist.loss(logits, labels) 41 | train_op = mnist.train(loss, global_step) 42 | 43 | with tf.train.MonitoredTrainingSession( 44 | checkpoint_dir=FLAGS.train_dir, 45 | hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_step), tf.train.NanTensorHook(loss)], 46 | save_checkpoint_steps=100 47 | ) as mon_sess: 48 | mon_sess.run(iterator.initializer, feed_dict={filenames: ['train_img.tfrecords']}) 49 | while not mon_sess.should_stop(): 50 | _, train_loss, train_step, label = mon_sess.run([train_op, loss, global_step, labels]) 51 | if train_step % 100 == 0: 52 | print('step: {}, loss: {}'.format(train_step, train_loss)) 53 | 54 | 55 | def train_and_validation(): 56 | training_dataset = tf.data.TFRecordDataset(['./train_img.tfrecords']) 57 | validation_dataset = tf.data.TFRecordDataset(['./validation_img.tfrecords']) 58 | test_dataset = tf.data.TFRecordDataset(['./test_img.tfrecords']) 59 | 60 | training_dataset = training_dataset.map(mnist.parse_data) 61 | training_dataset = training_dataset.shuffle(50000).batch(FLAGS.batch_size).repeat() 62 | validation_dataset = validation_dataset.map(mnist.parse_data).batch(FLAGS.batch_size) 63 | test_dataset = test_dataset.map(mnist.parse_data).batch(FLAGS.batch_size) 64 | 65 | iterator = tf.data.Iterator.from_structure(output_types=training_dataset.output_types, 66 | output_shapes=training_dataset.output_shapes) 67 | 68 | training_init_op = iterator.make_initializer(training_dataset) 69 | validation_init_op = iterator.make_initializer(validation_dataset) 70 | test_init_op = iterator.make_initializer(test_dataset) 71 | images, labels = iterator.get_next() 72 | 73 | training = tf.placeholder(dtype=tf.bool) 74 | logits, pred = mnist.inference(images, training=training) 75 | loss = mnist.loss(logits, labels) 76 | top_k_op = tf.nn.in_top_k(logits, labels, 1) 77 | global_step = tf.train.get_or_create_global_step() 78 | train_op = mnist.train(loss, global_step) 79 | saver = tf.train.Saver() 80 | with tf.Session() as sess: 81 | sess.run(tf.global_variables_initializer()) 82 | sess.run(training_init_op) 83 | print('begin to train!') 84 | ckpt = os.path.join(FLAGS.train_dir, 'model.ckpt') 85 | train_step = 0 86 | while train_step < FLAGS.max_step: 87 | _, train_loss, step, label = sess.run([train_op, loss, global_step, labels], feed_dict={training: True}) 88 | train_step += 1 89 | if train_step % 100 == 0: 90 | saver.save(sess, ckpt, train_step) 91 | if train_step % 1000 == 0: 92 | precision = evaluate(sess, top_k_op, training, mnist.TRAIN_EXAMPLES_NUM) 93 | print('step: {}, loss: {}, training precision: {}'.format(train_step, train_loss, precision)) 94 | sess.run(validation_init_op) 95 | precision = evaluate(sess, top_k_op, training, mnist.VALIDATION_EXAMPLES_NUM) 96 | print('step: {}, loss: {}, validation precision: {}'.format(train_step, train_loss, precision)) 97 | sess.run(training_init_op) 98 | sess.run(test_init_op) 99 | precision = evaluate(sess, top_k_op, training, mnist.TEST_EXAMPLES_NUM) 100 | print('finally test precision: {}'.format(precision)) 101 | 102 | 103 | if __name__ == '__main__': 104 | if tf.gfile.Exists(FLAGS.train_dir): 105 | tf.gfile.DeleteRecursively(FLAGS.train_dir) 106 | tf.gfile.MakeDirs(FLAGS.train_dir) 107 | # train() 108 | train_and_validation() 109 | 110 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import cv2 3 | import mnist 4 | import numpy as np 5 | 6 | 7 | def pred(filename, train_dir): 8 | img = cv2.imread(filename, flags=cv2.IMREAD_GRAYSCALE) 9 | img = tf.cast(img, tf.float32) 10 | img = tf.reshape(img, [-1, 28, 28, 1]) 11 | 12 | logits, predict = mnist.inference(img, training=False) 13 | saver = tf.train.Saver() 14 | with tf.Session() as sess: 15 | ckpt = tf.train.get_checkpoint_state(train_dir) 16 | if ckpt and ckpt.model_checkpoint_path: 17 | saver.restore(sess, ckpt.model_checkpoint_path) 18 | else: 19 | print('no checkpoint file') 20 | return 21 | pre = sess.run(predict) 22 | print('model:{}, file:{}, label: {} ({:.2f}%)'. 23 | format(ckpt.model_checkpoint_path, filename, np.argmax(pre[0]), np.max(pre[0]) * 100)) 24 | 25 | 26 | if __name__ == '__main__': 27 | pred('./img_test/2_2098.jpg', './train') 28 | -------------------------------------------------------------------------------- /validation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import mnist 3 | import numpy as np 4 | import time 5 | import math 6 | 7 | 8 | def eval_once(saver, top_k_op): 9 | with tf.Session() as sess: 10 | ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir) 11 | if ckpt and ckpt.model_checkpoint_path: 12 | saver.restore(sess, ckpt.model_checkpoint_path) 13 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 14 | else: 15 | print('no checkpoint file') 16 | return 17 | 18 | coord = tf.train.Coordinator() 19 | try: 20 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 21 | 22 | iter_per_epoch = int(math.ceil(mnist.VALIDATION_EXAMPLES_NUM / FLAGS.batch_size)) 23 | 24 | total_sample = iter_per_epoch * FLAGS.batch_size 25 | correct_predict = 0 26 | step = 0 27 | 28 | while step < iter_per_epoch and not coord.should_stop(): 29 | predict = sess.run(top_k_op) 30 | correct_predict += np.sum(predict) 31 | step += 1 32 | 33 | precision = correct_predict / total_sample 34 | print('step: {}, model: {}, precision: {}'.format(global_step, ckpt.model_checkpoint_path, precision)) 35 | 36 | except Exception as e: 37 | print('exception: ', e) 38 | coord.request_stop(e) 39 | finally: 40 | coord.request_stop() 41 | coord.join(threads) 42 | 43 | 44 | def evaluation(): 45 | images, labels = mnist.inputs(['./validation_img.tfrecords'], mnist.VALIDATION_EXAMPLES_NUM, 46 | batch_size=FLAGS.batch_size, shuffle=False) 47 | logits, pred = mnist.inference(images, training=False) 48 | top_k_op = tf.nn.in_top_k(logits, labels, 1) 49 | 50 | saver = tf.train.Saver() 51 | 52 | while True: 53 | eval_once(saver, top_k_op) 54 | if FLAGS.run_once: 55 | break 56 | time.sleep(FLAGS.eval_interval_secs) 57 | 58 | 59 | if __name__ == '__main__': 60 | FLAGS = tf.app.flags.FLAGS 61 | tf.app.flags.DEFINE_integer('eval_interval_secs', 100, 'How often to run the eval') 62 | tf.app.flags.DEFINE_string('train_dir', './train', 'Directory where to write event logs and checkpoint') 63 | tf.app.flags.DEFINE_boolean('run_once', True, 'whether to run eval only once') 64 | 65 | evaluation() 66 | -------------------------------------------------------------------------------- /validation_slim.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import mnist 3 | import math 4 | from tensorflow.contrib import slim 5 | 6 | FLAGS = tf.app.flags.FLAGS 7 | tf.app.flags.DEFINE_string('train_dir', './train', 'Directory where to write event logs and checkpoint') 8 | 9 | tf.logging.set_verbosity(tf.logging.INFO) 10 | 11 | 12 | def validation(): 13 | validation_images, validation_labels = mnist.input_fn(['./validation_img.tfrecords'], False) 14 | _, loss, pred = mnist.model_slim(validation_images, validation_labels, is_training=False) 15 | prediction = tf.argmax(pred, axis=1) 16 | 17 | # Choose the metrics to compute: 18 | value_op, update_op = tf.metrics.accuracy(validation_labels, prediction) 19 | num_batchs = math.ceil(mnist.VALIDATION_EXAMPLES_NUM / FLAGS.batch_size) 20 | 21 | print('Running evaluation...') 22 | # Only load latest checkpoint 23 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir) 24 | 25 | metric_values = slim.evaluation.evaluate_once( 26 | num_evals=num_batchs, 27 | master='', 28 | checkpoint_path=checkpoint_path, 29 | logdir=FLAGS.train_dir, 30 | eval_op=update_op, 31 | final_op=value_op) 32 | print(metric_values) 33 | 34 | 35 | if __name__ == '__main__': 36 | validation() 37 | --------------------------------------------------------------------------------