├── BUILD ├── README.md ├── __init__.py ├── cifar10.py ├── cifar10_eval.py ├── cifar10_input.py ├── cifar10_input_test.py ├── cifar10_multi_gpu_train.py ├── cifar10_train.py └── compat.py /BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Example TensorFlow models for CIFAR-10 3 | licenses(["notice"]) # Apache 2.0 4 | exports_files(["LICENSE"]) 5 | py_library( 6 | name = "cifar10_input", 7 | srcs = ["cifar10_input.py"], 8 | srcs_version = "PY2AND3", 9 | visibility = ["//tensorflow:internal"], 10 | deps = [ 11 | "//tensorflow:tensorflow_py", 12 | ], 13 | ) 14 | py_test( 15 | name = "cifar10_input_test", 16 | srcs = ["cifar10_input_test.py"], 17 | srcs_version = "PY2AND3", 18 | deps = [ 19 | ":cifar10_input", 20 | "//tensorflow:tensorflow_py", 21 | "//tensorflow/python:framework_test_lib", 22 | "//tensorflow/python:platform_test", 23 | ], 24 | ) 25 | py_library( 26 | name = "cifar10", 27 | srcs = ["cifar10.py"], 28 | srcs_version = "PY2AND3", 29 | deps = [ 30 | ":cifar10_input", 31 | "//tensorflow:tensorflow_py", 32 | ], 33 | ) 34 | py_binary( 35 | name = "cifar10_eval", 36 | srcs = [ 37 | "cifar10_eval.py", 38 | ], 39 | srcs_version = "PY2AND3", 40 | visibility = ["//tensorflow:__subpackages__"], 41 | deps = [ 42 | ":cifar10", 43 | ], 44 | ) 45 | py_binary( 46 | name = "cifar10_train", 47 | srcs = [ 48 | "cifar10_train.py", 49 | ], 50 | srcs_version = "PY2AND3", 51 | visibility = ["//tensorflow:__subpackages__"], 52 | deps = [ 53 | ":cifar10", 54 | ], 55 | ) 56 | py_binary( 57 | name = "cifar10_multi_gpu_train", 58 | srcs = [ 59 | "cifar10_multi_gpu_train.py", 60 | ], 61 | srcs_version = "PY2AND3", 62 | visibility = ["//tensorflow:__subpackages__"], 63 | deps = [ 64 | ":cifar10", 65 | ], 66 | ) 67 | filegroup( 68 | name = "all_files", 69 | srcs = glob( 70 | ["**/*"], 71 | exclude = [ 72 | "**/METADATA", 73 | "**/OWNERS", 74 | ], 75 | ), 76 | visibility = ["//tensorflow:__subpackages__"], 77 | ) 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # tensorflow.cifar10 2 | The examples of image recognition with the dataset CIFAR10 via tensorflow. 3 | 4 | ### **1 CIFAR-10 数据集** 5 | CIFAR-10数据集是机器学习中的一个通用的用于图像识别的基础数据集,官网链接为:[The CIFAR-10 dataset](http://www.cs.toronto.edu/~kriz/cifar.html) 6 | 7 | ![cifar10](https://img-blog.csdn.net/20160226153743929) 8 | 9 | 下载使用的版本是: 10 | 11 | ![version](https://img-blog.csdn.net/20160226160254186) 12 | 13 | 将其解压后(代码中包含自动解压代码),内容为: 14 | 15 | ![cifar10 data](https://img-blog.csdn.net/20160226160343415) 16 | 17 | ![cifar10 data2](https://img-blog.csdn.net/20160226160405884) 18 | 19 | ### **2 测试代码** 20 | 测试代码公布在GitHub:[yhlleo](https://github.com/yhlleo/tensorflow.cifar10) 21 | 22 | 主要代码及作用: 23 | 24 | |文件|作用| 25 | |:--------------|:--------------| 26 | |`cifar10_input.py`|读取本地或者在线下载CIFAR-10的二进制文件格式数据集| 27 | |`cifar10.py`|建立CIFAR-10的模型| 28 | |`cifar10_train.py`|在CPU或GPU上训练CIFAR-10的模型| 29 | |`cifar10_multi_gpu_train.py`|在多个GPU上训练CIFAR-10的模型| 30 | |`cifar10_eval.py`|评估CIFAR-10模型的预测性能| 31 | 32 | 33 | 该部分的代码,介绍了如何使用TensorFlow在CPU和GPU上训练和评估卷积神经网络(convolutional neural network, CNN)。 34 | 35 | ### **3 相关网页及教程** 36 | 更加详细地介绍说明,请浏览网页:[Convolutional Neural Networks](http://tensorflow.org/tutorials/deep_cnn/) 37 | 38 | 中文网站极客学院也有该部分的汉译版:[卷积神经网络](http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/deep_cnn.html) 39 | 40 | 代码源自tensorflow官网:[tensorflow/models/image/cifar10](https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10) 41 | 42 | ### **4 代码修改说明** 43 | GitHub公布代码相对源码,主要进行了以下修正: 44 | 45 | - **`cifar10.py`** 46 | 47 | ```python 48 | #indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1]) 49 | indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1]) 50 | 51 | # or 52 | indices = tf.reshape(tf.range(0, FLAGS.batch_size, 1), [FLAGS.batch_size, 1]) 53 | ``` 54 | 55 | 此处,源码编译时会出现以下错误: 56 | 57 | ```python 58 | ... 59 | File ".../cifar10.py", line 271, in loss 60 | indices = tf.reshape(tf.range(FLAGS.batch_size), [FLAGS.batch_size, 1]) 61 | TypeError: range() takes at least 2 arguments (1 given) 62 | ``` 63 | 64 | - **`cifar10_input_test.py`** 65 | 66 | ```python 67 | #self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key)) 68 | 69 | import compat as cp 70 | ... 71 | 72 | self.assertEqual("%s:%d" % (filename, i), cp.as_text(key)) 73 | ``` 74 | 75 | 不然的话,我测试的时候就会出现这的错误: 76 | 77 | ```python 78 | AttributeError: 'module' object has no attribute 'compat' 79 | ``` 80 | 81 | - **`cifar10_train.py`**和**`cifar10_multi_gpu_train.py`** 82 | 83 | 源代码里的最大迭代次数`max_steps`为`1000000`,需要训练几个小时,不忍心折腾我的破笔记本,就改为了`20000`。 84 | 85 | 其他改动,例如导入模块或者文件路径等,都很容易理解,就不列举了~ 86 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builds the CIFAR-10 network. 17 | 18 | Summary of available functions: 19 | 20 | # Compute input images and labels for training. If you would like to run 21 | # evaluations, use input() instead. 22 | inputs, labels = distorted_inputs() 23 | 24 | # Compute inference on the model inputs to make a prediction. 25 | predictions = inference(inputs) 26 | 27 | # Compute the total loss of the prediction with respect to the labels. 28 | loss = loss(predictions, labels) 29 | 30 | # Create a graph to run one step of training with respect to the loss. 31 | train_op = train(loss, global_step) 32 | """ 33 | # pylint: disable=missing-docstring 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import gzip 39 | import os 40 | import re 41 | import sys 42 | import tarfile 43 | 44 | import tensorflow.python.platform 45 | from six.moves import urllib 46 | import tensorflow as tf 47 | 48 | #from tensorflow.models.image.cifar10 import cifar10_input 49 | import cifar10_input 50 | 51 | FLAGS = tf.app.flags.FLAGS 52 | 53 | # Basic model parameters. 54 | tf.app.flags.DEFINE_integer('batch_size', 128, 55 | """Number of images to process in a batch.""") 56 | tf.app.flags.DEFINE_string('data_dir', 'cifar10_data/', 57 | """Path to the CIFAR-10 data directory.""") 58 | 59 | # Global constants describing the CIFAR-10 data set. 60 | IMAGE_SIZE = cifar10_input.IMAGE_SIZE 61 | NUM_CLASSES = cifar10_input.NUM_CLASSES 62 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 63 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL 64 | 65 | 66 | # Constants describing the training process. 67 | MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. 68 | NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. 69 | LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. 70 | INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. 71 | 72 | # If a model is trained with multiple GPU's prefix all Op names with tower_name 73 | # to differentiate the operations. Note that this prefix is removed from the 74 | # names of the summaries when visualizing a model. 75 | TOWER_NAME = 'tower' 76 | 77 | DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 78 | 79 | 80 | def _activation_summary(x): 81 | """Helper to create summaries for activations. 82 | 83 | Creates a summary that provides a histogram of activations. 84 | Creates a summary that measure the sparsity of activations. 85 | 86 | Args: 87 | x: Tensor 88 | Returns: 89 | nothing 90 | """ 91 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 92 | # session. This helps the clarity of presentation on tensorboard. 93 | tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) 94 | tf.histogram_summary(tensor_name + '/activations', x) 95 | tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x)) 96 | 97 | 98 | def _variable_on_cpu(name, shape, initializer): 99 | """Helper to create a Variable stored on CPU memory. 100 | 101 | Args: 102 | name: name of the variable 103 | shape: list of ints 104 | initializer: initializer for Variable 105 | 106 | Returns: 107 | Variable Tensor 108 | """ 109 | with tf.device('/cpu:0'): 110 | var = tf.get_variable(name, shape, initializer=initializer) 111 | return var 112 | 113 | 114 | def _variable_with_weight_decay(name, shape, stddev, wd): 115 | """Helper to create an initialized Variable with weight decay. 116 | 117 | Note that the Variable is initialized with a truncated normal distribution. 118 | A weight decay is added only if one is specified. 119 | 120 | Args: 121 | name: name of the variable 122 | shape: list of ints 123 | stddev: standard deviation of a truncated Gaussian 124 | wd: add L2Loss weight decay multiplied by this float. If None, weight 125 | decay is not added for this Variable. 126 | 127 | Returns: 128 | Variable Tensor 129 | """ 130 | var = _variable_on_cpu(name, shape, 131 | tf.truncated_normal_initializer(stddev=stddev)) 132 | if wd: 133 | weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') 134 | tf.add_to_collection('losses', weight_decay) 135 | return var 136 | 137 | 138 | def distorted_inputs(): 139 | """Construct distorted input for CIFAR training using the Reader ops. 140 | 141 | Returns: 142 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 143 | labels: Labels. 1D tensor of [batch_size] size. 144 | 145 | Raises: 146 | ValueError: If no data_dir 147 | """ 148 | if not FLAGS.data_dir: 149 | raise ValueError('Please supply a data_dir') 150 | data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') 151 | return cifar10_input.distorted_inputs(data_dir=data_dir, 152 | batch_size=FLAGS.batch_size) 153 | 154 | 155 | def inputs(eval_data): 156 | """Construct input for CIFAR evaluation using the Reader ops. 157 | 158 | Args: 159 | eval_data: bool, indicating if one should use the train or eval data set. 160 | 161 | Returns: 162 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 163 | labels: Labels. 1D tensor of [batch_size] size. 164 | 165 | Raises: 166 | ValueError: If no data_dir 167 | """ 168 | if not FLAGS.data_dir: 169 | raise ValueError('Please supply a data_dir') 170 | data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') 171 | return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir, 172 | batch_size=FLAGS.batch_size) 173 | 174 | 175 | def inference(images): 176 | """Build the CIFAR-10 model. 177 | 178 | Args: 179 | images: Images returned from distorted_inputs() or inputs(). 180 | 181 | Returns: 182 | Logits. 183 | """ 184 | # We instantiate all variables using tf.get_variable() instead of 185 | # tf.Variable() in order to share variables across multiple GPU training runs. 186 | # If we only ran this model on a single GPU, we could simplify this function 187 | # by replacing all instances of tf.get_variable() with tf.Variable(). 188 | # 189 | # conv1 190 | with tf.variable_scope('conv1') as scope: 191 | kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64], 192 | stddev=1e-4, wd=0.0) 193 | conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME') 194 | biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) 195 | bias = tf.nn.bias_add(conv, biases) 196 | conv1 = tf.nn.relu(bias, name=scope.name) 197 | _activation_summary(conv1) 198 | 199 | # pool1 200 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], 201 | padding='SAME', name='pool1') 202 | # norm1 203 | norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 204 | name='norm1') 205 | 206 | # conv2 207 | with tf.variable_scope('conv2') as scope: 208 | kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64], 209 | stddev=1e-4, wd=0.0) 210 | conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME') 211 | biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) 212 | bias = tf.nn.bias_add(conv, biases) 213 | conv2 = tf.nn.relu(bias, name=scope.name) 214 | _activation_summary(conv2) 215 | 216 | # norm2 217 | norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 218 | name='norm2') 219 | # pool2 220 | pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], 221 | strides=[1, 2, 2, 1], padding='SAME', name='pool2') 222 | 223 | # local3 224 | with tf.variable_scope('local3') as scope: 225 | # Move everything into depth so we can perform a single matrix multiply. 226 | dim = 1 227 | for d in pool2.get_shape()[1:].as_list(): 228 | dim *= d 229 | reshape = tf.reshape(pool2, [FLAGS.batch_size, dim]) 230 | 231 | weights = _variable_with_weight_decay('weights', shape=[dim, 384], 232 | stddev=0.04, wd=0.004) 233 | biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) 234 | local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name) 235 | _activation_summary(local3) 236 | 237 | # local4 238 | with tf.variable_scope('local4') as scope: 239 | weights = _variable_with_weight_decay('weights', shape=[384, 192], 240 | stddev=0.04, wd=0.004) 241 | biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) 242 | local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name) 243 | _activation_summary(local4) 244 | 245 | # softmax, i.e. softmax(WX + b) 246 | with tf.variable_scope('softmax_linear') as scope: 247 | weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES], 248 | stddev=1/192.0, wd=0.0) 249 | biases = _variable_on_cpu('biases', [NUM_CLASSES], 250 | tf.constant_initializer(0.0)) 251 | softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name) 252 | _activation_summary(softmax_linear) 253 | 254 | return softmax_linear 255 | 256 | 257 | def loss(logits, labels): 258 | """Add L2Loss to all the trainable variables. 259 | 260 | Add summary for for "Loss" and "Loss/avg". 261 | Args: 262 | logits: Logits from inference(). 263 | labels: Labels from distorted_inputs or inputs(). 1-D tensor 264 | of shape [batch_size] 265 | 266 | Returns: 267 | Loss tensor of type float. 268 | """ 269 | # Reshape the labels into a dense Tensor of 270 | # shape [batch_size, NUM_CLASSES]. 271 | sparse_labels = tf.reshape(labels, [FLAGS.batch_size, 1]) 272 | # indices = tf.reshape(tfrange(FLAGS.batch_size), [FLAGS.batch_size, 1]) 273 | indices = tf.reshape(range(FLAGS.batch_size), [FLAGS.batch_size, 1]) 274 | concated = tf.concat(1, [indices, sparse_labels]) 275 | dense_labels = tf.sparse_to_dense(concated, 276 | [FLAGS.batch_size, NUM_CLASSES], 277 | 1.0, 0.0) 278 | 279 | # Calculate the average cross entropy loss across the batch. 280 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 281 | logits, dense_labels, name='cross_entropy_per_example') 282 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 283 | tf.add_to_collection('losses', cross_entropy_mean) 284 | 285 | # The total loss is defined as the cross entropy loss plus all of the weight 286 | # decay terms (L2 loss). 287 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 288 | 289 | 290 | def _add_loss_summaries(total_loss): 291 | """Add summaries for losses in CIFAR-10 model. 292 | 293 | Generates moving average for all losses and associated summaries for 294 | visualizing the performance of the network. 295 | 296 | Args: 297 | total_loss: Total loss from loss(). 298 | Returns: 299 | loss_averages_op: op for generating moving averages of losses. 300 | """ 301 | # Compute the moving average of all individual losses and the total loss. 302 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') 303 | losses = tf.get_collection('losses') 304 | loss_averages_op = loss_averages.apply(losses + [total_loss]) 305 | 306 | # Attach a scalar summary to all individual losses and the total loss; do the 307 | # same for the averaged version of the losses. 308 | for l in losses + [total_loss]: 309 | # Name each loss as '(raw)' and name the moving average version of the loss 310 | # as the original loss name. 311 | tf.scalar_summary(l.op.name +' (raw)', l) 312 | tf.scalar_summary(l.op.name, loss_averages.average(l)) 313 | 314 | return loss_averages_op 315 | 316 | 317 | def train(total_loss, global_step): 318 | """Train CIFAR-10 model. 319 | 320 | Create an optimizer and apply to all trainable variables. Add moving 321 | average for all trainable variables. 322 | 323 | Args: 324 | total_loss: Total loss from loss(). 325 | global_step: Integer Variable counting the number of training steps 326 | processed. 327 | Returns: 328 | train_op: op for training. 329 | """ 330 | # Variables that affect learning rate. 331 | num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size 332 | decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) 333 | 334 | # Decay the learning rate exponentially based on the number of steps. 335 | lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, 336 | global_step, 337 | decay_steps, 338 | LEARNING_RATE_DECAY_FACTOR, 339 | staircase=True) 340 | tf.scalar_summary('learning_rate', lr) 341 | 342 | # Generate moving averages of all losses and associated summaries. 343 | loss_averages_op = _add_loss_summaries(total_loss) 344 | 345 | # Compute gradients. 346 | with tf.control_dependencies([loss_averages_op]): 347 | opt = tf.train.GradientDescentOptimizer(lr) 348 | grads = opt.compute_gradients(total_loss) 349 | 350 | # Apply gradients. 351 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 352 | 353 | # Add histograms for trainable variables. 354 | for var in tf.trainable_variables(): 355 | tf.histogram_summary(var.op.name, var) 356 | 357 | # Add histograms for gradients. 358 | for grad, var in grads: 359 | if grad: 360 | tf.histogram_summary(var.op.name + '/gradients', grad) 361 | 362 | # Track the moving averages of all trainable variables. 363 | variable_averages = tf.train.ExponentialMovingAverage( 364 | MOVING_AVERAGE_DECAY, global_step) 365 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 366 | 367 | with tf.control_dependencies([apply_gradient_op, variables_averages_op]): 368 | train_op = tf.no_op(name='train') 369 | 370 | return train_op 371 | 372 | 373 | def maybe_download_and_extract(): 374 | """Download and extract the tarball from Alex's website.""" 375 | dest_directory = FLAGS.data_dir 376 | if not os.path.exists(dest_directory): 377 | os.mkdir(dest_directory) 378 | filename = DATA_URL.split('/')[-1] 379 | filepath = os.path.join(dest_directory, filename) 380 | if not os.path.exists(filepath): 381 | def _progress(count, block_size, total_size): 382 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, 383 | float(count * block_size) / float(total_size) * 100.0)) 384 | sys.stdout.flush() 385 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, 386 | reporthook=_progress) 387 | print() 388 | statinfo = os.stat(filepath) 389 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 390 | tarfile.open(filepath, 'r:gz').extractall(dest_directory) 391 | -------------------------------------------------------------------------------- /cifar10_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Evaluation for CIFAR-10. 17 | 18 | Accuracy: 19 | cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs 20 | of data) as judged by cifar10_eval.py. 21 | 22 | Speed: 23 | On a single Tesla K40, cifar10_train.py processes a single batch of 128 images 24 | in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86% 25 | accuracy after 100K steps in 8 hours of training time. 26 | 27 | Usage: 28 | Please see the tutorial and website for how to download the CIFAR-10 29 | data set, compile the program and train the model. 30 | 31 | http://tensorflow.org/tutorials/deep_cnn/ 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | from datetime import datetime 38 | import math 39 | import time 40 | 41 | import tensorflow.python.platform 42 | from tensorflow.python.platform import gfile 43 | import numpy as np 44 | import tensorflow as tf 45 | 46 | #from tensorflow.models.image.cifar10 import cifar10 47 | import cifar10 48 | 49 | FLAGS = tf.app.flags.FLAGS 50 | 51 | tf.app.flags.DEFINE_string('eval_dir', 'cifar10_eval/', 52 | """Directory where to write event logs.""") 53 | tf.app.flags.DEFINE_string('eval_data', 'test', 54 | """Either 'test' or 'train_eval'.""") 55 | tf.app.flags.DEFINE_string('checkpoint_dir', 'cifar10_train/', 56 | """Directory where to read model checkpoints.""") 57 | tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5, 58 | """How often to run the eval.""") 59 | tf.app.flags.DEFINE_integer('num_examples', 10000, 60 | """Number of examples to run.""") 61 | tf.app.flags.DEFINE_boolean('run_once', False, 62 | """Whether to run eval only once.""") 63 | 64 | 65 | def eval_once(saver, summary_writer, top_k_op, summary_op): 66 | """Run Eval once. 67 | 68 | Args: 69 | saver: Saver. 70 | summary_writer: Summary writer. 71 | top_k_op: Top K op. 72 | summary_op: Summary op. 73 | """ 74 | with tf.Session() as sess: 75 | ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 76 | if ckpt and ckpt.model_checkpoint_path: 77 | # Restores from checkpoint 78 | saver.restore(sess, ckpt.model_checkpoint_path) 79 | # Assuming model_checkpoint_path looks something like: 80 | # /my-favorite-path/cifar10_train/model.ckpt-0, 81 | # extract global_step from it. 82 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 83 | else: 84 | print('No checkpoint file found') 85 | return 86 | 87 | # Start the queue runners. 88 | coord = tf.train.Coordinator() 89 | try: 90 | threads = [] 91 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 92 | threads.extend(qr.create_threads(sess, coord=coord, daemon=True, 93 | start=True)) 94 | 95 | num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size)) 96 | true_count = 0 # Counts the number of correct predictions. 97 | total_sample_count = num_iter * FLAGS.batch_size 98 | step = 0 99 | while step < num_iter and not coord.should_stop(): 100 | predictions = sess.run([top_k_op]) 101 | true_count += np.sum(predictions) 102 | step += 1 103 | 104 | # Compute precision @ 1. 105 | precision = true_count / total_sample_count 106 | print('%s: precision @ 1 = %.3f' % (datetime.now(), precision)) 107 | 108 | summary = tf.Summary() 109 | summary.ParseFromString(sess.run(summary_op)) 110 | summary.value.add(tag='Precision @ 1', simple_value=precision) 111 | summary_writer.add_summary(summary, global_step) 112 | except Exception as e: # pylint: disable=broad-except 113 | coord.request_stop(e) 114 | 115 | coord.request_stop() 116 | coord.join(threads, stop_grace_period_secs=10) 117 | 118 | 119 | def evaluate(): 120 | """Eval CIFAR-10 for a number of steps.""" 121 | with tf.Graph().as_default(): 122 | # Get images and labels for CIFAR-10. 123 | eval_data = FLAGS.eval_data == 'test' 124 | images, labels = cifar10.inputs(eval_data=eval_data) 125 | 126 | # Build a Graph that computes the logits predictions from the 127 | # inference model. 128 | logits = cifar10.inference(images) 129 | 130 | # Calculate predictions. 131 | top_k_op = tf.nn.in_top_k(logits, labels, 1) 132 | 133 | # Restore the moving average version of the learned variables for eval. 134 | variable_averages = tf.train.ExponentialMovingAverage( 135 | cifar10.MOVING_AVERAGE_DECAY) 136 | variables_to_restore = variable_averages.variables_to_restore() 137 | saver = tf.train.Saver(variables_to_restore) 138 | 139 | # Build the summary operation based on the TF collection of Summaries. 140 | summary_op = tf.merge_all_summaries() 141 | 142 | graph_def = tf.get_default_graph().as_graph_def() 143 | summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir, 144 | graph_def=graph_def) 145 | 146 | while True: 147 | eval_once(saver, summary_writer, top_k_op, summary_op) 148 | if FLAGS.run_once: 149 | break 150 | time.sleep(FLAGS.eval_interval_secs) 151 | 152 | 153 | def main(argv=None): # pylint: disable=unused-argument 154 | cifar10.maybe_download_and_extract() 155 | if gfile.Exists(FLAGS.eval_dir): 156 | gfile.DeleteRecursively(FLAGS.eval_dir) 157 | gfile.MakeDirs(FLAGS.eval_dir) 158 | evaluate() 159 | 160 | 161 | if __name__ == '__main__': 162 | tf.app.run() 163 | -------------------------------------------------------------------------------- /cifar10_input.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Routine for decoding the CIFAR-10 binary file format.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import tensorflow.python.platform 25 | from six.moves import xrange # pylint: disable=redefined-builtin 26 | import tensorflow as tf 27 | 28 | from tensorflow.python.platform import gfile 29 | 30 | # Process images of this size. Note that this differs from the original CIFAR 31 | # image size of 32 x 32. If one alters this number, then the entire model 32 | # architecture will change and any model would need to be retrained. 33 | IMAGE_SIZE = 24 34 | 35 | # Global constants describing the CIFAR-10 data set. 36 | NUM_CLASSES = 10 37 | NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 38 | NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 39 | 40 | 41 | def read_cifar10(filename_queue): 42 | """Reads and parses examples from CIFAR10 data files. 43 | 44 | Recommendation: if you want N-way read parallelism, call this function 45 | N times. This will give you N independent Readers reading different 46 | files & positions within those files, which will give better mixing of 47 | examples. 48 | 49 | Args: 50 | filename_queue: A queue of strings with the filenames to read from. 51 | 52 | Returns: 53 | An object representing a single example, with the following fields: 54 | height: number of rows in the result (32) 55 | width: number of columns in the result (32) 56 | depth: number of color channels in the result (3) 57 | key: a scalar string Tensor describing the filename & record number 58 | for this example. 59 | label: an int32 Tensor with the label in the range 0..9. 60 | uint8image: a [height, width, depth] uint8 Tensor with the image data 61 | """ 62 | 63 | class CIFAR10Record(object): 64 | pass 65 | result = CIFAR10Record() 66 | 67 | # Dimensions of the images in the CIFAR-10 dataset. 68 | # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the 69 | # input format. 70 | label_bytes = 1 # 2 for CIFAR-100 71 | result.height = 32 72 | result.width = 32 73 | result.depth = 3 74 | image_bytes = result.height * result.width * result.depth 75 | # Every record consists of a label followed by the image, with a 76 | # fixed number of bytes for each. 77 | record_bytes = label_bytes + image_bytes 78 | 79 | # Read a record, getting filenames from the filename_queue. No 80 | # header or footer in the CIFAR-10 format, so we leave header_bytes 81 | # and footer_bytes at their default of 0. 82 | reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) 83 | result.key, value = reader.read(filename_queue) 84 | 85 | # Convert from a string to a vector of uint8 that is record_bytes long. 86 | record_bytes = tf.decode_raw(value, tf.uint8) 87 | 88 | # The first bytes represent the label, which we convert from uint8->int32. 89 | result.label = tf.cast( 90 | tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 91 | 92 | # The remaining bytes after the label represent the image, which we reshape 93 | # from [depth * height * width] to [depth, height, width]. 94 | depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), 95 | [result.depth, result.height, result.width]) 96 | # Convert from [depth, height, width] to [height, width, depth]. 97 | result.uint8image = tf.transpose(depth_major, [1, 2, 0]) 98 | 99 | return result 100 | 101 | 102 | def _generate_image_and_label_batch(image, label, min_queue_examples, 103 | batch_size): 104 | """Construct a queued batch of images and labels. 105 | 106 | Args: 107 | image: 3-D Tensor of [height, width, 3] of type.float32. 108 | label: 1-D Tensor of type.int32 109 | min_queue_examples: int32, minimum number of samples to retain 110 | in the queue that provides of batches of examples. 111 | batch_size: Number of images per batch. 112 | 113 | Returns: 114 | images: Images. 4D tensor of [batch_size, height, width, 3] size. 115 | labels: Labels. 1D tensor of [batch_size] size. 116 | """ 117 | # Create a queue that shuffles the examples, and then 118 | # read 'batch_size' images + labels from the example queue. 119 | num_preprocess_threads = 16 120 | images, label_batch = tf.train.shuffle_batch( 121 | [image, label], 122 | batch_size=batch_size, 123 | num_threads=num_preprocess_threads, 124 | capacity=min_queue_examples + 3 * batch_size, 125 | min_after_dequeue=min_queue_examples) 126 | 127 | # Display the training images in the visualizer. 128 | tf.image_summary('images', images) 129 | 130 | return images, tf.reshape(label_batch, [batch_size]) 131 | 132 | 133 | def distorted_inputs(data_dir, batch_size): 134 | """Construct distorted input for CIFAR training using the Reader ops. 135 | 136 | Args: 137 | data_dir: Path to the CIFAR-10 data directory. 138 | batch_size: Number of images per batch. 139 | 140 | Returns: 141 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 142 | labels: Labels. 1D tensor of [batch_size] size. 143 | """ 144 | filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) 145 | for i in xrange(1, 6)] 146 | for f in filenames: 147 | if not gfile.Exists(f): 148 | raise ValueError('Failed to find file: ' + f) 149 | 150 | # Create a queue that produces the filenames to read. 151 | filename_queue = tf.train.string_input_producer(filenames) 152 | 153 | # Read examples from files in the filename queue. 154 | read_input = read_cifar10(filename_queue) 155 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 156 | 157 | height = IMAGE_SIZE 158 | width = IMAGE_SIZE 159 | 160 | # Image processing for training the network. Note the many random 161 | # distortions applied to the image. 162 | 163 | # Randomly crop a [height, width] section of the image. 164 | distorted_image = tf.image.random_crop(reshaped_image, [height, width]) 165 | 166 | # Randomly flip the image horizontally. 167 | distorted_image = tf.image.random_flip_left_right(distorted_image) 168 | 169 | # Because these operations are not commutative, consider randomizing 170 | # randomize the order their operation. 171 | distorted_image = tf.image.random_brightness(distorted_image, 172 | max_delta=63) 173 | distorted_image = tf.image.random_contrast(distorted_image, 174 | lower=0.2, upper=1.8) 175 | 176 | # Subtract off the mean and divide by the variance of the pixels. 177 | float_image = tf.image.per_image_whitening(distorted_image) 178 | 179 | # Ensure that the random shuffling has good mixing properties. 180 | min_fraction_of_examples_in_queue = 0.4 181 | min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * 182 | min_fraction_of_examples_in_queue) 183 | print ('Filling queue with %d CIFAR images before starting to train. ' 184 | 'This will take a few minutes.' % min_queue_examples) 185 | 186 | # Generate a batch of images and labels by building up a queue of examples. 187 | return _generate_image_and_label_batch(float_image, read_input.label, 188 | min_queue_examples, batch_size) 189 | 190 | 191 | def inputs(eval_data, data_dir, batch_size): 192 | """Construct input for CIFAR evaluation using the Reader ops. 193 | 194 | Args: 195 | eval_data: bool, indicating if one should use the train or eval data set. 196 | data_dir: Path to the CIFAR-10 data directory. 197 | batch_size: Number of images per batch. 198 | 199 | Returns: 200 | images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. 201 | labels: Labels. 1D tensor of [batch_size] size. 202 | """ 203 | if not eval_data: 204 | filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) 205 | for i in xrange(1, 6)] 206 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN 207 | else: 208 | filenames = [os.path.join(data_dir, 'test_batch.bin')] 209 | num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL 210 | 211 | for f in filenames: 212 | if not gfile.Exists(f): 213 | raise ValueError('Failed to find file: ' + f) 214 | 215 | # Create a queue that produces the filenames to read. 216 | filename_queue = tf.train.string_input_producer(filenames) 217 | 218 | # Read examples from files in the filename queue. 219 | read_input = read_cifar10(filename_queue) 220 | reshaped_image = tf.cast(read_input.uint8image, tf.float32) 221 | 222 | height = IMAGE_SIZE 223 | width = IMAGE_SIZE 224 | 225 | # Image processing for evaluation. 226 | # Crop the central [height, width] of the image. 227 | resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, 228 | width, height) 229 | 230 | # Subtract off the mean and divide by the variance of the pixels. 231 | float_image = tf.image.per_image_whitening(resized_image) 232 | 233 | # Ensure that the random shuffling has good mixing properties. 234 | min_fraction_of_examples_in_queue = 0.4 235 | min_queue_examples = int(num_examples_per_epoch * 236 | min_fraction_of_examples_in_queue) 237 | 238 | # Generate a batch of images and labels by building up a queue of examples. 239 | return _generate_image_and_label_batch(float_image, read_input.label, 240 | min_queue_examples, batch_size) 241 | -------------------------------------------------------------------------------- /cifar10_input_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for cifar10 input.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | import tensorflow.python.platform 25 | 26 | import tensorflow as tf 27 | 28 | #from tensorflow.models.image.cifar10 import cifar10_input 29 | #from tensorflow.python.util import compat 30 | import compat as cp 31 | 32 | import cifar10_input 33 | 34 | class CIFAR10InputTest(tf.test.TestCase): 35 | 36 | def _record(self, label, red, green, blue): 37 | image_size = 32 * 32 38 | record = bytes(bytearray([label] + [red] * image_size + 39 | [green] * image_size + [blue] * image_size)) 40 | expected = [[[red, green, blue]] * 32] * 32 41 | return record, expected 42 | 43 | def testSimple(self): 44 | labels = [9, 3, 0] 45 | records = [self._record(labels[0], 0, 128, 255), 46 | self._record(labels[1], 255, 0, 1), 47 | self._record(labels[2], 254, 255, 0)] 48 | contents = b"".join([record for record, _ in records]) 49 | expected = [expected for _, expected in records] 50 | filename = os.path.join(self.get_temp_dir(), "cifar") 51 | open(filename, "wb").write(contents) 52 | 53 | with self.test_session() as sess: 54 | q = tf.FIFOQueue(99, [tf.string], shapes=()) 55 | q.enqueue([filename]).run() 56 | q.close().run() 57 | result = cifar10_input.read_cifar10(q) 58 | 59 | for i in range(3): 60 | key, label, uint8image = sess.run([ 61 | result.key, result.label, result.uint8image]) 62 | self.assertEqual("%s:%d" % (filename, i), cp.as_text(key)) 63 | self.assertEqual(labels[i], label) 64 | self.assertAllEqual(expected[i], uint8image) 65 | 66 | with self.assertRaises(tf.errors.OutOfRangeError): 67 | sess.run([result.key, result.uint8image]) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /cifar10_multi_gpu_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A binary to train CIFAR-10 using multiple GPU's with synchronous updates. 17 | 18 | Accuracy: 19 | cifar10_multi_gpu_train.py achieves ~86% accuracy after 100K steps (256 20 | epochs of data) as judged by cifar10_eval.py. 21 | 22 | Speed: With batch_size 128. 23 | 24 | System | Step Time (sec/batch) | Accuracy 25 | -------------------------------------------------------------------- 26 | 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) 27 | 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) 28 | 2 Tesla K20m | 0.13-0.20 | ~84% at 30K steps (2.5 hours) 29 | 3 Tesla K20m | 0.13-0.18 | ~84% at 30K steps 30 | 4 Tesla K20m | ~0.10 | ~84% at 30K steps 31 | 32 | Usage: 33 | Please see the tutorial and website for how to download the CIFAR-10 34 | data set, compile the program and train the model. 35 | 36 | http://tensorflow.org/tutorials/deep_cnn/ 37 | """ 38 | from __future__ import absolute_import 39 | from __future__ import division 40 | from __future__ import print_function 41 | 42 | from datetime import datetime 43 | import os.path 44 | import re 45 | import time 46 | 47 | # pylint: disable=unused-import,g-bad-import-order 48 | import tensorflow.python.platform 49 | from tensorflow.python.platform import gfile 50 | import numpy as np 51 | from six.moves import xrange # pylint: disable=redefined-builtin 52 | import tensorflow as tf 53 | #from tensorflow.models.image.cifar10 import cifar10 54 | import cifar10 55 | # pylint: disable=unused-import,g-bad-import-order 56 | 57 | FLAGS = tf.app.flags.FLAGS 58 | 59 | tf.app.flags.DEFINE_string('train_dir', 'cifar10_train_gpu/', 60 | """Directory where to write event logs """ 61 | """and checkpoint.""") 62 | tf.app.flags.DEFINE_integer('max_steps', 20000, 63 | """Number of batches to run.""") 64 | tf.app.flags.DEFINE_integer('num_gpus', 1, 65 | """How many GPUs to use.""") 66 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 67 | """Whether to log device placement.""") 68 | 69 | 70 | def tower_loss(scope): 71 | """Calculate the total loss on a single tower running the CIFAR model. 72 | 73 | Args: 74 | scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0' 75 | 76 | Returns: 77 | Tensor of shape [] containing the total loss for a batch of data 78 | """ 79 | # Get images and labels for CIFAR-10. 80 | images, labels = cifar10.distorted_inputs() 81 | 82 | # Build inference Graph. 83 | logits = cifar10.inference(images) 84 | 85 | # Build the portion of the Graph calculating the losses. Note that we will 86 | # assemble the total_loss using a custom function below. 87 | _ = cifar10.loss(logits, labels) 88 | 89 | # Assemble all of the losses for the current tower only. 90 | losses = tf.get_collection('losses', scope) 91 | 92 | # Calculate the total loss for the current tower. 93 | total_loss = tf.add_n(losses, name='total_loss') 94 | 95 | # Compute the moving average of all individual losses and the total loss. 96 | loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') 97 | loss_averages_op = loss_averages.apply(losses + [total_loss]) 98 | 99 | # Attach a scalar summary to all individual losses and the total loss; do the 100 | # same for the averaged version of the losses. 101 | for l in losses + [total_loss]: 102 | # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training 103 | # session. This helps the clarity of presentation on tensorboard. 104 | loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name) 105 | # Name each loss as '(raw)' and name the moving average version of the loss 106 | # as the original loss name. 107 | tf.scalar_summary(loss_name +' (raw)', l) 108 | tf.scalar_summary(loss_name, loss_averages.average(l)) 109 | 110 | with tf.control_dependencies([loss_averages_op]): 111 | total_loss = tf.identity(total_loss) 112 | return total_loss 113 | 114 | 115 | def average_gradients(tower_grads): 116 | """Calculate the average gradient for each shared variable across all towers. 117 | 118 | Note that this function provides a synchronization point across all towers. 119 | 120 | Args: 121 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 122 | is over individual gradients. The inner list is over the gradient 123 | calculation for each tower. 124 | Returns: 125 | List of pairs of (gradient, variable) where the gradient has been averaged 126 | across all towers. 127 | """ 128 | average_grads = [] 129 | for grad_and_vars in zip(*tower_grads): 130 | # Note that each grad_and_vars looks like the following: 131 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 132 | grads = [] 133 | for g, _ in grad_and_vars: 134 | # Add 0 dimension to the gradients to represent the tower. 135 | expanded_g = tf.expand_dims(g, 0) 136 | 137 | # Append on a 'tower' dimension which we will average over below. 138 | grads.append(expanded_g) 139 | 140 | # Average over the 'tower' dimension. 141 | grad = tf.concat(0, grads) 142 | grad = tf.reduce_mean(grad, 0) 143 | 144 | # Keep in mind that the Variables are redundant because they are shared 145 | # across towers. So .. we will just return the first tower's pointer to 146 | # the Variable. 147 | v = grad_and_vars[0][1] 148 | grad_and_var = (grad, v) 149 | average_grads.append(grad_and_var) 150 | return average_grads 151 | 152 | 153 | def train(): 154 | """Train CIFAR-10 for a number of steps.""" 155 | with tf.Graph().as_default(), tf.device('/cpu:0'): 156 | # Create a variable to count the number of train() calls. This equals the 157 | # number of batches processed * FLAGS.num_gpus. 158 | global_step = tf.get_variable( 159 | 'global_step', [], 160 | initializer=tf.constant_initializer(0), trainable=False) 161 | 162 | # Calculate the learning rate schedule. 163 | num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / 164 | FLAGS.batch_size) 165 | decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY) 166 | 167 | # Decay the learning rate exponentially based on the number of steps. 168 | lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE, 169 | global_step, 170 | decay_steps, 171 | cifar10.LEARNING_RATE_DECAY_FACTOR, 172 | staircase=True) 173 | 174 | # Create an optimizer that performs gradient descent. 175 | opt = tf.train.GradientDescentOptimizer(lr) 176 | 177 | # Calculate the gradients for each model tower. 178 | tower_grads = [] 179 | for i in xrange(FLAGS.num_gpus): 180 | with tf.device('/gpu:%d' % i): 181 | with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope: 182 | # Calculate the loss for one tower of the CIFAR model. This function 183 | # constructs the entire CIFAR model but shares the variables across 184 | # all towers. 185 | loss = tower_loss(scope) 186 | 187 | # Reuse variables for the next tower. 188 | tf.get_variable_scope().reuse_variables() 189 | 190 | # Retain the summaries from the final tower. 191 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) 192 | 193 | # Calculate the gradients for the batch of data on this CIFAR tower. 194 | grads = opt.compute_gradients(loss) 195 | 196 | # Keep track of the gradients across all towers. 197 | tower_grads.append(grads) 198 | 199 | # We must calculate the mean of each gradient. Note that this is the 200 | # synchronization point across all towers. 201 | grads = average_gradients(tower_grads) 202 | 203 | # Add a summary to track the learning rate. 204 | summaries.append(tf.scalar_summary('learning_rate', lr)) 205 | 206 | # Add histograms for gradients. 207 | for grad, var in grads: 208 | if grad: 209 | summaries.append( 210 | tf.histogram_summary(var.op.name + '/gradients', grad)) 211 | 212 | # Apply the gradients to adjust the shared variables. 213 | apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) 214 | 215 | # Add histograms for trainable variables. 216 | for var in tf.trainable_variables(): 217 | summaries.append(tf.histogram_summary(var.op.name, var)) 218 | 219 | # Track the moving averages of all trainable variables. 220 | variable_averages = tf.train.ExponentialMovingAverage( 221 | cifar10.MOVING_AVERAGE_DECAY, global_step) 222 | variables_averages_op = variable_averages.apply(tf.trainable_variables()) 223 | 224 | # Group all updates to into a single train op. 225 | train_op = tf.group(apply_gradient_op, variables_averages_op) 226 | 227 | # Create a saver. 228 | saver = tf.train.Saver(tf.all_variables()) 229 | 230 | # Build the summary operation from the last tower summaries. 231 | summary_op = tf.merge_summary(summaries) 232 | 233 | # Build an initialization operation to run below. 234 | init = tf.initialize_all_variables() 235 | 236 | # Start running operations on the Graph. allow_soft_placement must be set to 237 | # True to build towers on GPU, as some of the ops do not have GPU 238 | # implementations. 239 | sess = tf.Session(config=tf.ConfigProto( 240 | allow_soft_placement=True, 241 | log_device_placement=FLAGS.log_device_placement)) 242 | sess.run(init) 243 | 244 | # Start the queue runners. 245 | tf.train.start_queue_runners(sess=sess) 246 | 247 | summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, 248 | graph_def=sess.graph_def) 249 | 250 | for step in xrange(FLAGS.max_steps): 251 | start_time = time.time() 252 | _, loss_value = sess.run([train_op, loss]) 253 | duration = time.time() - start_time 254 | 255 | assert not np.isnan(loss_value), 'Model diverged with loss = NaN' 256 | 257 | if step % 10 == 0: 258 | num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus 259 | examples_per_sec = num_examples_per_step / duration 260 | sec_per_batch = duration / FLAGS.num_gpus 261 | 262 | format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 263 | 'sec/batch)') 264 | print (format_str % (datetime.now(), step, loss_value, 265 | examples_per_sec, sec_per_batch)) 266 | 267 | if step % 100 == 0: 268 | summary_str = sess.run(summary_op) 269 | summary_writer.add_summary(summary_str, step) 270 | 271 | # Save the model checkpoint periodically. 272 | if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: 273 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') 274 | saver.save(sess, checkpoint_path, global_step=step) 275 | 276 | 277 | def main(argv=None): # pylint: disable=unused-argument 278 | cifar10.maybe_download_and_extract() 279 | if gfile.Exists(FLAGS.train_dir): 280 | gfile.DeleteRecursively(FLAGS.train_dir) 281 | gfile.MakeDirs(FLAGS.train_dir) 282 | train() 283 | 284 | 285 | if __name__ == '__main__': 286 | tf.app.run() 287 | 288 | -------------------------------------------------------------------------------- /cifar10_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A binary to train CIFAR-10 using a single GPU. 17 | 18 | Accuracy: 19 | cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of 20 | data) as judged by cifar10_eval.py. 21 | 22 | Speed: With batch_size 128. 23 | 24 | System | Step Time (sec/batch) | Accuracy 25 | ------------------------------------------------------------------ 26 | 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours) 27 | 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours) 28 | 29 | Usage: 30 | Please see the tutorial and website for how to download the CIFAR-10 31 | data set, compile the program and train the model. 32 | 33 | http://tensorflow.org/tutorials/deep_cnn/ 34 | """ 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | from datetime import datetime 40 | import os.path 41 | import time 42 | 43 | import tensorflow.python.platform 44 | from tensorflow.python.platform import gfile 45 | 46 | import numpy as np 47 | from six.moves import xrange # pylint: disable=redefined-builtin 48 | import tensorflow as tf 49 | 50 | #from tensorflow.models.image.cifar10 import cifar10 51 | import cifar10 52 | 53 | FLAGS = tf.app.flags.FLAGS 54 | 55 | tf.app.flags.DEFINE_string('train_dir', 'cifar10_train/', 56 | """Directory where to write event logs """ 57 | """and checkpoint.""") 58 | tf.app.flags.DEFINE_integer('max_steps', 20000, 59 | """Number of batches to run.""") 60 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 61 | """Whether to log device placement.""") 62 | 63 | 64 | def train(): 65 | """Train CIFAR-10 for a number of steps.""" 66 | with tf.Graph().as_default(): 67 | global_step = tf.Variable(0, trainable=False) 68 | 69 | # Get images and labels for CIFAR-10. 70 | images, labels = cifar10.distorted_inputs() 71 | 72 | # Build a Graph that computes the logits predictions from the 73 | # inference model. 74 | logits = cifar10.inference(images) 75 | 76 | # Calculate loss. 77 | loss = cifar10.loss(logits, labels) 78 | 79 | # Build a Graph that trains the model with one batch of examples and 80 | # updates the model parameters. 81 | train_op = cifar10.train(loss, global_step) 82 | 83 | # Create a saver. 84 | saver = tf.train.Saver(tf.all_variables()) 85 | 86 | # Build the summary operation based on the TF collection of Summaries. 87 | summary_op = tf.merge_all_summaries() 88 | 89 | # Build an initialization operation to run below. 90 | init = tf.initialize_all_variables() 91 | 92 | # Start running operations on the Graph. 93 | sess = tf.Session(config=tf.ConfigProto( 94 | log_device_placement=FLAGS.log_device_placement)) 95 | sess.run(init) 96 | 97 | # Start the queue runners. 98 | tf.train.start_queue_runners(sess=sess) 99 | 100 | summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, 101 | graph_def=sess.graph_def) 102 | 103 | for step in xrange(FLAGS.max_steps): 104 | start_time = time.time() 105 | _, loss_value = sess.run([train_op, loss]) 106 | duration = time.time() - start_time 107 | 108 | assert not np.isnan(loss_value), 'Model diverged with loss = NaN' 109 | 110 | if step % 10 == 0: 111 | num_examples_per_step = FLAGS.batch_size 112 | examples_per_sec = num_examples_per_step / duration 113 | sec_per_batch = float(duration) 114 | 115 | format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 116 | 'sec/batch)') 117 | print (format_str % (datetime.now(), step, loss_value, 118 | examples_per_sec, sec_per_batch)) 119 | 120 | if step % 100 == 0: 121 | summary_str = sess.run(summary_op) 122 | summary_writer.add_summary(summary_str, step) 123 | 124 | # Save the model checkpoint periodically. 125 | if step % 1000 == 0 or (step + 1) == FLAGS.max_steps: 126 | checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') 127 | saver.save(sess, checkpoint_path, global_step=step) 128 | 129 | 130 | def main(argv=None): # pylint: disable=unused-argument 131 | cifar10.maybe_download_and_extract() 132 | if gfile.Exists(FLAGS.train_dir): 133 | gfile.DeleteRecursively(FLAGS.train_dir) 134 | gfile.MakeDirs(FLAGS.train_dir) 135 | train() 136 | 137 | 138 | if __name__ == '__main__': 139 | tf.app.run() 140 | 141 | -------------------------------------------------------------------------------- /compat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for Python 2 vs. 3 compatibility.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numbers 22 | import numpy as np 23 | import six 24 | 25 | 26 | def as_bytes(bytes_or_text): 27 | """Converts either bytes or unicode to `bytes`, using utf-8 encoding for text. 28 | 29 | Args: 30 | bytes_or_text: A `bytes`, `str`, or `unicode` object. 31 | 32 | Returns: 33 | A `bytes` object. 34 | 35 | Raises: 36 | TypeError: If `bytes_or_text` is not a binary or unicode string. 37 | """ 38 | if isinstance(bytes_or_text, six.text_type): 39 | return bytes_or_text.encode('utf-8') 40 | elif isinstance(bytes_or_text, bytes): 41 | return bytes_or_text 42 | else: 43 | raise TypeError('Expected binary or unicode string, got %r' % bytes_or_text) 44 | 45 | 46 | def as_text(bytes_or_text): 47 | """Returns the given argument as a unicode string. 48 | 49 | Args: 50 | bytes_or_text: A `bytes`, `str, or `unicode` object. 51 | 52 | Returns: 53 | A `unicode` (Python 2) or `str` (Python 3) object. 54 | 55 | Raises: 56 | TypeError: If `bytes_or_text` is not a binary or unicode string. 57 | """ 58 | if isinstance(bytes_or_text, six.text_type): 59 | return bytes_or_text 60 | elif isinstance(bytes_or_text, bytes): 61 | return bytes_or_text.decode('utf-8') 62 | else: 63 | raise TypeError('Expected binary or unicode string, got %r' % bytes_or_text) 64 | 65 | 66 | # Convert an object to a `str` in both Python 2 and 3 67 | if six.PY2: 68 | as_str = as_bytes 69 | else: 70 | as_str = as_text 71 | 72 | 73 | def as_str_any(value): 74 | """Converts to `str` as `str(value)`, but use `as_str` for `bytes`. 75 | 76 | Args: 77 | value: A object that can be converted to `str`. 78 | 79 | Returns: 80 | A `str` object. 81 | """ 82 | if isinstance(value, bytes): 83 | return as_str(value) 84 | else: 85 | return str(value) 86 | 87 | 88 | # Numpy 1.8 scalars don't inherit from numbers.Integral in Python 3, so we 89 | # need to check them specifically. The same goes from Real and Complex. 90 | integral_types = (numbers.Integral, np.integer) 91 | real_types = (numbers.Real, np.integer, np.floating) 92 | complex_types = (numbers.Complex, np.number) 93 | 94 | 95 | # Either bytes or text 96 | bytes_or_text_types = (bytes, six.text_type) 97 | --------------------------------------------------------------------------------