├── .gitignore ├── README.md ├── classify.py ├── data ├── built_data_set ├── input │ ├── test │ │ ├── negative │ │ │ └── .gitkeep │ │ └── positive │ │ │ └── .gitkeep │ └── train │ │ ├── negative │ │ └── .gitkeep │ │ └── positive │ │ └── .gitkeep ├── testset.json └── trainingset.json ├── model.py └── tf_data.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | **/*.pyc 3 | data/input/**/*.jpg 4 | data/input/**/*.png 5 | data/output/ 6 | 7 | .DS_Store 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow Image Classification 2 | This repo demonstrates an image classification system for any set of images reusably. It uses tensorflow to create a neural network and do the image classification, and imagemagick to preprocess images. It then will generate data files for the images and labels for use within tensorflow. 3 | 4 | ## Prerequisites 5 | Make sure you have: 6 | * Tensorflow 7 | * Imagemagick 8 | * Wand (python imagemagick wrapper) 9 | * Set of pre-classified images (both a training set and a test set) 10 | 11 | ## How It Works 12 | 13 | 1. Put images into ```data/input/```, seperating them into different directories for training data and test data (```data/input/training``` and ```data/input/test``` for example) 14 | 2. Further seperate images into directories representing different classifications (```data/input/[training/test]/class1```, ```data/input/[training/test]/class2``` etc...) 15 | 3. Modify the ```trainingset.json``` and ```testset.json``` to set the height and width of images, as well as adding the classification directories made in (2), giving them a label. The json files have an example already written in, assuming 2 classifications in directories ```data/input/[training/test]/positive``` and ```data/input/[training/test]/negative``` for a binary classification (if output files are modified, their location will also need to be modified in ```tf_data.py``` 16 | 4. Run ```./build_data_set [dataset].json``` which will build the imageset and labelset into the outputfiles specified in the json file (```output/TF[test/train].[labels & images].gz``` by default) 17 | 5. Note - in ```tf_data.py``` you will need to set the number of classes you have defined in the function ```read_data_sets``` 18 | 6. Run ```python classify.py``` to train the model and test the data set. 19 | -------------------------------------------------------------------------------- /classify.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Trains and Evaluates the MNIST network using a feed dictionary.""" 17 | # pylint: disable=missing-docstring 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os.path 23 | import time 24 | 25 | import numpy 26 | from six.moves import xrange # pylint: disable=redefined-builtin 27 | import tensorflow as tf 28 | 29 | import tf_data 30 | import model 31 | 32 | 33 | # Basic model parameters as external flags. 34 | flags = tf.app.flags 35 | FLAGS = flags.FLAGS 36 | flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 37 | flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.') 38 | flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.') 39 | flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.') 40 | flags.DEFINE_integer('batch_size', 50, 'Batch size. ' 41 | 'Must divide evenly into the dataset sizes.') 42 | flags.DEFINE_string('train_dir', 'data/output', 'Directory to put the training data.') 43 | flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' 44 | 'for unit testing.') 45 | 46 | 47 | def placeholder_inputs(batch_size): 48 | """Generate placeholder variables to represent the input tensors. 49 | 50 | These placeholders are used as inputs by the rest of the model building 51 | code and will be fed from the downloaded data in the .run() loop, below. 52 | 53 | Args: 54 | batch_size: The batch size will be baked into both placeholders. 55 | 56 | Returns: 57 | images_placeholder: Images placeholder. 58 | labels_placeholder: Labels placeholder. 59 | """ 60 | # Note that the shapes of the placeholders match the shapes of the full 61 | # image and label tensors, except the first dimension is now batch_size 62 | # rather than the full size of the train or test data sets. 63 | images_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 64 | model.IMAGE_PIXELS)) 65 | labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) 66 | return images_placeholder, labels_placeholder 67 | 68 | 69 | def fill_feed_dict(data_set, images_pl, labels_pl): 70 | """Fills the feed_dict for training the given step. 71 | 72 | A feed_dict takes the form of: 73 | feed_dict = { 74 | : , 75 | .... 76 | } 77 | 78 | Args: 79 | data_set: The set of images and labels, from input_data.read_data_sets() 80 | images_pl: The images placeholder, from placeholder_inputs(). 81 | labels_pl: The labels placeholder, from placeholder_inputs(). 82 | 83 | Returns: 84 | feed_dict: The feed dictionary mapping from placeholders to values. 85 | """ 86 | # Create the feed_dict for the placeholders filled with the next 87 | # `batch size ` examples. 88 | images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size, 89 | FLAGS.fake_data) 90 | 91 | feed_dict = { 92 | images_pl: images_feed, 93 | labels_pl: labels_feed, 94 | } 95 | return feed_dict 96 | 97 | 98 | def do_eval(sess, 99 | eval_correct, 100 | images_placeholder, 101 | labels_placeholder, 102 | data_set): 103 | """Runs one evaluation against the full epoch of data. 104 | 105 | Args: 106 | sess: The session in which the model has been trained. 107 | eval_correct: The Tensor that returns the number of correct predictions. 108 | images_placeholder: The images placeholder. 109 | labels_placeholder: The labels placeholder. 110 | data_set: The set of images and labels to evaluate, from 111 | input_data.read_data_sets(). 112 | """ 113 | # And run one epoch of eval. 114 | true_count = 0 # Counts the number of correct predictions. 115 | steps_per_epoch = data_set.num_examples // FLAGS.batch_size 116 | num_examples = steps_per_epoch * FLAGS.batch_size 117 | for step in xrange(steps_per_epoch): 118 | feed_dict = fill_feed_dict(data_set, 119 | images_placeholder, 120 | labels_placeholder) 121 | true_count += sess.run(eval_correct, feed_dict=feed_dict) 122 | precision = true_count / num_examples 123 | print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % 124 | (num_examples, true_count, precision)) 125 | 126 | 127 | def run_training(): 128 | """Train MNIST for a number of steps.""" 129 | # Get the sets of images and labels for training, validation, and 130 | # test on MNIST. 131 | data_sets = tf_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data) 132 | 133 | # Tell TensorFlow that the model will be built into the default Graph. 134 | with tf.Graph().as_default(): 135 | # Generate placeholders for the images and labels. 136 | images_placeholder, labels_placeholder = placeholder_inputs( 137 | FLAGS.batch_size) 138 | 139 | # Build a Graph that computes predictions from the inference model. 140 | logits = model.inference(images_placeholder, 141 | FLAGS.hidden1, 142 | FLAGS.hidden2) 143 | 144 | # Add to the Graph the Ops for loss calculation. 145 | loss = model.loss(logits, labels_placeholder) 146 | 147 | # Add to the Graph the Ops that calculate and apply gradients. 148 | train_op = model.training(loss, FLAGS.learning_rate) 149 | 150 | # Add the Op to compare the logits to the labels during evaluation. 151 | eval_correct = model.evaluation(logits, labels_placeholder) 152 | 153 | # Build the summary operation based on the TF collection of Summaries. 154 | summary_op = tf.merge_all_summaries() 155 | 156 | # Create a saver for writing training checkpoints. 157 | saver = tf.train.Saver() 158 | 159 | # Create a session for running Ops on the Graph. 160 | sess = tf.Session() 161 | 162 | # Run the Op to initialize the variables. 163 | init = tf.initialize_all_variables() 164 | sess.run(init) 165 | 166 | # Instantiate a SummaryWriter to output summaries and the Graph. 167 | summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, 168 | graph_def=sess.graph_def) 169 | 170 | # And then after everything is built, start the training loop. 171 | for step in xrange(FLAGS.max_steps): 172 | start_time = time.time() 173 | 174 | # Fill a feed dictionary with the actual set of images and labels 175 | # for this particular training step. 176 | feed_dict = fill_feed_dict(data_sets.train, 177 | images_placeholder, 178 | labels_placeholder) 179 | 180 | # Run one step of the model. The return values are the activations 181 | # from the `train_op` (which is discarded) and the `loss` Op. To 182 | # inspect the values of your Ops or variables, you may include them 183 | # in the list passed to sess.run() and the value tensors will be 184 | # returned in the tuple from the call. 185 | _, loss_value = sess.run([train_op, loss], 186 | feed_dict=feed_dict) 187 | 188 | duration = time.time() - start_time 189 | 190 | # Write the summaries and print an overview fairly often. 191 | if step % 100 == 0: 192 | # Print status to stdout. 193 | print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) 194 | # Update the events file. 195 | summary_str = sess.run(summary_op, feed_dict=feed_dict) 196 | summary_writer.add_summary(summary_str, step) 197 | 198 | # Save a checkpoint and evaluate the model periodically. 199 | if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: 200 | saver.save(sess, FLAGS.train_dir, global_step=step) 201 | # Evaluate against the training set. 202 | print('Training Data Eval:') 203 | do_eval(sess, 204 | eval_correct, 205 | images_placeholder, 206 | labels_placeholder, 207 | data_sets.train) 208 | # Evaluate against the validation set. 209 | print('Validation Data Eval:') 210 | do_eval(sess, 211 | eval_correct, 212 | images_placeholder, 213 | labels_placeholder, 214 | data_sets.validation) 215 | # Evaluate against the test set. 216 | print('Test Data Eval:') 217 | do_eval(sess, 218 | eval_correct, 219 | images_placeholder, 220 | labels_placeholder, 221 | data_sets.test) 222 | 223 | 224 | def main(_): 225 | run_training() 226 | 227 | 228 | if __name__ == '__main__': 229 | tf.app.run() -------------------------------------------------------------------------------- /data/built_data_set: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os, sys, subprocess 4 | import tempfile 5 | import shutil 6 | import gzip, zlib 7 | import uuid, math 8 | import json 9 | from wand.image import Image, Color 10 | from pprint import pprint 11 | 12 | # Extend the wand image class to get pixel 13 | class DataImage(Image): 14 | @property 15 | def pixels(self): 16 | pixels = [] 17 | self.depth = 8 18 | blob = self.make_blob(format='RGB') 19 | for cursor in range(0, self.width * self.height * 3, 3): 20 | pixels.append(blob[cursor]) 21 | return bytearray(pixels) 22 | 23 | def convertForModel(self, imgWidth=-1, imgHeight=-1, extension="jpg"): 24 | if imgWidth == -1: imgWidth = self.width 25 | if imgHeight == -1: imgHeight = self.height 26 | self.auto_orient() 27 | ratio = float(self.width)/self.height 28 | target_ratio= float(imgWidth)/imgHeight 29 | 30 | old_width = self.width 31 | old_height = self.height 32 | if ratio > target_ratio: 33 | width = imgWidth 34 | height = int(old_height*imgWidth/old_width) 35 | else: 36 | width = int(old_width*imgHeight/old_height) 37 | height = imgHeight 38 | 39 | width = width+1 if (imgWidth-width)%2 == 1 else width 40 | height = height+1 if (imgHeight-height)%2 == 1 else height 41 | 42 | self.evaluate(operator='median', value=self.quantum_range*0.3) 43 | self.equalize() 44 | self.resize(width, height) 45 | self.format = extension 46 | self.strip() 47 | self.type = "grayscale" 48 | self.gravity = "center" 49 | self.border(Color("black"), (imgWidth-width)/2, (imgHeight-height)/2) 50 | 51 | def update_progress(progress): 52 | sys.stdout.write('\r[{0}{1}] {2}%'.format('#'*(progress/5), ' '* int(math.ceil((100.0-progress)/5)), progress)) 53 | sys.stdout.flush() 54 | 55 | def bytesFromInt(v): 56 | return bytearray([v >> i & 0xff for i in (24,16,8,0)]) 57 | 58 | def write(bytes, file): 59 | file.write(bytes) 60 | 61 | def getDirs(base): 62 | labels = [] 63 | count = 0 64 | 65 | for d in base: 66 | directory = d["directory"] 67 | label = d["label"] 68 | files = [] 69 | label_count = 0 70 | for i in os.listdir(directory): 71 | if i.endswith(".jpg") or i.endswith(".png") or i.endswith(".jpeg"): 72 | files.append("{}/{}".format(directory, i)) 73 | count += 1 74 | label_count += 1 75 | 76 | labels.append({ 77 | "count": label_count, 78 | "files": files, 79 | "label": label 80 | }) 81 | 82 | return { 83 | "count": count, 84 | "labels": labels 85 | } 86 | 87 | def buildTrainingDataSet(options): 88 | 89 | print "Packaging Images and Labels:" 90 | 91 | DUMP=options["dump_dir"] 92 | IMG_WIDTH=options["width"] 93 | IMG_HEIGHT=options["height"] 94 | IMG_EXTENSION=options["image_format"] 95 | DIRS=options["dirs"] 96 | IMAGE_OUTPUT=options["image_dest"] 97 | LABEL_OUTPUT=options["label_dest"] 98 | 99 | MAGIC_IMAGE=2051 100 | MAGIC_LABELS=2049 101 | 102 | if os.path.isdir(DUMP): shutil.rmtree(DUMP) 103 | os.makedirs(DUMP) 104 | 105 | if not os.path.exists(os.path.dirname(IMAGE_OUTPUT)): 106 | os.makedirs(os.path.dirname(IMAGE_OUTPUT)) 107 | 108 | if not os.path.exists(os.path.dirname(LABEL_OUTPUT)): 109 | os.makedirs(os.path.dirname(LABEL_OUTPUT)) 110 | 111 | dirs = getDirs(DIRS) 112 | 113 | image_file = open(IMAGE_OUTPUT, 'w') 114 | label_file = open(LABEL_OUTPUT, 'w') 115 | 116 | write(bytesFromInt(MAGIC_IMAGE), image_file) 117 | write(bytesFromInt(dirs["count"]), image_file) 118 | write(bytesFromInt(IMG_WIDTH), image_file) 119 | write(bytesFromInt(IMG_HEIGHT), image_file) 120 | 121 | write(bytesFromInt(MAGIC_LABELS), label_file) 122 | write(bytesFromInt(dirs["count"]), label_file) 123 | 124 | progress = 0 125 | for label in dirs["labels"]: 126 | for filename in label["files"]: 127 | uname = uuid.uuid4() 128 | with DataImage(filename=filename) as img: 129 | img.convertForModel(imgWidth=IMG_WIDTH, imgHeight=IMG_HEIGHT, extension=IMG_EXTENSION) 130 | img.save(filename="{}/{}.{}".format(DUMP, uname, IMG_EXTENSION)) 131 | 132 | progress += 1 133 | update_progress(int(100 * progress / dirs["count"])) 134 | pixels = img.pixels 135 | 136 | write(pixels, image_file) 137 | write(bytearray([label["label"]] * label["count"]), label_file) 138 | 139 | image_file.close() 140 | label_file.close() 141 | 142 | print 143 | 144 | image_size = os.path.getsize(IMAGE_OUTPUT) 145 | image_expected_size = dirs["count"] * (IMG_WIDTH * IMG_HEIGHT) + 4*4 146 | label_size = os.path.getsize(LABEL_OUTPUT) 147 | label_expected_size = dirs["count"] + 2*4 148 | 149 | if image_size != image_expected_size: 150 | print "Expected exported image file size to be {} bytes but found {} bytes".format(image_expected_size, image_size) 151 | elif label_size != label_expected_size: 152 | print "Expected exported label file size to be {} bytes but found {} bytes".format(label_expected_size, label_size) 153 | else: 154 | subprocess.check_call(["gzip", IMAGE_OUTPUT]) 155 | subprocess.check_call(["gzip", LABEL_OUTPUT]) 156 | print "Successfully built images and labels." 157 | 158 | 159 | def main(): 160 | with open(sys.argv[1]) as data_file: 161 | data = json.load(data_file) 162 | 163 | buildTrainingDataSet(options=data) 164 | 165 | if __name__ == '__main__': 166 | main() 167 | 168 | -------------------------------------------------------------------------------- /data/input/test/negative/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanwebber/tensorflow-image-classification/b1bee24f920457646929b06db663d53639638adf/data/input/test/negative/.gitkeep -------------------------------------------------------------------------------- /data/input/test/positive/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanwebber/tensorflow-image-classification/b1bee24f920457646929b06db663d53639638adf/data/input/test/positive/.gitkeep -------------------------------------------------------------------------------- /data/input/train/negative/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanwebber/tensorflow-image-classification/b1bee24f920457646929b06db663d53639638adf/data/input/train/negative/.gitkeep -------------------------------------------------------------------------------- /data/input/train/positive/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryanwebber/tensorflow-image-classification/b1bee24f920457646929b06db663d53639638adf/data/input/train/positive/.gitkeep -------------------------------------------------------------------------------- /data/testset.json: -------------------------------------------------------------------------------- 1 | { 2 | "image_format": "jpg", 3 | "width": 240, 4 | "height": 240, 5 | "dump_dir": "output/test_images", 6 | "image_dest": "output/TFtest.images", 7 | "label_dest": "output/TFtest.labels", 8 | "dirs": [{ 9 | "directory": "input/test/negative", 10 | "label": 0 11 | }, { 12 | "directory": "input/test/positive", 13 | "label": 1 14 | }] 15 | } -------------------------------------------------------------------------------- /data/trainingset.json: -------------------------------------------------------------------------------- 1 | { 2 | "image_format": "jpg", 3 | "width": 240, 4 | "height": 240, 5 | "dump_dir": "output/train_images", 6 | "image_dest": "output/TFtrain.images", 7 | "label_dest": "output/TFtrain.labels", 8 | "dirs": [{ 9 | "directory": "input/train/negative", 10 | "label": 0 11 | }, { 12 | "directory": "input/train/positive", 13 | "label": 1 14 | }] 15 | } -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builds the MNIST network. 17 | 18 | Implements the inference/loss/training pattern for model building. 19 | 20 | 1. inference() - Builds the model as far as is required for running the network 21 | forward to make predictions. 22 | 2. loss() - Adds to the inference model the layers required to generate loss. 23 | 3. training() - Adds to the loss model the Ops required to generate and 24 | apply gradients. 25 | 26 | This file is used by the various "fully_connected_*.py" files and not meant to 27 | be run. 28 | """ 29 | from __future__ import absolute_import 30 | from __future__ import division 31 | from __future__ import print_function 32 | 33 | import math 34 | 35 | import tensorflow as tf 36 | 37 | # The MNIST dataset has 10 classes, representing the digits 0 through 9. 38 | NUM_CLASSES = 2 39 | 40 | # The MNIST images are always 28x28 pixels. 41 | IMAGE_SIZE = 240 42 | IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE 43 | 44 | 45 | def inference(images, hidden1_units, hidden2_units): 46 | """Build the MNIST model up to where it may be used for inference. 47 | 48 | Args: 49 | images: Images placeholder, from inputs(). 50 | hidden1_units: Size of the first hidden layer. 51 | hidden2_units: Size of the second hidden layer. 52 | 53 | Returns: 54 | softmax_linear: Output tensor with the computed logits. 55 | """ 56 | # Hidden 1 57 | with tf.name_scope('hidden1'): 58 | weights = tf.Variable( 59 | tf.truncated_normal([IMAGE_PIXELS, hidden1_units], 60 | stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))), 61 | name='weights') 62 | biases = tf.Variable(tf.zeros([hidden1_units]), 63 | name='biases') 64 | hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases) 65 | # Hidden 2 66 | with tf.name_scope('hidden2'): 67 | weights = tf.Variable( 68 | tf.truncated_normal([hidden1_units, hidden2_units], 69 | stddev=1.0 / math.sqrt(float(hidden1_units))), 70 | name='weights') 71 | biases = tf.Variable(tf.zeros([hidden2_units]), 72 | name='biases') 73 | hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases) 74 | # Linear 75 | with tf.name_scope('softmax_linear'): 76 | weights = tf.Variable( 77 | tf.truncated_normal([hidden2_units, NUM_CLASSES], 78 | stddev=1.0 / math.sqrt(float(hidden2_units))), 79 | name='weights') 80 | biases = tf.Variable(tf.zeros([NUM_CLASSES]), 81 | name='biases') 82 | logits = tf.matmul(hidden2, weights) + biases 83 | 84 | return logits 85 | 86 | 87 | def loss(logits, labels): 88 | """Calculates the loss from the logits and the labels. 89 | 90 | Args: 91 | logits: Logits tensor, float - [batch_size, NUM_CLASSES]. 92 | labels: Labels tensor, int32 - [batch_size]. 93 | 94 | Returns: 95 | loss: Loss tensor of type float. 96 | """ 97 | 98 | labels = tf.to_int64(labels) 99 | 100 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy') 101 | loss = tf.reduce_mean(cross_entropy, name='xentropy_mean') 102 | return loss 103 | 104 | 105 | def training(loss, learning_rate): 106 | """Sets up the training Ops. 107 | 108 | Creates a summarizer to track the loss over time in TensorBoard. 109 | 110 | Creates an optimizer and applies the gradients to all trainable variables. 111 | 112 | The Op returned by this function is what must be passed to the 113 | `sess.run()` call to cause the model to train. 114 | 115 | Args: 116 | loss: Loss tensor, from loss(). 117 | learning_rate: The learning rate to use for gradient descent. 118 | 119 | Returns: 120 | train_op: The Op for training. 121 | """ 122 | # Add a scalar summary for the snapshot loss. 123 | tf.scalar_summary(loss.op.name, loss) 124 | # Create the gradient descent optimizer with the given learning rate. 125 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 126 | # Create a variable to track the global step. 127 | global_step = tf.Variable(0, name='global_step', trainable=False) 128 | # Use the optimizer to apply the gradients that minimize the loss 129 | # (and also increment the global step counter) as a single training step. 130 | train_op = optimizer.minimize(loss, global_step=global_step) 131 | return train_op 132 | 133 | 134 | def evaluation(logits, labels): 135 | """Evaluate the quality of the logits at predicting the label. 136 | 137 | Args: 138 | logits: Logits tensor, float - [batch_size, NUM_CLASSES]. 139 | labels: Labels tensor, int32 - [batch_size], with values in the 140 | range [0, NUM_CLASSES). 141 | 142 | Returns: 143 | A scalar int32 tensor with the number of examples (out of batch_size) 144 | that were predicted correctly. 145 | """ 146 | # For a classifier model, we can use the in_top_k Op. 147 | # It returns a bool tensor with shape [batch_size] that is true for 148 | # the examples where the label is in the top k (here k=1) 149 | # of all logits for that example. 150 | correct = tf.nn.in_top_k(logits, labels, 1) 151 | # Return the number of true entries. 152 | return tf.reduce_sum(tf.cast(correct, tf.int32)) -------------------------------------------------------------------------------- /tf_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for downloading and reading MNIST data.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from gzip import GzipFile 22 | import gzip 23 | import os 24 | import tempfile 25 | 26 | import numpy 27 | from six.moves import urllib 28 | from six.moves import xrange # pylint: disable=redefined-builtin 29 | import tensorflow as tf 30 | 31 | 32 | def _read32(bytestream): 33 | dt = numpy.dtype(numpy.uint32).newbyteorder('>') 34 | return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 35 | 36 | 37 | def extract_images(filename): 38 | """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" 39 | print('Extracting', filename) 40 | with open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream: 41 | magic = _read32(bytestream) 42 | if magic != 2051: 43 | raise ValueError( 44 | 'Invalid magic number %d in MNIST image file: %s' % 45 | (magic, filename)) 46 | num_images = _read32(bytestream) 47 | rows = _read32(bytestream) 48 | cols = _read32(bytestream) 49 | buf = bytestream.read(rows * cols * num_images) 50 | data = numpy.frombuffer(buf, dtype=numpy.uint8) 51 | data = data.reshape(num_images, rows, cols, 1) 52 | return data 53 | 54 | 55 | def dense_to_one_hot(labels_dense, num_classes): 56 | """Convert class labels from scalars to one-hot vectors.""" 57 | num_labels = labels_dense.shape[0] 58 | index_offset = numpy.arange(num_labels) * num_classes 59 | labels_one_hot = numpy.zeros((num_labels, num_classes)) 60 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 61 | return labels_one_hot 62 | 63 | 64 | def extract_labels(filename, one_hot=False, num_classes=10): 65 | """Extract the labels into a 1D uint8 numpy array [index].""" 66 | print('Extracting', filename) 67 | with open(filename, 'rb') as f, gzip.GzipFile(fileobj=f) as bytestream: 68 | magic = _read32(bytestream) 69 | if magic != 2049: 70 | raise ValueError( 71 | 'Invalid magic number %d in MNIST label file: %s' % 72 | (magic, filename)) 73 | num_items = _read32(bytestream) 74 | buf = bytestream.read(num_items) 75 | labels = numpy.frombuffer(buf, dtype=numpy.uint8) 76 | if one_hot: 77 | return dense_to_one_hot(labels, num_classes) 78 | return labels 79 | 80 | 81 | class DataSet(object): 82 | 83 | def __init__(self, images, labels, fake_data=False, one_hot=False, 84 | dtype=tf.float32): 85 | """Construct a DataSet. 86 | one_hot arg is used only if fake_data is true. `dtype` can be either 87 | `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into 88 | `[0, 1]`. 89 | """ 90 | dtype = tf.as_dtype(dtype).base_dtype 91 | if dtype not in (tf.uint8, tf.float32): 92 | raise TypeError('Invalid image dtype %r, expected uint8 or float32' % 93 | dtype) 94 | if fake_data: 95 | self._num_examples = 10000 96 | self.one_hot = one_hot 97 | else: 98 | assert images.shape[0] == labels.shape[0], ( 99 | 'images.shape: %s labels.shape: %s' % (images.shape, 100 | labels.shape)) 101 | self._num_examples = images.shape[0] 102 | 103 | # Convert shape from [num examples, rows, columns, depth] 104 | # to [num examples, rows*columns] (assuming depth == 1) 105 | assert images.shape[3] == 1 106 | images = images.reshape(images.shape[0], 107 | images.shape[1] * images.shape[2]) 108 | if dtype == tf.float32: 109 | # Convert from [0, 255] -> [0.0, 1.0]. 110 | images = images.astype(numpy.float32) 111 | images = numpy.multiply(images, 1.0 / 255.0) 112 | 113 | self._images = images 114 | self._labels = labels 115 | self._epochs_completed = 0 116 | self._index_in_epoch = 0 117 | 118 | @property 119 | def images(self): 120 | return self._images 121 | 122 | @property 123 | def labels(self): 124 | return self._labels 125 | 126 | @property 127 | def num_examples(self): 128 | return self._num_examples 129 | 130 | @property 131 | def epochs_completed(self): 132 | return self._epochs_completed 133 | 134 | def next_batch(self, batch_size, fake_data=False): 135 | """Return the next `batch_size` examples from this data set.""" 136 | if fake_data: 137 | fake_image = [1] * 784 138 | if self.one_hot: 139 | fake_label = [1] + [0] * 9 140 | else: 141 | fake_label = 0 142 | return [fake_image for _ in xrange(batch_size)], [ 143 | fake_label for _ in xrange(batch_size)] 144 | start = self._index_in_epoch 145 | self._index_in_epoch += batch_size 146 | if self._index_in_epoch > self._num_examples: 147 | # Finished epoch 148 | self._epochs_completed += 1 149 | # Shuffle the data 150 | perm = numpy.arange(self._num_examples) 151 | numpy.random.shuffle(perm) 152 | self._images = self._images[perm] 153 | self._labels = self._labels[perm] 154 | # Start next epoch 155 | start = 0 156 | self._index_in_epoch = batch_size 157 | assert batch_size <= self._num_examples 158 | end = self._index_in_epoch 159 | return self._images[start:end], self._labels[start:end] 160 | 161 | def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): 162 | class DataSets(object): 163 | pass 164 | data_sets = DataSets() 165 | 166 | if fake_data: 167 | def fake(): 168 | return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) 169 | data_sets.train = fake() 170 | data_sets.validation = fake() 171 | data_sets.test = fake() 172 | return data_sets 173 | 174 | TRAIN_IMAGES = 'TFtrain.images.gz' 175 | TRAIN_LABELS = 'TFtrain.labels.gz' 176 | TEST_IMAGES = 'TFtest.images.gz' 177 | TEST_LABELS = 'TFtest.labels.gz' 178 | VALIDATION_SIZE = 100 179 | 180 | local_file = os.path.join(train_dir, TRAIN_IMAGES) 181 | train_images = extract_images(local_file) 182 | 183 | local_file = os.path.join(train_dir, TRAIN_LABELS) 184 | train_labels = extract_labels(local_file, one_hot=one_hot, num_classes=2) 185 | 186 | local_file = os.path.join(train_dir, TEST_IMAGES) 187 | test_images = extract_images(local_file) 188 | 189 | local_file = os.path.join(train_dir, TEST_LABELS) 190 | test_labels = extract_labels(local_file, one_hot=one_hot, num_classes=2) 191 | 192 | validation_images = train_images[:VALIDATION_SIZE] 193 | validation_labels = train_labels[:VALIDATION_SIZE] 194 | train_images = train_images[VALIDATION_SIZE:] 195 | train_labels = train_labels[VALIDATION_SIZE:] 196 | 197 | data_sets.train = DataSet(train_images, train_labels, dtype=dtype) 198 | data_sets.validation = DataSet(validation_images, validation_labels, 199 | dtype=dtype) 200 | data_sets.test = DataSet(test_images, test_labels, dtype=dtype) 201 | 202 | return data_sets 203 | 204 | --------------------------------------------------------------------------------