├── nets ├── __init__.py ├── nasnet │ ├── __init__.py │ ├── nasnet_utils_test.py │ └── README.md ├── mobilenet_v1.png ├── mobilenet │ ├── madds_top1_accuracy.png │ ├── mnet_v1_vs_v2_pixel1_latency.png │ └── README.md ├── inception.py ├── inception_utils.py ├── nets_factory_test.py ├── lenet.py ├── cyclegan_test.py ├── dcgan_test.py ├── cifarnet.py ├── mobilenet_v1_eval.py ├── overfeat.py ├── pix2pix_test.py ├── alexnet.py ├── i3d_test.py ├── s3dg_test.py ├── i3d.py ├── overfeat_test.py ├── alexnet_test.py └── nets_factory.py ├── datasets ├── __init__.py ├── dataset_factory.py ├── download_mscoco.sh ├── garbage.py ├── preprocess_imagenet_validation_data.py ├── cifar10.py ├── flowers.py ├── mnist.py ├── download_imagenet.sh ├── download_and_convert_imagenet.sh ├── visualwakewords.py ├── dataset_utils.py ├── build_visualwakewords_data.py └── download_and_convert_cifar10.py ├── deployment ├── __init__.py ├── tensorflow_determinism-0.1.0-py3-none-any.whl ├── tensorflow_determinism-0.3.0-py3-none-any.whl └── tensorflow_determinism-0.1.0-py3-none-any.whl.zip ├── preprocessing ├── __init__.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── cifarnet_preprocessing.py ├── pre_train └── README.md ├── README.md ├── params_printer.py ├── model3_config.py ├── model2_config.py ├── model3.py ├── model2.py ├── model1.py ├── model.py ├── export_to_saved_model.py ├── eval_image_classifier_original.py ├── model1_config.py ├── eval_image_classifier.py ├── validation_confusion_matrix.py └── convert_garbage_data.py /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skypeaceLL/sky-flink-phase2-python/HEAD/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /nets/mobilenet/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skypeaceLL/sky-flink-phase2-python/HEAD/nets/mobilenet/madds_top1_accuracy.png -------------------------------------------------------------------------------- /nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skypeaceLL/sky-flink-phase2-python/HEAD/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /deployment/tensorflow_determinism-0.1.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skypeaceLL/sky-flink-phase2-python/HEAD/deployment/tensorflow_determinism-0.1.0-py3-none-any.whl -------------------------------------------------------------------------------- /deployment/tensorflow_determinism-0.3.0-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skypeaceLL/sky-flink-phase2-python/HEAD/deployment/tensorflow_determinism-0.3.0-py3-none-any.whl -------------------------------------------------------------------------------- /deployment/tensorflow_determinism-0.1.0-py3-none-any.whl.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/skypeaceLL/sky-flink-phase2-python/HEAD/deployment/tensorflow_determinism-0.1.0-py3-none-any.whl.zip -------------------------------------------------------------------------------- /pre_train/README.md: -------------------------------------------------------------------------------- 1 | pre_train目录下预训练文件有3个,分别是: 2 | ``` 3 | pre_train/resnet_v1_101.ckpt 4 | pre_train/inception_v4.ckpt 5 | pre_train/inception_v3.ckpt 6 | ``` 7 | 这些文件太大,无法上传到 git server上,请从以下地址下载并解压到pre_train目录下: 8 | http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz 9 | http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz 10 | http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Apache Flink极客挑战赛——垃圾图片分类—复赛Python code 2 | 3 | # 1. 代码说明 4 | 1. 采用Tensorflow slim (开源) 5 | 2. model.py为主入口文件。 6 | 3. 总共训练出三个模型文件,分别基于 resnet_v1_101(resnet_v1_50也可以), inception_v4, inception_v3的ImageNet数据集的预训练checkpoint。 7 | 4. 训练各个模型的具体参数分别见model.py, model1_config.py, model2_config.py, model3_config.py 8 | 5. 训练过程中以指定时间间隔生成多个Checkpoint文件,训练完后,逐个检查Checkpoint的val_acc,选择最大值的checkpoint用来Export saved models(3个TF Saved Model)。 9 | 6. pre_train目录下预训练文件有3个,分别是: 10 | ``` 11 | pre_train/resnet_v1_101.ckpt 12 | pre_train/inception_v4.ckpt 13 | pre_train/inception_v3.ckpt 14 | ``` 15 | 这些文件太大,无法上传到 git server上,请从以下地址下载并解压到pre_train目录下: 16 | http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz 17 | http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz 18 | http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 19 | 20 | # 2. 许可声明 21 | 你可以使用此代码用于学习和研究,但务必不要将此代码用于任何商业用途和比赛项目(Tensorflow的slim那部分代码不受此许可声明约束)。 22 | -------------------------------------------------------------------------------- /params_printer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | def output(): 8 | FLAGS = tf.app.flags.FLAGS 9 | print("model_name: %s" % FLAGS.model_name) 10 | print("checkpoint_path: %s" % FLAGS.checkpoint_path) 11 | print("==================================================") 12 | print("max_number_of_steps: %d" % FLAGS.max_number_of_steps) 13 | print("save_interval_secs: %d" % FLAGS.save_interval_secs) 14 | print("max_to_keep: %d" % FLAGS.max_to_keep) 15 | print("batch_size: %d" % FLAGS.batch_size) 16 | print("optimizer: %s" % FLAGS.optimizer) 17 | print("weight_decay: %f" % FLAGS.weight_decay) 18 | print("opt_epsilon: %.8f" % FLAGS.opt_epsilon) 19 | print("learning_rate: %f" % FLAGS.learning_rate) 20 | print("end_learning_rate: %f" % FLAGS.end_learning_rate) 21 | print("learning_rate_decay_type: %s" % FLAGS.learning_rate_decay_type) 22 | print("learning_rate_decay_factor: %f" % FLAGS.learning_rate_decay_factor) 23 | print("num_steps_per_decay: %f" % FLAGS.num_steps_per_decay) 24 | 25 | -------------------------------------------------------------------------------- /model3_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import os 7 | 8 | def init_params(): 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | _checkpoint_path = os.path.join(FLAGS.script_root_dir, "pre_train", "inception_v3.ckpt") 12 | 13 | ##################### 14 | # Basic Flags # 15 | ##################### 16 | FLAGS.checkpoint_path = _checkpoint_path 17 | FLAGS.max_number_of_steps = 3600 # 18 | FLAGS.save_interval_secs = 40 # 19 | FLAGS.max_to_keep = 12 20 | FLAGS.train_image_size = 299 21 | FLAGS.model_name = "inception_v3" 22 | FLAGS.preprocessing_name = None 23 | FLAGS.checkpoint_exclude_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits' 24 | FLAGS.trainable_scopes = 'InceptionV3/Logits,InceptionV3/AuxLogits' 25 | FLAGS.optimizer = "rmsprop" 26 | FLAGS.opt_epsilon = 1.0 27 | FLAGS.learning_rate = 0.01 28 | FLAGS.end_learning_rate = 0.00001 29 | FLAGS.learning_rate_decay_type = "fixed" 30 | FLAGS.learning_rate_decay_factor = 0.94 31 | FLAGS.num_steps_per_decay = 336.8 32 | 33 | -------------------------------------------------------------------------------- /model2_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import os 7 | 8 | def init_params(): 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | _checkpoint_path = os.path.join(FLAGS.script_root_dir, "pre_train", "inception_v4.ckpt") 12 | 13 | ##################### 14 | # Basic Flags # 15 | ##################### 16 | FLAGS.checkpoint_path = _checkpoint_path 17 | FLAGS.max_number_of_steps = 3400 # 4000: 25 epochs 18 | FLAGS.save_interval_secs = 85 #68 19 | FLAGS.max_to_keep = 10 20 | FLAGS.train_image_size = 299 21 | FLAGS.model_name = "inception_v4" 22 | FLAGS.preprocessing_name = None 23 | FLAGS.checkpoint_exclude_scopes = 'InceptionV4/Logits,InceptionV4/AuxLogits' 24 | FLAGS.trainable_scopes = 'InceptionV4/Logits,InceptionV4/AuxLogits' 25 | FLAGS.optimizer = "rmsprop" 26 | FLAGS.opt_epsilon = 1.0 27 | FLAGS.learning_rate = 0.01 28 | FLAGS.end_learning_rate = 0.00001 29 | FLAGS.learning_rate_decay_type = "fixed" 30 | FLAGS.learning_rate_decay_factor = 0.94 31 | FLAGS.num_steps_per_decay = 336.8 32 | -------------------------------------------------------------------------------- /model3.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from time import sleep 6 | from time import time 7 | from datetime import datetime 8 | import model3_config 9 | import params_printer 10 | 11 | def execute(): 12 | 13 | model3_config.init_params() 14 | params_printer.output() 15 | 16 | import train_image_classifier 17 | import eval_image_classifier_original 18 | import eval_image_classifier 19 | import validation_confusion_matrix 20 | import export_to_saved_model 21 | 22 | print("Begin %s" % format(datetime.now().isoformat())) 23 | t_begin = time() 24 | 25 | print("Begin model3 train ...") 26 | t1 = time() 27 | train_image_classifier.train() 28 | t2 = time() 29 | print("Model3 end train %d s" % (t2-t1)) 30 | 31 | sleep(1) 32 | 33 | print("Begin model3 validations and find a lucky check point ...") 34 | t1 = time() 35 | eval_image_classifier_original.print_train_acc() 36 | max_checkpoint_path = eval_image_classifier.find_max_accuracy_checkpoint() 37 | validation_confusion_matrix.execute(max_checkpoint_path, "model3") 38 | 39 | t2 = time() 40 | print("Model3 end validations %d s" %(t2 - t1)) 41 | 42 | print("Begin model3 export to saved model ...") 43 | t1 = time() 44 | export_to_saved_model.export(max_checkpoint_path, "model3") 45 | t2 = time() 46 | print("End model3 export to saved model %d s" %(t2 - t1)) 47 | 48 | print("Done of all %s" % format(datetime.now().isoformat())) 49 | t_end = time() 50 | 51 | print("Model3 all end: %d" % (t_end - t_begin)) 52 | 53 | -------------------------------------------------------------------------------- /model2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from time import sleep 6 | from time import time 7 | from datetime import datetime 8 | import model2_config 9 | import params_printer 10 | 11 | def execute(): 12 | 13 | model2_config.init_params() 14 | params_printer.output() 15 | 16 | import train_image_classifier 17 | import eval_image_classifier_original 18 | import eval_image_classifier 19 | import validation_confusion_matrix 20 | import export_to_saved_model 21 | 22 | #""" 23 | print("Begin %s" % format(datetime.now().isoformat())) 24 | t_begin = time() 25 | 26 | print("Begin model2 train ...") 27 | t1 = time() 28 | train_image_classifier.train() 29 | t2 = time() 30 | print("Model2 end train %d s" % (t2-t1)) 31 | 32 | sleep(1) 33 | 34 | print("Begin model2 validations and find a lucky check point ...") 35 | t1 = time() 36 | eval_image_classifier_original.print_train_acc() 37 | max_checkpoint_path = eval_image_classifier.find_max_accuracy_checkpoint() 38 | validation_confusion_matrix.execute(max_checkpoint_path, "model2") 39 | 40 | t2 = time() 41 | print("Model2 end validations %d s" %(t2 - t1)) 42 | 43 | print("Begin model2 export to saved model ...") 44 | t1 = time() 45 | export_to_saved_model.export(max_checkpoint_path, "model2") 46 | t2 = time() 47 | print("End model2 export to saved model %d s" %(t2 - t1)) 48 | 49 | print("Done of all %s" % format(datetime.now().isoformat())) 50 | t_end = time() 51 | 52 | print("Model2 all end: %d" % (t_end - t_begin)) 53 | #""" 54 | -------------------------------------------------------------------------------- /model1.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from time import sleep 6 | from time import time 7 | from datetime import datetime 8 | import model1_config 9 | import params_printer 10 | 11 | def execute(): 12 | 13 | model1_config.init_params() 14 | params_printer.output() 15 | #return 16 | import train_image_classifier 17 | import eval_image_classifier_original 18 | import eval_image_classifier 19 | import validation_confusion_matrix 20 | import export_to_saved_model 21 | #""" 22 | print("Begin %s" % format(datetime.now().isoformat())) 23 | t_begin = time() 24 | 25 | print("Begin model1 train ...") 26 | t1 = time() 27 | train_image_classifier.train() 28 | t2 = time() 29 | print("Model1 end train %d s" % (t2-t1)) 30 | 31 | sleep(1) 32 | 33 | print("Begin model1 validations and find a lucky check point ...") 34 | t1 = time() 35 | eval_image_classifier_original.print_train_acc() 36 | max_checkpoint_path = eval_image_classifier.find_max_accuracy_checkpoint() 37 | validation_confusion_matrix.execute(max_checkpoint_path, "model1") 38 | 39 | t2 = time() 40 | print("Model1 end validations %d s" %(t2 - t1)) 41 | 42 | print("Begin model1 export to saved model ...") 43 | t1 = time() 44 | export_to_saved_model.export(max_checkpoint_path, "model1") 45 | t2 = time() 46 | print("End model1 export to saved model %d s" %(t2 - t1)) 47 | 48 | print("Done of all %s" % format(datetime.now().isoformat())) 49 | t_end = time() 50 | 51 | print("Model1 all end: %d" % (t_end - t_begin)) 52 | #""" 53 | -------------------------------------------------------------------------------- /preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | from datasets import visualwakewords 26 | from datasets import garbage 27 | 28 | datasets_map = { 29 | 'cifar10': cifar10, 30 | 'flowers': flowers, 31 | 'imagenet': imagenet, 32 | 'mnist': mnist, 33 | 'visualwakewords': visualwakewords, 34 | 'garbage': garbage 35 | } 36 | 37 | 38 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 39 | """Given a dataset name and a split_name returns a Dataset. 40 | 41 | Args: 42 | name: String, the name of the dataset. 43 | split_name: A train/test split name. 44 | dataset_dir: The directory where the dataset files are stored. 45 | file_pattern: The file pattern to use for matching the dataset source files. 46 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 47 | reader defined by each dataset is used. 48 | 49 | Returns: 50 | A `Dataset` class. 51 | 52 | Raises: 53 | ValueError: If the dataset `name` is unknown. 54 | """ 55 | if name not in datasets_map: 56 | raise ValueError('Name of dataset unknown %s' % name) 57 | return datasets_map[name].get_split( 58 | split_name, 59 | dataset_dir, 60 | file_pattern, 61 | reader) 62 | -------------------------------------------------------------------------------- /nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /datasets/download_mscoco.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | # Script to download the COCO dataset. See 17 | # http://cocodataset.org/#overview for an overview of the dataset. 18 | # 19 | # usage: 20 | # bash datasets/download_mscoco.sh path-to-COCO-dataset 21 | # 22 | set -e 23 | 24 | if [ -z "$1" ]; then 25 | echo "usage download_mscoco.sh [data dir]" 26 | exit 27 | fi 28 | 29 | if [ "$(uname)" == "Darwin" ]; then 30 | UNZIP="tar -xf" 31 | else 32 | UNZIP="unzip -nq" 33 | fi 34 | 35 | # Create the output directories. 36 | OUTPUT_DIR="${1%/}" 37 | SCRATCH_DIR="${OUTPUT_DIR}/raw-data" 38 | mkdir -p "${OUTPUT_DIR}" 39 | mkdir -p "${SCRATCH_DIR}" 40 | CURRENT_DIR=$(pwd) 41 | 42 | # Helper function to download and unpack a .zip file. 43 | function download_and_unzip() { 44 | local BASE_URL=${1} 45 | local FILENAME=${2} 46 | 47 | if [ ! -f ${FILENAME} ]; then 48 | echo "Downloading ${FILENAME} to $(pwd)" 49 | wget -nd -c "${BASE_URL}/${FILENAME}" 50 | else 51 | echo "Skipping download of ${FILENAME}" 52 | fi 53 | echo "Unzipping ${FILENAME}" 54 | ${UNZIP} ${FILENAME} 55 | } 56 | 57 | cd ${SCRATCH_DIR} 58 | 59 | # Download the images. 60 | BASE_IMAGE_URL="http://images.cocodataset.org/zips" 61 | 62 | TRAIN_IMAGE_FILE="train2014.zip" 63 | download_and_unzip ${BASE_IMAGE_URL} ${TRAIN_IMAGE_FILE} 64 | TRAIN_IMAGE_DIR="${SCRATCH_DIR}/train2014" 65 | 66 | VAL_IMAGE_FILE="val2014.zip" 67 | download_and_unzip ${BASE_IMAGE_URL} ${VAL_IMAGE_FILE} 68 | VAL_IMAGE_DIR="${SCRATCH_DIR}/val2014" 69 | 70 | 71 | # Download the annotations. 72 | BASE_INSTANCES_URL="http://images.cocodataset.org/annotations" 73 | INSTANCES_FILE="annotations_trainval2014.zip" 74 | download_and_unzip ${BASE_INSTANCES_URL} ${INSTANCES_FILE} 75 | 76 | 77 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import warnings 6 | warnings.filterwarnings(action="ignore") 7 | 8 | import tensorflow as tf 9 | import os 10 | from time import time 11 | import sys 12 | reload(sys) 13 | sys.setdefaultencoding('utf8') 14 | 15 | _script_root_path = os.path.dirname(os.path.abspath(__file__)) 16 | _dataset_dir = os.path.join(_script_root_path, "data") 17 | _train_dir = os.path.join(_script_root_path, "train_dir") 18 | _image_data_dir = os.environ["IMAGE_TRAIN_INPUT_PATH"] 19 | _inference_dir = os.environ["MODEL_INFERENCE_PATH"] 20 | tf.app.flags.DEFINE_string('script_root_dir', _script_root_path, 'Directory where checkpoints and event logs are written to.') 21 | tf.app.flags.DEFINE_string('image_data_dir', _image_data_dir, 'Directory where image files are stored.') 22 | tf.app.flags.DEFINE_string('train_dir', _train_dir, 'Directory where checkpoints and event logs are written to.') 23 | tf.app.flags.DEFINE_string('dataset_dir', _dataset_dir, 'The directory where the dataset files are stored.') 24 | tf.app.flags.DEFINE_string('inference_dir', _inference_dir, 'The directory where the model files are stored.') 25 | tf.app.flags.DEFINE_integer('num_train', 5389, '_NUM_TRAIN') 26 | tf.app.flags.DEFINE_integer('num_validation', 600, '_NUM_VALIDATION') 27 | 28 | import convert_garbage_data 29 | import model1 30 | import model2 31 | import model3 32 | 33 | def main(_): 34 | FLAGS = tf.app.flags.FLAGS 35 | print("script root path: %s" % _script_root_path) 36 | print('IMAGE_TRAIN_INPUT_PATH: %s' % FLAGS.image_data_dir) 37 | print('MODEL_INFERENCE_PATH: %s' % FLAGS.inference_dir) 38 | print("dataset_dir: %s " % FLAGS.dataset_dir) 39 | print("train_dir: %s" % FLAGS.train_dir) 40 | print("num_train: %d" % FLAGS.num_train) 41 | print("num_validation: %d" % FLAGS.num_validation) 42 | 43 | print("Begin convert data ...") 44 | t1 = time() 45 | num_classes = convert_garbage_data.run(FLAGS.image_data_dir, FLAGS.dataset_dir, FLAGS.inference_dir, 0) 46 | print("num_classes: %s" % num_classes) 47 | t2 = time() 48 | print("End convert data %d s" % (t2-t1)) 49 | 50 | print("Begin model1.execute() ...") 51 | model1.execute() 52 | 53 | print("Begin model2.execute() ...") 54 | #num_classes = convert_garbage_data.run(FLAGS.image_data_dir, FLAGS.dataset_dir, FLAGS.inference_dir, 1) 55 | model2.execute() 56 | 57 | print("Begin model3.execute() ...") 58 | #num_classes = convert_garbage_data.run(FLAGS.image_data_dir, FLAGS.dataset_dir, FLAGS.inference_dir, 2) 59 | model3.execute() 60 | 61 | if __name__ == '__main__': 62 | tf.compat.v1.app.run() 63 | -------------------------------------------------------------------------------- /datasets/garbage.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import tensorflow as tf 8 | 9 | from datasets import dataset_utils 10 | 11 | #slim = tf.contrib.slim 12 | from tensorflow.contrib import slim 13 | 14 | FLAGS = tf.app.flags.FLAGS 15 | 16 | _FILE_PATTERN = 'garbage_%s_*.tfrecord' 17 | 18 | SPLITS_TO_SIZES = {'train': FLAGS.num_train, 'validation': FLAGS.num_validation} 19 | 20 | _ITEMS_TO_DESCRIPTIONS = { 21 | 'image': 'A color image of varying size.', 22 | 'label': 'A single integer between 0 and 99', 23 | } 24 | 25 | 26 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 27 | """Gets a dataset tuple with instructions for reading garbages. 28 | 29 | Args: 30 | split_name: A train/validation split name. 31 | dataset_dir: The base directory of the dataset sources. 32 | file_pattern: The file pattern to use when matching the dataset sources. 33 | It is assumed that the pattern contains a '%s' string so that the split 34 | name can be inserted. 35 | reader: The TensorFlow reader type. 36 | 37 | Returns: 38 | A `Dataset` namedtuple. 39 | 40 | Raises: 41 | ValueError: if `split_name` is not a valid train/validation split. 42 | """ 43 | if split_name not in SPLITS_TO_SIZES: 44 | raise ValueError('split name %s was not recognized.' % split_name) 45 | 46 | if not file_pattern: 47 | file_pattern = _FILE_PATTERN 48 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 49 | 50 | # Allowing None in the signature so that dataset_factory can use the default. 51 | if reader is None: 52 | reader = tf.TFRecordReader 53 | 54 | keys_to_features = { 55 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 56 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 57 | 'image/class/label': tf.FixedLenFeature( 58 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 59 | } 60 | 61 | items_to_handlers = { 62 | 'image': slim.tfexample_decoder.Image(), 63 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 64 | } 65 | 66 | decoder = slim.tfexample_decoder.TFExampleDecoder( 67 | keys_to_features, items_to_handlers) 68 | 69 | labels_to_names = None 70 | if dataset_utils.has_labels(dataset_dir): 71 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 72 | _NUM_CLASSES = len(labels_to_names) 73 | 74 | return slim.dataset.Dataset( 75 | data_sources=file_pattern, 76 | reader=reader, 77 | decoder=decoder, 78 | num_samples=SPLITS_TO_SIZES[split_name], 79 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 80 | num_classes=_NUM_CLASSES, 81 | labels_to_names=labels_to_names) 82 | -------------------------------------------------------------------------------- /nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /datasets/preprocess_imagenet_validation_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | r"""Process the ImageNet Challenge bounding boxes for TensorFlow model training. 17 | 18 | Associate the ImageNet 2012 Challenge validation data set with labels. 19 | 20 | The raw ImageNet validation data set is expected to reside in JPEG files 21 | located in the following directory structure. 22 | 23 | data_dir/ILSVRC2012_val_00000001.JPEG 24 | data_dir/ILSVRC2012_val_00000002.JPEG 25 | ... 26 | data_dir/ILSVRC2012_val_00050000.JPEG 27 | 28 | This script moves the files into a directory structure like such: 29 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 30 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 31 | ... 32 | where 'n01440764' is the unique synset label associated with 33 | these images. 34 | 35 | This directory reorganization requires a mapping from validation image 36 | number (i.e. suffix of the original file) to the associated label. This 37 | is provided in the ImageNet development kit via a Matlab file. 38 | 39 | In order to make life easier and divorce ourselves from Matlab, we instead 40 | supply a custom text file that provides this mapping for us. 41 | 42 | Sample usage: 43 | ./preprocess_imagenet_validation_data.py ILSVRC2012_img_val \ 44 | imagenet_2012_validation_synset_labels.txt 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | 51 | import os 52 | import sys 53 | 54 | from six.moves import xrange # pylint: disable=redefined-builtin 55 | 56 | 57 | if __name__ == '__main__': 58 | if len(sys.argv) < 3: 59 | print('Invalid usage\n' 60 | 'usage: preprocess_imagenet_validation_data.py ' 61 | ' ') 62 | sys.exit(-1) 63 | data_dir = sys.argv[1] 64 | validation_labels_file = sys.argv[2] 65 | 66 | # Read in the 50000 synsets associated with the validation data set. 67 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 68 | unique_labels = set(labels) 69 | 70 | # Make all sub-directories in the validation data dir. 71 | for label in unique_labels: 72 | labeled_data_dir = os.path.join(data_dir, label) 73 | os.makedirs(labeled_data_dir) 74 | 75 | # Move all of the image to the appropriate sub-directory. 76 | for i in xrange(len(labels)): 77 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 78 | original_filename = os.path.join(data_dir, basename) 79 | if not os.path.exists(original_filename): 80 | print('Failed to find: ', original_filename) 81 | sys.exit(-1) 82 | new_filename = os.path.join(data_dir, labels[i], basename) 83 | os.rename(original_filename, new_filename) 84 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu, 37 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS, 38 | batch_norm_scale=False): 39 | """Defines the default arg scope for inception models. 40 | 41 | Args: 42 | weight_decay: The weight decay to use for regularizing the model. 43 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 44 | batch_norm_decay: Decay for batch norm moving average. 45 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 46 | in batch norm. 47 | activation_fn: Activation function for conv2d. 48 | batch_norm_updates_collections: Collection for the update ops for 49 | batch norm. 50 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 51 | activations in the batch normalization layer. 52 | 53 | Returns: 54 | An `arg_scope` to use for the inception models. 55 | """ 56 | batch_norm_params = { 57 | # Decay for the moving averages. 58 | 'decay': batch_norm_decay, 59 | # epsilon to prevent 0s in variance. 60 | 'epsilon': batch_norm_epsilon, 61 | # collection containing update_ops. 62 | 'updates_collections': batch_norm_updates_collections, 63 | # use fused batch norm if possible. 64 | 'fused': None, 65 | 'scale': batch_norm_scale, 66 | } 67 | if use_batch_norm: 68 | normalizer_fn = slim.batch_norm 69 | normalizer_params = batch_norm_params 70 | else: 71 | normalizer_fn = None 72 | normalizer_params = {} 73 | # Set weight_decay for weights in Conv and FC layers. 74 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 75 | weights_regularizer=slim.l2_regularizer(weight_decay)): 76 | with slim.arg_scope( 77 | [slim.conv2d], 78 | weights_initializer=slim.variance_scaling_initializer(), 79 | activation_fn=activation_fn, 80 | normalizer_fn=normalizer_fn, 81 | normalizer_params=normalizer_params) as sc: 82 | return sc 83 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'mobilenet_v2': inception_preprocessing, 58 | 'mobilenet_v2_035': inception_preprocessing, 59 | 'mobilenet_v2_140': inception_preprocessing, 60 | 'nasnet_mobile': inception_preprocessing, 61 | 'nasnet_large': inception_preprocessing, 62 | 'pnasnet_mobile': inception_preprocessing, 63 | 'pnasnet_large': inception_preprocessing, 64 | 'resnet_v1_50': vgg_preprocessing, 65 | 'resnet_v1_101': vgg_preprocessing, 66 | 'resnet_v1_152': vgg_preprocessing, 67 | 'resnet_v1_200': vgg_preprocessing, 68 | 'resnet_v2_50': vgg_preprocessing, 69 | 'resnet_v2_101': vgg_preprocessing, 70 | 'resnet_v2_152': vgg_preprocessing, 71 | 'resnet_v2_200': vgg_preprocessing, 72 | 'vgg': vgg_preprocessing, 73 | 'vgg_a': vgg_preprocessing, 74 | 'vgg_16': vgg_preprocessing, 75 | 'vgg_19': vgg_preprocessing, 76 | } 77 | 78 | if name not in preprocessing_fn_map: 79 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 80 | 81 | def preprocessing_fn(image, output_height, output_width, **kwargs): 82 | return preprocessing_fn_map[name].preprocess_image( 83 | image, output_height, output_width, is_training=is_training, **kwargs) 84 | 85 | return preprocessing_fn 86 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_cifar10.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [32 x 32 x 3] color image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if not reader: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/flowers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the flowers dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_flowers.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'flowers_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 3320, 'validation': 350} 35 | 36 | _NUM_CLASSES = 5 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading flowers. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the MNIST dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_mnist.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'mnist_%s.tfrecord' 33 | 34 | _SPLITS_TO_SIZES = {'train': 60000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [28 x 28 x 1] grayscale image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading MNIST. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in _SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[28, 28, 1], channels=1), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=_SPLITS_TO_SIZES[split_name], 96 | num_classes=_NUM_CLASSES, 97 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in list(nets_factory.networks_map.keys())[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes=num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | if net not in ['i3d', 's3dg']: 39 | inputs = tf.random_uniform( 40 | (batch_size, image_size, image_size, 3)) 41 | logits, end_points = net_fn(inputs) 42 | self.assertTrue(isinstance(logits, tf.Tensor)) 43 | self.assertTrue(isinstance(end_points, dict)) 44 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 45 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 46 | 47 | def testGetNetworkFnSecondHalf(self): 48 | batch_size = 5 49 | num_classes = 1000 50 | for net in list(nets_factory.networks_map.keys())[10:]: 51 | with tf.Graph().as_default() as g, self.test_session(g): 52 | net_fn = nets_factory.get_network_fn(net, num_classes=num_classes) 53 | # Most networks use 224 as their default_image_size 54 | image_size = getattr(net_fn, 'default_image_size', 224) 55 | if net not in ['i3d', 's3dg']: 56 | inputs = tf.random_uniform( 57 | (batch_size, image_size, image_size, 3)) 58 | logits, end_points = net_fn(inputs) 59 | self.assertTrue(isinstance(logits, tf.Tensor)) 60 | self.assertTrue(isinstance(end_points, dict)) 61 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 62 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 63 | 64 | def testGetNetworkFnVideoModels(self): 65 | batch_size = 5 66 | num_classes = 400 67 | for net in ['i3d', 's3dg']: 68 | with tf.Graph().as_default() as g, self.test_session(g): 69 | net_fn = nets_factory.get_network_fn(net, num_classes=num_classes) 70 | # Most networks use 224 as their default_image_size 71 | image_size = getattr(net_fn, 'default_image_size', 224) // 2 72 | inputs = tf.random_uniform( 73 | (batch_size, 10, image_size, image_size, 3)) 74 | logits, end_points = net_fn(inputs) 75 | self.assertTrue(isinstance(logits, tf.Tensor)) 76 | self.assertTrue(isinstance(end_points, dict)) 77 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 78 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 79 | 80 | if __name__ == '__main__': 81 | tf.test.main() 82 | -------------------------------------------------------------------------------- /datasets/download_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download ImageNet Challenge 2012 training and validation data set. 18 | # 19 | # Downloads and decompresses raw images and bounding boxes. 20 | # 21 | # **IMPORTANT** 22 | # To download the raw images, the user must create an account with image-net.org 23 | # and generate a username and access_key. The latter two are required for 24 | # downloading the raw images. 25 | # 26 | # usage: 27 | # ./download_imagenet.sh [dirname] 28 | set -e 29 | 30 | if [ "x$IMAGENET_ACCESS_KEY" == x -o "x$IMAGENET_USERNAME" == x ]; then 31 | cat <"${BOUNDING_BOX_FILE}" 90 | echo "Finished downloading and preprocessing the ImageNet data." 91 | 92 | # Build the TFRecords version of the ImageNet data. 93 | BUILD_SCRIPT="${WORK_DIR}/build_imagenet_data" 94 | OUTPUT_DIRECTORY="${DATA_DIR}" 95 | IMAGENET_METADATA_FILE="${WORK_DIR}/datasets/imagenet_metadata.txt" 96 | 97 | "${BUILD_SCRIPT}" \ 98 | --train_directory="${TRAIN_DIRECTORY}" \ 99 | --validation_directory="${VALIDATION_DIRECTORY}" \ 100 | --output_directory="${OUTPUT_DIRECTORY}" \ 101 | --imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \ 102 | --labels_file="${LABELS_FILE}" \ 103 | --bounding_box_file="${BOUNDING_BOX_FILE}" 104 | -------------------------------------------------------------------------------- /export_to_saved_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import os 7 | from datetime import datetime 8 | from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def 9 | from tensorflow.contrib import slim 10 | from nets import nets_factory 11 | from datasets import dataset_utils 12 | 13 | FLAGS = tf.app.flags.FLAGS 14 | 15 | def export(checkpoint_path, modelNo): 16 | 17 | print("Begin exporting %s" % format(datetime.now().isoformat())) 18 | 19 | saved_model_dir = "SavedModel" 20 | 21 | inference_dir = os.environ['MODEL_INFERENCE_PATH'] 22 | export_dir = os.path.join(inference_dir, saved_model_dir, modelNo, "SavedModel") 23 | 24 | print("The path of saved model: %s"%export_dir) 25 | 26 | if tf.gfile.Exists(export_dir): 27 | print('Saved model folder already exist. Delete it firstly.') 28 | if(export_dir.endswith(saved_model_dir)): 29 | tf.gfile.DeleteRecursively(export_dir) 30 | 31 | if(checkpoint_path==None): 32 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir) 33 | 34 | print("checkpoint_path: %s"%checkpoint_path) 35 | 36 | with tf.Graph().as_default() as graph: 37 | tf_global_step = slim.get_or_create_global_step() 38 | labels_to_names = dataset_utils.read_label_file(FLAGS.dataset_dir) 39 | num_classes = len(labels_to_names) 40 | network_fn = nets_factory.get_network_fn( 41 | FLAGS.model_name, 42 | num_classes=(num_classes - FLAGS.labels_offset), 43 | weight_decay=FLAGS.weight_decay, 44 | is_training=False) 45 | 46 | input_shape = [None, FLAGS.train_image_size, FLAGS.train_image_size, 3] 47 | input_tensor = tf.placeholder(name='input_1', dtype=tf.float32, shape=input_shape) 48 | 49 | predictions_key = "Predictions" 50 | if FLAGS.model_name.startswith("resnet"): 51 | logits, endpoints = network_fn(input_tensor) 52 | predictions_key = "predictions" 53 | elif FLAGS.model_name.startswith("inception"): 54 | logits, endpoints = network_fn(input_tensor, create_aux_logits=False) 55 | elif FLAGS.model_name.startswith("nasnet_mobile"): 56 | logits, endpoints = network_fn(input_tensor, use_aux_head=0) 57 | 58 | predictions = endpoints[predictions_key] 59 | 60 | if FLAGS.moving_average_decay: 61 | variable_averages = tf.train.ExponentialMovingAverage( 62 | FLAGS.moving_average_decay, tf_global_step) 63 | variables_to_restore = variable_averages.variables_to_restore( 64 | slim.get_model_variables()) 65 | variables_to_restore[tf_global_step.op.name] = tf_global_step 66 | else: 67 | variables_to_restore = slim.get_variables_to_restore() 68 | 69 | saver = tf.train.Saver(var_list=variables_to_restore) #Same as slim.get_variables() 70 | 71 | init1 = tf.global_variables_initializer() 72 | init2 = tf.local_variables_initializer() 73 | with tf.Session() as sess: 74 | sess.run(init1) 75 | sess.run(init2) 76 | saver.restore(sess, checkpoint_path) 77 | 78 | #uninitialized_variables = [str(v, 'utf-8') for v in set(sess.run(tf.report_uninitialized_variables()))] 79 | #print(uninitialized_variables) 80 | #tf.graph_util.convert_variables_to_constants() 81 | 82 | print("Exporting saved model to: %s" % export_dir) 83 | 84 | prediction_signature = predict_signature_def( 85 | inputs={'input_1': input_tensor}, 86 | outputs={'output': predictions}) 87 | 88 | signature_def_map = { 89 | tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature 90 | } 91 | 92 | builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 93 | 94 | builder.add_meta_graph_and_variables( 95 | sess, 96 | tags = [tf.saved_model.tag_constants.SERVING], 97 | signature_def_map=signature_def_map, 98 | clear_devices=True, 99 | main_op=None, #Suggest tf.tables_initializer()? 100 | strip_default_attrs=False) #Suggest True? 101 | builder.save() 102 | sess.close() 103 | print("Done exporting %s" % format(datetime.now().isoformat())) 104 | 105 | -------------------------------------------------------------------------------- /nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | 24 | from nets import dcgan 25 | 26 | 27 | class DCGANTest(tf.test.TestCase): 28 | 29 | def test_generator_run(self): 30 | tf.set_random_seed(1234) 31 | noise = tf.random_normal([100, 64]) 32 | image, _ = dcgan.generator(noise) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | image.eval() 36 | 37 | def test_generator_graph(self): 38 | tf.set_random_seed(1234) 39 | # Check graph construction for a number of image size/depths and batch 40 | # sizes. 41 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 42 | tf.reset_default_graph() 43 | final_size = 2 ** i 44 | noise = tf.random_normal([batch_size, 64]) 45 | image, end_points = dcgan.generator( 46 | noise, 47 | depth=32, 48 | final_size=final_size) 49 | 50 | self.assertAllEqual([batch_size, final_size, final_size, 3], 51 | image.shape.as_list()) 52 | 53 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 54 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 55 | 56 | # Check layer depths. 57 | for j in range(1, i): 58 | layer = end_points['deconv%i' % j] 59 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 60 | 61 | def test_generator_invalid_input(self): 62 | wrong_dim_input = tf.zeros([5, 32, 32]) 63 | with self.assertRaises(ValueError): 64 | dcgan.generator(wrong_dim_input) 65 | 66 | correct_input = tf.zeros([3, 2]) 67 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 68 | dcgan.generator(correct_input, final_size=30) 69 | 70 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 71 | dcgan.generator(correct_input, final_size=4) 72 | 73 | def test_discriminator_run(self): 74 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 75 | output, _ = dcgan.discriminator(image) 76 | with self.test_session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | output.eval() 79 | 80 | def test_discriminator_graph(self): 81 | # Check graph construction for a number of image size/depths and batch 82 | # sizes. 83 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 84 | tf.reset_default_graph() 85 | img_w = 2 ** i 86 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 87 | output, end_points = dcgan.discriminator( 88 | image, 89 | depth=32) 90 | 91 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 92 | 93 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 94 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 95 | 96 | # Check layer depths. 97 | for j in range(1, i+1): 98 | layer = end_points['conv%i' % j] 99 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 100 | 101 | def test_discriminator_invalid_input(self): 102 | wrong_dim_img = tf.zeros([5, 32, 32]) 103 | with self.assertRaises(ValueError): 104 | dcgan.discriminator(wrong_dim_img) 105 | 106 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 107 | with self.assertRaises(ValueError): 108 | dcgan.discriminator(spatially_undefined_shape) 109 | 110 | not_square = tf.zeros([5, 32, 16, 3]) 111 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 112 | dcgan.discriminator(not_square) 113 | 114 | not_power_2 = tf.zeros([5, 30, 30, 3]) 115 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 116 | dcgan.discriminator(not_power_2) 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.test.main() 121 | -------------------------------------------------------------------------------- /nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /datasets/visualwakewords.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for Visual WakeWords Dataset with images+labels. 16 | 17 | Visual WakeWords Dataset derives from the COCO dataset to design tiny models 18 | classifying two classes, such as person/not-person. The COCO annotations 19 | are filtered to two classes: person and not-person (or another user-defined 20 | category). Bounding boxes for small objects with area less than 5% of the image 21 | area are filtered out. 22 | See build_visualwakewords_data.py which generates the Visual WakeWords dataset 23 | annotations from the raw COCO dataset and converts them to TFRecord. 24 | 25 | """ 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | import os 31 | import tensorflow as tf 32 | 33 | from datasets import dataset_utils 34 | 35 | 36 | slim = tf.contrib.slim 37 | 38 | _FILE_PATTERN = '%s.record-*' 39 | 40 | _SPLITS_TO_SIZES = { 41 | 'train': 82783, 42 | 'validation': 40504, 43 | } 44 | 45 | 46 | _ITEMS_TO_DESCRIPTIONS = { 47 | 'image': 'A color image of varying height and width.', 48 | 'label': 'The label id of the image, an integer in {0, 1}', 49 | 'object/bbox': 'A list of bounding boxes.', 50 | 'object/label': 'A list of labels, all objects belong to the same class.', 51 | } 52 | 53 | _NUM_CLASSES = 2 54 | 55 | # labels file 56 | LABELS_FILENAME = 'labels.txt' 57 | 58 | 59 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 60 | """Gets a dataset tuple with instructions for reading ImageNet. 61 | 62 | Args: 63 | split_name: A train/test split name. 64 | dataset_dir: The base directory of the dataset sources. 65 | file_pattern: The file pattern to use when matching the dataset sources. It 66 | is assumed that the pattern contains a '%s' string so that the split name 67 | can be inserted. 68 | reader: The TensorFlow reader type. 69 | 70 | Returns: 71 | A `Dataset` namedtuple. 72 | 73 | Raises: 74 | ValueError: if `split_name` is not a valid train/test split. 75 | """ 76 | if split_name not in _SPLITS_TO_SIZES: 77 | raise ValueError('split name %s was not recognized.' % split_name) 78 | 79 | if not file_pattern: 80 | file_pattern = _FILE_PATTERN 81 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 82 | 83 | # Allowing None in the signature so that dataset_factory can use the default. 84 | if reader is None: 85 | reader = tf.TFRecordReader 86 | 87 | keys_to_features = { 88 | 'image/encoded': 89 | tf.FixedLenFeature((), tf.string, default_value=''), 90 | 'image/format': 91 | tf.FixedLenFeature((), tf.string, default_value='jpeg'), 92 | 'image/class/label': 93 | tf.FixedLenFeature([], dtype=tf.int64, default_value=-1), 94 | 'image/object/bbox/xmin': 95 | tf.VarLenFeature(dtype=tf.float32), 96 | 'image/object/bbox/ymin': 97 | tf.VarLenFeature(dtype=tf.float32), 98 | 'image/object/bbox/xmax': 99 | tf.VarLenFeature(dtype=tf.float32), 100 | 'image/object/bbox/ymax': 101 | tf.VarLenFeature(dtype=tf.float32), 102 | 'image/object/class/label': 103 | tf.VarLenFeature(dtype=tf.int64), 104 | } 105 | 106 | items_to_handlers = { 107 | 'image': 108 | slim.tfexample_decoder.Image('image/encoded', 'image/format'), 109 | 'label': 110 | slim.tfexample_decoder.Tensor('image/class/label'), 111 | 'object/bbox': 112 | slim.tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'], 113 | 'image/object/bbox/'), 114 | 'object/label': 115 | slim.tfexample_decoder.Tensor('image/object/class/label'), 116 | } 117 | 118 | decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, 119 | items_to_handlers) 120 | 121 | labels_to_names = None 122 | labels_file = os.path.join(dataset_dir, LABELS_FILENAME) 123 | if tf.gfile.Exists(labels_file): 124 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 125 | 126 | return slim.dataset.Dataset( 127 | data_sources=file_pattern, 128 | reader=reader, 129 | decoder=decoder, 130 | num_samples=_SPLITS_TO_SIZES[split_name], 131 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 132 | num_classes=_NUM_CLASSES, 133 | labels_to_names=labels_to_names) 134 | -------------------------------------------------------------------------------- /preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /nets/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 2 | This folder contains building code for MobileNetV2, based on 3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 4 | 5 | # Performance 6 | ## Latency 7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using 8 | TF-Lite on the large core of Pixel 1 phone. 9 | 10 | ![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png) 11 | 12 | ## MACs 13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed 14 | to compute an inference on a single image is a common metric to measure the efficiency of the model. 15 | 16 | Below is the graph comparing V2 vs a few selected networks. The size 17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there 18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers. 19 | 20 | ![madds_top1_accuracy](madds_top1_accuracy.png) 21 | 22 | # Pretrained models 23 | ## Imagenet Checkpoints 24 | 25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1 26 | ---------------------------|---------|---------------|---------|----|------------- 27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0 28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0 29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8 30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1 31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2 32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6 33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6 34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8 35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6 36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4 37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9 38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2 39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7 40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1 41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9 42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9 43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4 44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7 45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6 46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5 47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9 48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5 49 | 50 | # Training 51 | The numbers above can be reproduced using slim's `train_image_classifier`. 52 | Below is the set of parameters that achieves 72.0% for full size MobileNetV2, after about 700K when trained on 8 GPU. 53 | If trained on a single GPU the full convergence is after 5.5M steps. Also note that learning rate and 54 | num_epochs_per_decay both need to be adjusted depending on how many GPUs are being 55 | used due to slim's internal averaging. 56 | 57 | ```bash 58 | --model_name="mobilenet_v2" 59 | --learning_rate=0.045 * NUM_GPUS #slim internally averages clones so we compensate 60 | --preprocessing_name="inception_v2" 61 | --label_smoothing=0.1 62 | --moving_average_decay=0.9999 63 | --batch_size= 96 64 | --num_clones = NUM_GPUS # you can use any number here between 1 and 8 depending on your hardware setup. 65 | --learning_rate_decay_factor=0.98 66 | --num_epochs_per_decay = 2.5 / NUM_GPUS # train_image_classifier does per clone epochs 67 | ``` 68 | 69 | # Example 70 | 71 | 72 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb). 73 | 74 | -------------------------------------------------------------------------------- /nets/mobilenet_v1_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Validate mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import mobilenet_v1 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | flags = tf.app.flags 31 | 32 | flags.DEFINE_string('master', '', 'Session master') 33 | flags.DEFINE_integer('batch_size', 250, 'Batch size') 34 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 35 | flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate') 36 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 37 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 38 | flags.DEFINE_bool('quantize', False, 'Quantize training') 39 | flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints') 40 | flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs') 41 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def imagenet_input(is_training): 47 | """Data reader for imagenet. 48 | 49 | Reads in imagenet data and performs pre-processing on the images. 50 | 51 | Args: 52 | is_training: bool specifying if train or validation dataset is needed. 53 | Returns: 54 | A batch of images and labels. 55 | """ 56 | if is_training: 57 | dataset = dataset_factory.get_dataset('imagenet', 'train', 58 | FLAGS.dataset_dir) 59 | else: 60 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 61 | FLAGS.dataset_dir) 62 | 63 | provider = slim.dataset_data_provider.DatasetDataProvider( 64 | dataset, 65 | shuffle=is_training, 66 | common_queue_capacity=2 * FLAGS.batch_size, 67 | common_queue_min=FLAGS.batch_size) 68 | [image, label] = provider.get(['image', 'label']) 69 | 70 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 71 | 'mobilenet_v1', is_training=is_training) 72 | 73 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 74 | 75 | images, labels = tf.train.batch( 76 | tensors=[image, label], 77 | batch_size=FLAGS.batch_size, 78 | num_threads=4, 79 | capacity=5 * FLAGS.batch_size) 80 | return images, labels 81 | 82 | 83 | def metrics(logits, labels): 84 | """Specify the metrics for eval. 85 | 86 | Args: 87 | logits: Logits output from the graph. 88 | labels: Ground truth labels for inputs. 89 | 90 | Returns: 91 | Eval Op for the graph. 92 | """ 93 | labels = tf.squeeze(labels) 94 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 95 | 'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels), 96 | 'Recall_5': tf.metrics.recall_at_k(labels, logits, 5), 97 | }) 98 | for name, value in names_to_values.iteritems(): 99 | slim.summaries.add_scalar_summary( 100 | value, name, prefix='eval', print_summary=True) 101 | return names_to_updates.values() 102 | 103 | 104 | def build_model(): 105 | """Build the mobilenet_v1 model for evaluation. 106 | 107 | Returns: 108 | g: graph with rewrites after insertion of quantization ops and batch norm 109 | folding. 110 | eval_ops: eval ops for inference. 111 | variables_to_restore: List of variables to restore from checkpoint. 112 | """ 113 | g = tf.Graph() 114 | with g.as_default(): 115 | inputs, labels = imagenet_input(is_training=False) 116 | 117 | scope = mobilenet_v1.mobilenet_v1_arg_scope( 118 | is_training=False, weight_decay=0.0) 119 | with slim.arg_scope(scope): 120 | logits, _ = mobilenet_v1.mobilenet_v1( 121 | inputs, 122 | is_training=False, 123 | depth_multiplier=FLAGS.depth_multiplier, 124 | num_classes=FLAGS.num_classes) 125 | 126 | if FLAGS.quantize: 127 | tf.contrib.quantize.create_eval_graph() 128 | 129 | eval_ops = metrics(logits, labels) 130 | 131 | return g, eval_ops 132 | 133 | 134 | def eval_model(): 135 | """Evaluates mobilenet_v1.""" 136 | g, eval_ops = build_model() 137 | with g.as_default(): 138 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) 139 | slim.evaluation.evaluate_once( 140 | FLAGS.master, 141 | FLAGS.checkpoint_dir, 142 | logdir=FLAGS.eval_dir, 143 | num_evals=num_batches, 144 | eval_op=eval_ops) 145 | 146 | 147 | def main(unused_arg): 148 | eval_model() 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run(main) 153 | -------------------------------------------------------------------------------- /eval_image_classifier_original.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import os 7 | 8 | from time import time 9 | from datetime import datetime 10 | from nets import nets_factory 11 | from datasets import dataset_factory 12 | from preprocessing import preprocessing_factory 13 | from tensorflow.contrib import slim 14 | import math 15 | import sys 16 | 17 | FLAGS = tf.app.flags.FLAGS 18 | 19 | batch_size = 100 20 | max_num_batches = FLAGS.num_validation // batch_size 21 | dataset_split_name = "train" 22 | master = "" 23 | quantize = False 24 | 25 | def print_train_acc(): 26 | if not FLAGS.dataset_dir: 27 | raise ValueError('You must supply the dataset directory with --dataset_dir') 28 | 29 | print("Begin evaluate train_accuracy %s" % format(datetime.now().isoformat())) 30 | t1 = time() 31 | 32 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO) 33 | with tf.Graph().as_default(): 34 | tf_global_step = slim.get_or_create_global_step() 35 | 36 | ###################### 37 | # Select the dataset # 38 | ###################### 39 | dataset = dataset_factory.get_dataset( 40 | FLAGS.dataset_name, dataset_split_name, FLAGS.dataset_dir) 41 | 42 | #################### 43 | # Select the model # 44 | #################### 45 | network_fn = nets_factory.get_network_fn( 46 | FLAGS.model_name, 47 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 48 | is_training=False) 49 | 50 | ############################################################## 51 | # Create a dataset provider that loads data from the dataset # 52 | ############################################################## 53 | provider = slim.dataset_data_provider.DatasetDataProvider( 54 | dataset, 55 | shuffle=False, 56 | common_queue_capacity=2 * batch_size, 57 | common_queue_min=batch_size) 58 | [image, label] = provider.get(['image', 'label']) 59 | label -= FLAGS.labels_offset 60 | 61 | ##################################### 62 | # Select the preprocessing function # 63 | ##################################### 64 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 65 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 66 | preprocessing_name, 67 | is_training=False) 68 | 69 | eval_image_size = FLAGS.train_image_size 70 | 71 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size) 72 | 73 | images, labels = tf.train.batch( 74 | [image, label], 75 | batch_size=batch_size, 76 | num_threads=FLAGS.num_preprocessing_threads, 77 | capacity=5 * batch_size) 78 | 79 | #################### 80 | # Define the model # 81 | #################### 82 | logits, _ = network_fn(images) 83 | 84 | if quantize: 85 | tf.contrib.quantize.create_eval_graph() 86 | 87 | if FLAGS.moving_average_decay: 88 | variable_averages = tf.train.ExponentialMovingAverage( 89 | FLAGS.moving_average_decay, tf_global_step) 90 | variables_to_restore = variable_averages.variables_to_restore( 91 | slim.get_model_variables()) 92 | variables_to_restore[tf_global_step.op.name] = tf_global_step 93 | else: 94 | variables_to_restore = slim.get_variables_to_restore() 95 | 96 | predictions = tf.argmax(logits, 1) 97 | labels = tf.squeeze(labels) 98 | 99 | # Define the metrics: 100 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 101 | 'Train_Accuracy': slim.metrics.streaming_accuracy(predictions, labels), 102 | 'Train_Recall_5': slim.metrics.streaming_recall_at_k( 103 | logits, labels, 5), 104 | }) 105 | 106 | # Print the summaries to screen. 107 | for name, value in names_to_values.items(): 108 | summary_name = 'eval/%s' % name 109 | op = tf.summary.scalar(summary_name, value, collections=[]) 110 | op = tf.Print(op, [value], summary_name) 111 | tf.add_to_collection(tf.GraphKeys.SUMMARIES, op) 112 | 113 | if max_num_batches: 114 | num_batches = max_num_batches 115 | else: 116 | # This ensures that we make a single pass over all of the data. 117 | num_batches = math.ceil(dataset.num_samples / float(batch_size)) 118 | 119 | if tf.gfile.IsDirectory(FLAGS.train_dir): 120 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir) 121 | else: 122 | checkpoint_path = FLAGS.train_dir 123 | 124 | tf.logging.info('Evaluating on train_dir: %s' % FLAGS.train_dir) 125 | 126 | dirlist = os.listdir(FLAGS.train_dir) 127 | file_numbers = [] 128 | for file_name in dirlist: 129 | if file_name.startswith("model.ckpt-") & file_name.endswith(".index"): 130 | idx = file_name.replace("model.ckpt-", "").replace(".index", "") 131 | file_numbers.append(int(idx)) 132 | file_numbers.sort() 133 | for file_number in file_numbers[-1:]: 134 | file_name = "model.ckpt-%d" % file_number 135 | checkpoint_path = os.path.join(FLAGS.train_dir, file_name) 136 | tf.logging.info('Evaluating %s' % checkpoint_path) 137 | slim.evaluation.evaluate_once( 138 | master=master, 139 | checkpoint_path=checkpoint_path, 140 | logdir=FLAGS.train_dir, 141 | num_evals=num_batches, 142 | eval_op=list(names_to_updates.values()), 143 | variables_to_restore=variables_to_restore) 144 | 145 | t2 = time() 146 | print("End train_accuracy %d s" %(t2 - t1)) 147 | sys.stdout.flush() 148 | 149 | -------------------------------------------------------------------------------- /nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /model1_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import os 7 | 8 | def init_params(): 9 | 10 | FLAGS = tf.app.flags.FLAGS 11 | _checkpoint_path = os.path.join(FLAGS.script_root_dir, "pre_train", "resnet_v1_101.ckpt") 12 | 13 | ##################### 14 | # Basic Flags # 15 | ##################### 16 | tf.app.flags.DEFINE_string('checkpoint_path', _checkpoint_path, 'The path to a checkpoint from which to fine-tune.') 17 | tf.app.flags.DEFINE_integer('max_number_of_steps', 3500, 'The maximum number of training steps.') 18 | # 3680: 20 epochs, 4800: 26 epochs 19 | tf.app.flags.DEFINE_integer('log_every_n_steps', 30, 'The frequency with which logs are print.') 20 | tf.app.flags.DEFINE_integer('save_summaries_secs', 600, 'The frequency with which summaries are saved, in seconds.') 21 | tf.app.flags.DEFINE_integer('save_interval_secs', 68, 'The frequency with which the model is saved, in seconds.') #68 22 | tf.app.flags.DEFINE_integer('max_to_keep', 12, 'max number of checkpoint to keep') 23 | tf.app.flags.DEFINE_integer('num_readers', 4, 'The number of parallel readers that read data from the dataset.') 24 | tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4, 'The number of threads used to create the batches.') 25 | tf.app.flags.DEFINE_integer('num_clones', 1, 'Number of model clones to deploy.') 26 | tf.app.flags.DEFINE_boolean('clone_on_cpu', True, 'Use CPUs to deploy clones.') 27 | 28 | tf.app.flags.DEFINE_string('dataset_name', 'garbage', 'The name of the dataset to load.') 29 | tf.app.flags.DEFINE_string('dataset_split_name', 'train', 'The name of the train/test split.') 30 | tf.app.flags.DEFINE_integer('batch_size', 32, 'The number of samples in each batch.') 31 | tf.app.flags.DEFINE_integer('train_image_size', 224, 'Train image size') 32 | tf.app.flags.DEFINE_string('model_name', 'resnet_v1_101', 'The name of the architecture to train.') 33 | tf.app.flags.DEFINE_string('preprocessing_name', None, 'The name of the preprocessing to use. If None, then the model_name flag is used.') 34 | 35 | tf.app.flags.DEFINE_integer('labels_offset', 0, 'An offset for the labels in the dataset.') 36 | 37 | ##################### 38 | # Fine-Tuning Flags # 39 | ##################### 40 | tf.app.flags.DEFINE_string('checkpoint_exclude_scopes', 'resnet_v1_101/logits', 'Comma-separated list of scopes of variables.') 41 | tf.app.flags.DEFINE_string('trainable_scopes', 'resnet_v1_101/logits', 'Comma-separated list of scopes.') 42 | tf.app.flags.DEFINE_boolean('ignore_missing_vars', False, 'When restoring a checkpoint would ignore missing variables.') 43 | 44 | ####################### 45 | # Distribution and clone Flags # 46 | ####################### 47 | tf.app.flags.DEFINE_string('master', '', 'The address of the TensorFlow master to use.') 48 | tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.') 49 | tf.app.flags.DEFINE_integer('num_ps_tasks', 0, 'The number of parameter servers.') 50 | tf.app.flags.DEFINE_integer('task', 0, 'Task id of the replica running the training.') 51 | tf.app.flags.DEFINE_bool('sync_replicas', False, 'Whether or not to synchronize the replicas during training.') 52 | tf.app.flags.DEFINE_integer('replicas_to_aggregate', 1, 'The Number of gradients to collect before updating params.') 53 | 54 | ###################### 55 | # Optimization Flags # 56 | ###################### 57 | tf.app.flags.DEFINE_string('optimizer', 'adam', '"adadelta", "adagrad", "adam", "ftrl", "momentum", "sgd" or "rmsprop".') 58 | tf.app.flags.DEFINE_float('weight_decay', 0.00004, 'The weight decay on the model weights.') #0.00004 59 | tf.app.flags.DEFINE_float('adadelta_rho', 0.95, 'The decay rate for adadelta.') 60 | tf.app.flags.DEFINE_float('adagrad_initial_accumulator_value', 0.1, 'Starting value for the AdaGrad accumulators.') 61 | tf.app.flags.DEFINE_float('adam_beta1', 0.9, 'The exponential decay rate for the 1st moment estimates.') 62 | tf.app.flags.DEFINE_float('adam_beta2', 0.999, 'The exponential decay rate for the 2nd moment estimates.') 63 | tf.app.flags.DEFINE_float('opt_epsilon', 1e-8, 'Epsilon term for the optimizer.') #1.0 64 | tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5, 'The learning rate power.') 65 | tf.app.flags.DEFINE_float('ftrl_initial_accumulator_value', 0.1, 'Starting value for the FTRL accumulators.') 66 | tf.app.flags.DEFINE_float('ftrl_l1', 0.0, 'The FTRL l1 regularization strength.') 67 | tf.app.flags.DEFINE_float('ftrl_l2', 0.0, 'The FTRL l2 regularization strength.') 68 | tf.app.flags.DEFINE_float('momentum', 0.9, 'The momentum for the MomentumOptimizer and RMSPropOptimizer.') 69 | tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.') 70 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') 71 | tf.app.flags.DEFINE_integer('quantize_delay', -1,'Number of steps to start quantized training. Set to -1 would disable quantized training.') 72 | 73 | ####################### 74 | # Learning Rate Flags # 75 | ####################### 76 | tf.app.flags.DEFINE_string( 77 | 'learning_rate_decay_type', 78 | 'exponential', 79 | 'Specifies how the learning rate is decayed. One of "fixed", "exponential",' 80 | ' or "polynomial"') 81 | 82 | tf.app.flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.') #0.01, 0.001 83 | tf.app.flags.DEFINE_float('end_learning_rate', 0.000001, 'The minimal end learning rate used by a polynomial decay learning rate.')#0.0001 84 | tf.app.flags.DEFINE_float('label_smoothing', 0.0, 'The amount of label smoothing.') 85 | tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.3, 'Learning rate decay factor.') #0.94, 0.3 86 | 87 | tf.app.flags.DEFINE_float( 88 | 'num_steps_per_decay', 1000.0, 89 | 'Number of epochs after which learning rate decays. Note: this flag counts ' 90 | 'epochs per clone but aggregates per sync replicas. So 1.0 means that ' 91 | 'each clone will go over full epoch individually, but replicas will go ' 92 | 'once across all replicas.') 93 | 94 | tf.app.flags.DEFINE_float('moving_average_decay', None, 'The decay to use for the moving average.' 95 | 'If left as None, then moving averages are not used.') #None, 0.99 96 | -------------------------------------------------------------------------------- /eval_image_classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import os 7 | 8 | from time import time 9 | from datetime import datetime 10 | from nets import nets_factory 11 | from datasets import dataset_utils 12 | from preprocessing import preprocessing_factory 13 | from tensorflow.contrib import slim 14 | #from tensorflow.python.training import checkpoint_management 15 | import sys 16 | 17 | FLAGS = tf.app.flags.FLAGS 18 | 19 | batch_size = FLAGS.num_validation 20 | max_num_batches = 1 21 | dataset_split_name = "validation" 22 | 23 | def find_max_accuracy_checkpoint(): 24 | 25 | if FLAGS.num_validation == 0: 26 | print("FLAGS.num_validation is 0, no need to validation") 27 | return None 28 | 29 | print("Begin evaluate val_accuracy %s" % format(datetime.now().isoformat())) 30 | t1 = time() 31 | 32 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 33 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 34 | preprocessing_name, is_training=False) 35 | 36 | labels_to_names = dataset_utils.read_label_file(FLAGS.dataset_dir) 37 | num_classes = len(labels_to_names) 38 | 39 | def decode(serialized_example): 40 | feature = { 41 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 42 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 43 | 'image/class/label': tf.FixedLenFeature( 44 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 45 | } 46 | features = tf.parse_single_example(serialized_example, features=feature) 47 | # image 48 | image_string = features['image/encoded'] 49 | image = tf.image.decode_jpeg(image_string, channels=3) 50 | image = image_preprocessing_fn(image, FLAGS.train_image_size, FLAGS.train_image_size) 51 | # label 52 | label = features['image/class/label'] 53 | label = tf.one_hot(label, num_classes) 54 | return image, label 55 | 56 | def input_iter(filenames, batch_size, num_epochs): 57 | if not num_epochs: 58 | num_epochs = 1 59 | dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=FLAGS.num_readers) 60 | dataset = dataset.map(decode) 61 | dataset = dataset.repeat(num_epochs) 62 | dataset = dataset.batch(batch_size) 63 | # dataset = dataset.shuffle(buffer_size=NUM_IMAGES) 64 | iterator = dataset.make_one_shot_iterator() 65 | return iterator 66 | 67 | with tf.Graph().as_default() as graph: 68 | tf_global_step = slim.get_or_create_global_step() 69 | network_fn = nets_factory.get_network_fn( 70 | FLAGS.model_name, 71 | num_classes=(num_classes - FLAGS.labels_offset), 72 | weight_decay=FLAGS.weight_decay, 73 | is_training=False) 74 | 75 | eval_image_size = FLAGS.train_image_size 76 | 77 | x = tf.placeholder(tf.float32, [None, eval_image_size, eval_image_size, 3]) 78 | y_ = tf.placeholder(tf.float32, [None, num_classes]) 79 | 80 | logits, endpoints = network_fn(x) 81 | 82 | predictions_key = "Predictions" 83 | if FLAGS.model_name.startswith("resnet"): 84 | predictions_key = "predictions" 85 | t_prediction = endpoints[predictions_key] 86 | 87 | if FLAGS.moving_average_decay: 88 | variable_averages = tf.train.ExponentialMovingAverage( 89 | FLAGS.moving_average_decay, tf_global_step) 90 | variables_to_restore = variable_averages.variables_to_restore( 91 | slim.get_model_variables()) 92 | variables_to_restore[tf_global_step.op.name] = tf_global_step 93 | else: 94 | variables_to_restore = slim.get_variables_to_restore() 95 | 96 | predictions = tf.argmax(t_prediction, 1, name="prediction") 97 | test_labels = tf.argmax(y_, 1, name="label") 98 | correct_prediction = tf.equal(predictions, test_labels) 99 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy") 100 | 101 | input_dir = [] 102 | for i in range(5) : 103 | data_file = os.path.join(FLAGS.dataset_dir, "garbage_validation_0000%d-of-00005.tfrecord")%i 104 | input_dir.append(data_file) 105 | 106 | iter = input_iter(input_dir, batch_size, 1) 107 | next_batch = iter.get_next() 108 | 109 | saver = tf.train.Saver(var_list=variables_to_restore) #Same as slim.get_variables() 110 | init1 = tf.global_variables_initializer() 111 | init2 = tf.local_variables_initializer() 112 | with tf.Session() as sess: 113 | sess.run(init1) 114 | sess.run(init2) 115 | images, labels = sess.run(next_batch) 116 | 117 | dirlist = os.listdir(FLAGS.train_dir) 118 | file_numbers = [] 119 | for file_name in dirlist: 120 | if file_name.startswith("model.ckpt-") & file_name.endswith(".index"): 121 | idx = file_name.replace("model.ckpt-","").replace(".index","") 122 | file_numbers.append(int(idx)) 123 | file_numbers.sort() 124 | maxAccuracy = 0.0 125 | maxAccuracyCheckPoint = "" 126 | for file_number in file_numbers: 127 | if file_number<=0: 128 | continue 129 | file_name = "model.ckpt-%d"%file_number 130 | checkpoint_path = os.path.join(FLAGS.train_dir, file_name) 131 | print('Evaluate val_accuracy on %s' % checkpoint_path) 132 | saver.restore(sess, checkpoint_path) 133 | train_accuracy = sess.run(fetches=accuracy, feed_dict={x: images, y_: labels}) 134 | print("Val_accuracy: {0}".format(train_accuracy)) 135 | if train_accuracy >= (maxAccuracy + 0.0000): 136 | maxAccuracy = train_accuracy 137 | maxAccuracyCheckPoint = checkpoint_path 138 | print("Max val_accuracy: %f"%maxAccuracy) 139 | print("maxAccuracyCheckPoint: %s"%maxAccuracyCheckPoint) 140 | sess.close() 141 | t2 = time() 142 | print("End val_accuracy %d s" %(t2 - t1)) 143 | sys.stdout.flush() 144 | return maxAccuracyCheckPoint 145 | 146 | 147 | -------------------------------------------------------------------------------- /nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def _reduced_default_blocks(self): 28 | """Returns the default blocks, scaled down to make test run faster.""" 29 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 30 | for b in pix2pix._default_generator_blocks()] 31 | 32 | def test_output_size_nn_upsample_conv(self): 33 | batch_size = 2 34 | height, width = 256, 256 35 | num_outputs = 4 36 | 37 | images = tf.ones((batch_size, height, width, 3)) 38 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 39 | logits, _ = pix2pix.pix2pix_generator( 40 | images, num_outputs, blocks=self._reduced_default_blocks(), 41 | upsample_method='nn_upsample_conv') 42 | 43 | with self.test_session() as session: 44 | session.run(tf.global_variables_initializer()) 45 | np_outputs = session.run(logits) 46 | self.assertListEqual([batch_size, height, width, num_outputs], 47 | list(np_outputs.shape)) 48 | 49 | def test_output_size_conv2d_transpose(self): 50 | batch_size = 2 51 | height, width = 256, 256 52 | num_outputs = 4 53 | 54 | images = tf.ones((batch_size, height, width, 3)) 55 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 56 | logits, _ = pix2pix.pix2pix_generator( 57 | images, num_outputs, blocks=self._reduced_default_blocks(), 58 | upsample_method='conv2d_transpose') 59 | 60 | with self.test_session() as session: 61 | session.run(tf.global_variables_initializer()) 62 | np_outputs = session.run(logits) 63 | self.assertListEqual([batch_size, height, width, num_outputs], 64 | list(np_outputs.shape)) 65 | 66 | def test_block_number_dictates_number_of_layers(self): 67 | batch_size = 2 68 | height, width = 256, 256 69 | num_outputs = 4 70 | 71 | images = tf.ones((batch_size, height, width, 3)) 72 | blocks = [ 73 | pix2pix.Block(64, 0.5), 74 | pix2pix.Block(128, 0), 75 | ] 76 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 77 | _, end_points = pix2pix.pix2pix_generator( 78 | images, num_outputs, blocks) 79 | 80 | num_encoder_layers = 0 81 | num_decoder_layers = 0 82 | for end_point in end_points: 83 | if end_point.startswith('encoder'): 84 | num_encoder_layers += 1 85 | elif end_point.startswith('decoder'): 86 | num_decoder_layers += 1 87 | 88 | self.assertEqual(num_encoder_layers, len(blocks)) 89 | self.assertEqual(num_decoder_layers, len(blocks)) 90 | 91 | 92 | class DiscriminatorTest(tf.test.TestCase): 93 | 94 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 95 | return (input_size + pad * 2 - kernel_size) // stride + 1 96 | 97 | def test_four_layers(self): 98 | batch_size = 2 99 | input_size = 256 100 | 101 | output_size = self._layer_output_size(input_size) 102 | output_size = self._layer_output_size(output_size) 103 | output_size = self._layer_output_size(output_size) 104 | output_size = self._layer_output_size(output_size, stride=1) 105 | output_size = self._layer_output_size(output_size, stride=1) 106 | 107 | images = tf.ones((batch_size, input_size, input_size, 3)) 108 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 109 | logits, end_points = pix2pix.pix2pix_discriminator( 110 | images, num_filters=[64, 128, 256, 512]) 111 | self.assertListEqual([batch_size, output_size, output_size, 1], 112 | logits.shape.as_list()) 113 | self.assertListEqual([batch_size, output_size, output_size, 1], 114 | end_points['predictions'].shape.as_list()) 115 | 116 | def test_four_layers_no_padding(self): 117 | batch_size = 2 118 | input_size = 256 119 | 120 | output_size = self._layer_output_size(input_size, pad=0) 121 | output_size = self._layer_output_size(output_size, pad=0) 122 | output_size = self._layer_output_size(output_size, pad=0) 123 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 124 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 125 | 126 | images = tf.ones((batch_size, input_size, input_size, 3)) 127 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 128 | logits, end_points = pix2pix.pix2pix_discriminator( 129 | images, num_filters=[64, 128, 256, 512], padding=0) 130 | self.assertListEqual([batch_size, output_size, output_size, 1], 131 | logits.shape.as_list()) 132 | self.assertListEqual([batch_size, output_size, output_size, 1], 133 | end_points['predictions'].shape.as_list()) 134 | 135 | def test_four_layers_wrog_paddig(self): 136 | batch_size = 2 137 | input_size = 256 138 | 139 | images = tf.ones((batch_size, input_size, input_size, 3)) 140 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 141 | with self.assertRaises(TypeError): 142 | pix2pix.pix2pix_discriminator( 143 | images, num_filters=[64, 128, 256, 512], padding=1.5) 144 | 145 | def test_four_layers_negative_padding(self): 146 | batch_size = 2 147 | input_size = 256 148 | 149 | images = tf.ones((batch_size, input_size, input_size, 3)) 150 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 151 | with self.assertRaises(ValueError): 152 | pix2pix.pix2pix_discriminator( 153 | images, num_filters=[64, 128, 256, 512], padding=-1) 154 | 155 | if __name__ == '__main__': 156 | tf.test.main() 157 | -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /nets/i3d_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for networks.i3d.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import i3d 24 | 25 | 26 | class I3DTest(tf.test.TestCase): 27 | 28 | def testBuildClassificationNetwork(self): 29 | batch_size = 5 30 | num_frames = 64 31 | height, width = 224, 224 32 | num_classes = 1000 33 | 34 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 35 | logits, end_points = i3d.i3d(inputs, num_classes) 36 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | self.assertTrue('Predictions' in end_points) 40 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 41 | [batch_size, num_classes]) 42 | 43 | def testBuildBaseNetwork(self): 44 | batch_size = 5 45 | num_frames = 64 46 | height, width = 224, 224 47 | 48 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 49 | mixed_6c, end_points = i3d.i3d_base(inputs) 50 | self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c')) 51 | self.assertListEqual(mixed_6c.get_shape().as_list(), 52 | [batch_size, 8, 7, 7, 1024]) 53 | expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 54 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 55 | 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 56 | 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 57 | 'Mixed_5b', 'Mixed_5c'] 58 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 59 | 60 | def testBuildOnlyUptoFinalEndpoint(self): 61 | batch_size = 5 62 | num_frames = 64 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 67 | 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 68 | 'Mixed_5c'] 69 | for index, endpoint in enumerate(endpoints): 70 | with tf.Graph().as_default(): 71 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 72 | out_tensor, end_points = i3d.i3d_base( 73 | inputs, final_endpoint=endpoint) 74 | self.assertTrue(out_tensor.op.name.startswith( 75 | 'InceptionV1/' + endpoint)) 76 | self.assertItemsEqual(endpoints[:index+1], end_points) 77 | 78 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 79 | batch_size = 5 80 | num_frames = 64 81 | height, width = 224, 224 82 | 83 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 84 | _, end_points = i3d.i3d_base(inputs, 85 | final_endpoint='Mixed_5c') 86 | endpoints_shapes = {'Conv2d_1a_7x7': [5, 32, 112, 112, 64], 87 | 'MaxPool_2a_3x3': [5, 32, 56, 56, 64], 88 | 'Conv2d_2b_1x1': [5, 32, 56, 56, 64], 89 | 'Conv2d_2c_3x3': [5, 32, 56, 56, 192], 90 | 'MaxPool_3a_3x3': [5, 32, 28, 28, 192], 91 | 'Mixed_3b': [5, 32, 28, 28, 256], 92 | 'Mixed_3c': [5, 32, 28, 28, 480], 93 | 'MaxPool_4a_3x3': [5, 16, 14, 14, 480], 94 | 'Mixed_4b': [5, 16, 14, 14, 512], 95 | 'Mixed_4c': [5, 16, 14, 14, 512], 96 | 'Mixed_4d': [5, 16, 14, 14, 512], 97 | 'Mixed_4e': [5, 16, 14, 14, 528], 98 | 'Mixed_4f': [5, 16, 14, 14, 832], 99 | 'MaxPool_5a_2x2': [5, 8, 7, 7, 832], 100 | 'Mixed_5b': [5, 8, 7, 7, 832], 101 | 'Mixed_5c': [5, 8, 7, 7, 1024]} 102 | 103 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 104 | for endpoint_name, expected_shape in endpoints_shapes.iteritems(): 105 | self.assertTrue(endpoint_name in end_points) 106 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 107 | expected_shape) 108 | 109 | def testHalfSizeImages(self): 110 | batch_size = 5 111 | num_frames = 64 112 | height, width = 112, 112 113 | 114 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 115 | mixed_5c, _ = i3d.i3d_base(inputs) 116 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 117 | self.assertListEqual(mixed_5c.get_shape().as_list(), 118 | [batch_size, 8, 4, 4, 1024]) 119 | 120 | def testTenFrames(self): 121 | batch_size = 5 122 | num_frames = 10 123 | height, width = 224, 224 124 | 125 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 126 | mixed_5c, _ = i3d.i3d_base(inputs) 127 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 128 | self.assertListEqual(mixed_5c.get_shape().as_list(), 129 | [batch_size, 2, 7, 7, 1024]) 130 | 131 | def testEvaluation(self): 132 | batch_size = 2 133 | num_frames = 64 134 | height, width = 224, 224 135 | num_classes = 1000 136 | 137 | eval_inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 138 | logits, _ = i3d.i3d(eval_inputs, num_classes, 139 | is_training=False) 140 | predictions = tf.argmax(logits, 1) 141 | 142 | with self.test_session() as sess: 143 | sess.run(tf.global_variables_initializer()) 144 | output = sess.run(predictions) 145 | self.assertEquals(output.shape, (batch_size,)) 146 | 147 | 148 | if __name__ == '__main__': 149 | tf.test.main() 150 | -------------------------------------------------------------------------------- /validation_confusion_matrix.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | import os 7 | 8 | from time import time 9 | from datetime import datetime 10 | from nets import nets_factory 11 | from datasets import dataset_utils 12 | from preprocessing import preprocessing_factory 13 | from tensorflow.contrib import slim 14 | import numpy as np 15 | import sys 16 | 17 | FLAGS = tf.app.flags.FLAGS 18 | 19 | batch_size = FLAGS.num_validation 20 | max_num_batches = 1 21 | dataset_split_name = "validation" 22 | 23 | def execute(checkpoint_path, model_no): 24 | if FLAGS.num_validation == 0: 25 | print("FLAGS.num_validation is 0, no need to validation") 26 | return None 27 | 28 | if checkpoint_path == None: 29 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.train_dir) 30 | 31 | print("Begin validation_confusion_matrix %s" % format(datetime.now().isoformat())) 32 | t1 = time() 33 | 34 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 35 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 36 | preprocessing_name, is_training=False) 37 | 38 | labels_to_names = dataset_utils.read_label_file(FLAGS.dataset_dir) 39 | num_classes = len(labels_to_names) 40 | 41 | def decode(serialized_example): 42 | feature = { 43 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 44 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'), 45 | 'image/class/label': tf.FixedLenFeature( 46 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 47 | } 48 | features = tf.parse_single_example(serialized_example, features=feature) 49 | # image 50 | image_string = features['image/encoded'] 51 | image = tf.image.decode_jpeg(image_string, channels=3) 52 | image = image_preprocessing_fn(image, FLAGS.train_image_size, FLAGS.train_image_size) 53 | # label 54 | label = features['image/class/label'] 55 | label = tf.one_hot(label, num_classes) 56 | return image, label 57 | 58 | def input_iter(filenames, batch_size, num_epochs): 59 | if not num_epochs: 60 | num_epochs = 1 61 | dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=FLAGS.num_readers) 62 | dataset = dataset.map(decode) 63 | dataset = dataset.repeat(num_epochs) 64 | dataset = dataset.batch(batch_size) 65 | # dataset = dataset.shuffle(buffer_size=NUM_IMAGES) 66 | iterator = dataset.make_one_shot_iterator() 67 | return iterator 68 | 69 | with tf.Graph().as_default() as graph: 70 | tf_global_step = slim.get_or_create_global_step() 71 | 72 | network_fn = nets_factory.get_network_fn( 73 | FLAGS.model_name, 74 | num_classes=(num_classes - FLAGS.labels_offset), 75 | weight_decay=FLAGS.weight_decay, 76 | is_training=False) 77 | 78 | eval_image_size = FLAGS.train_image_size 79 | 80 | x = tf.placeholder(tf.float32, [None, eval_image_size, eval_image_size, 3]) 81 | y_ = tf.placeholder(tf.float32, [None, num_classes]) 82 | 83 | logits, endpoints = network_fn(x) 84 | 85 | predictions_key = "Predictions" 86 | if FLAGS.model_name.startswith("resnet"): 87 | predictions_key = "predictions" 88 | y = endpoints[predictions_key] 89 | test_labels = tf.argmax(y_, 1, name="label") 90 | 91 | if FLAGS.moving_average_decay: 92 | variable_averages = tf.train.ExponentialMovingAverage( 93 | FLAGS.moving_average_decay, tf_global_step) 94 | variables_to_restore = variable_averages.variables_to_restore( 95 | slim.get_model_variables()) 96 | variables_to_restore[tf_global_step.op.name] = tf_global_step 97 | else: 98 | variables_to_restore = slim.get_variables_to_restore() 99 | 100 | input_dir = [] 101 | for i in range(5) : 102 | data_file = os.path.join(FLAGS.dataset_dir, "garbage_validation_0000%d-of-00005.tfrecord")%i 103 | input_dir.append(data_file) 104 | iter = input_iter(input_dir, batch_size, 1) 105 | next_batch = iter.get_next() 106 | 107 | saver = tf.train.Saver(var_list=variables_to_restore) #Same as slim.get_variables() 108 | init1 = tf.global_variables_initializer() 109 | init2 = tf.local_variables_initializer() 110 | with tf.Session() as sess: 111 | sess.run(init1) 112 | sess.run(init2) 113 | saver.restore(sess, checkpoint_path) 114 | images, labels = sess.run(next_batch) 115 | 116 | predictions = sess.run(fetches=y, feed_dict={x: images}) 117 | predictions = np.squeeze(predictions) 118 | ids = sess.run(test_labels, feed_dict={y_: labels}) 119 | errorList = [] 120 | v_records = [] 121 | for i in range(batch_size): 122 | prediction = predictions[i] 123 | top_k = prediction.argsort()[-5:][::-1] 124 | if ids[i] != top_k[0]: 125 | errorList.append(str(ids[i]) + ":" + str(top_k[0])) 126 | v_record = str(ids[i]) + " " + labels_to_names[ids[i]] + " => " 127 | #print(ids[i], labels_to_names[ids[i]], "=> ", end='') 128 | for id in top_k: 129 | human_string = labels_to_names[id] 130 | score = prediction[id] 131 | v_record = v_record + str(id) + ":" + human_string + "(P=" + str(score) + "), " 132 | #print('%d:%s(P=%.5f), ' % (id, human_string, score), end='') 133 | print(v_record) 134 | v_records.append(v_record) 135 | print(errorList) 136 | errorid_filename = os.path.join(FLAGS.inference_dir, model_no + "_error.csv") 137 | print("Write file: %s ..."%errorid_filename) 138 | with tf.gfile.Open(errorid_filename, 'w') as f: 139 | for idmap in errorList: 140 | f.write('%s\n' % (idmap)) 141 | validation_record_filename = os.path.join(FLAGS.inference_dir, model_no + "_validation_record.txt") 142 | print("Write file: %s ..." % validation_record_filename) 143 | with tf.gfile.Open(validation_record_filename, 'w') as f: 144 | for v_rec in v_records: 145 | f.write('%s\n' % (v_rec)) 146 | sess.close() 147 | t2 = time() 148 | print("End validation_confusion_matrix %d s" %(t2 - t1)) 149 | sys.stdout.flush() 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(values): 31 | """Returns a TF-Feature of int64s. 32 | 33 | Args: 34 | values: A scalar or list of values. 35 | 36 | Returns: 37 | A TF-Feature. 38 | """ 39 | if not isinstance(values, (tuple, list)): 40 | values = [values] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 42 | 43 | 44 | def bytes_list_feature(values): 45 | """Returns a TF-Feature of list of bytes. 46 | 47 | Args: 48 | values: A string or list of strings. 49 | 50 | Returns: 51 | A TF-Feature. 52 | """ 53 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) 54 | 55 | 56 | def float_list_feature(values): 57 | """Returns a TF-Feature of list of floats. 58 | 59 | Args: 60 | values: A float or list of floats. 61 | 62 | Returns: 63 | A TF-Feature. 64 | """ 65 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 66 | 67 | 68 | def bytes_feature(values): 69 | """Returns a TF-Feature of bytes. 70 | 71 | Args: 72 | values: A string. 73 | 74 | Returns: 75 | A TF-Feature. 76 | """ 77 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 78 | 79 | 80 | def float_feature(values): 81 | """Returns a TF-Feature of floats. 82 | 83 | Args: 84 | values: A scalar of list of values. 85 | 86 | Returns: 87 | A TF-Feature. 88 | """ 89 | if not isinstance(values, (tuple, list)): 90 | values = [values] 91 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 92 | 93 | 94 | def image_to_tfexample(image_data, image_format, height, width, class_id): 95 | return tf.train.Example(features=tf.train.Features(feature={ 96 | 'image/encoded': bytes_feature(image_data), 97 | 'image/format': bytes_feature(image_format), 98 | 'image/class/label': int64_feature(class_id), 99 | 'image/height': int64_feature(height), 100 | 'image/width': int64_feature(width), 101 | })) 102 | 103 | 104 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 105 | """Downloads the `tarball_url` and uncompresses it locally. 106 | 107 | Args: 108 | tarball_url: The URL of a tarball file. 109 | dataset_dir: The directory where the temporary files are stored. 110 | """ 111 | filename = tarball_url.split('/')[-1] 112 | filepath = os.path.join(dataset_dir, filename) 113 | 114 | def _progress(count, block_size, total_size): 115 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 116 | filename, float(count * block_size) / float(total_size) * 100.0)) 117 | sys.stdout.flush() 118 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 119 | print() 120 | statinfo = os.stat(filepath) 121 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 122 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 123 | 124 | 125 | def write_label_file(labels_to_class_names, dataset_dir, 126 | filename=LABELS_FILENAME): 127 | """Writes a file with the list of class names. 128 | 129 | Args: 130 | labels_to_class_names: A map of (integer) labels to class names. 131 | dataset_dir: The directory in which the labels file should be written. 132 | filename: The filename where the class names are written. 133 | """ 134 | labels_filename = os.path.join(dataset_dir, filename) 135 | with tf.gfile.Open(labels_filename, 'w') as f: 136 | for label in labels_to_class_names: 137 | class_name = labels_to_class_names[label] 138 | f.write('%d:%s\n' % (label, class_name)) 139 | 140 | 141 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 142 | """Specifies whether or not the dataset directory contains a label map file. 143 | 144 | Args: 145 | dataset_dir: The directory in which the labels file is found. 146 | filename: The filename where the class names are written. 147 | 148 | Returns: 149 | `True` if the labels file exists and `False` otherwise. 150 | """ 151 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 152 | 153 | 154 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 155 | """Reads the labels file and returns a mapping from ID to class name. 156 | 157 | Args: 158 | dataset_dir: The directory in which the labels file is found. 159 | filename: The filename where the class names are written. 160 | 161 | Returns: 162 | A map from a label (integer) to class name. 163 | """ 164 | labels_filename = os.path.join(dataset_dir, filename) 165 | with tf.gfile.Open(labels_filename, 'rb') as f: 166 | lines = f.read() 167 | lines = lines.split('\n') 168 | lines = filter(None, lines) 169 | 170 | labels_to_class_names = {} 171 | for line in lines: 172 | index = line.index(':') 173 | labels_to_class_names[int(line[:index])] = line[index+1:] 174 | return labels_to_class_names 175 | 176 | 177 | def open_sharded_output_tfrecords(exit_stack, base_path, num_shards): 178 | """Opens all TFRecord shards for writing and adds them to an exit stack. 179 | 180 | Args: 181 | exit_stack: A context2.ExitStack used to automatically closed the TFRecords 182 | opened in this function. 183 | base_path: The base path for all shards 184 | num_shards: The number of shards 185 | 186 | Returns: 187 | The list of opened TFRecords. Position k in the list corresponds to shard k. 188 | """ 189 | tf_record_output_filenames = [ 190 | '{}-{:05d}-of-{:05d}'.format(base_path, idx, num_shards) 191 | for idx in range(num_shards) 192 | ] 193 | 194 | tfrecords = [ 195 | exit_stack.enter_context(tf.python_io.TFRecordWriter(file_name)) 196 | for file_name in tf_record_output_filenames 197 | ] 198 | 199 | return tfrecords 200 | -------------------------------------------------------------------------------- /nets/s3dg_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for networks.s3dg.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import s3dg 24 | 25 | 26 | class S3DGTest(tf.test.TestCase): 27 | 28 | def testBuildClassificationNetwork(self): 29 | batch_size = 5 30 | num_frames = 64 31 | height, width = 224, 224 32 | num_classes = 1000 33 | 34 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 35 | logits, end_points = s3dg.s3dg(inputs, num_classes) 36 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | self.assertTrue('Predictions' in end_points) 40 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 41 | [batch_size, num_classes]) 42 | 43 | def testBuildBaseNetwork(self): 44 | batch_size = 5 45 | num_frames = 64 46 | height, width = 224, 224 47 | 48 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 49 | mixed_6c, end_points = s3dg.s3dg_base(inputs) 50 | self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c')) 51 | self.assertListEqual(mixed_6c.get_shape().as_list(), 52 | [batch_size, 8, 7, 7, 1024]) 53 | expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 54 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 55 | 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 56 | 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 57 | 'Mixed_5b', 'Mixed_5c'] 58 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 59 | 60 | def testBuildOnlyUptoFinalEndpointNoGating(self): 61 | batch_size = 5 62 | num_frames = 64 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 67 | 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 68 | 'Mixed_5c'] 69 | for index, endpoint in enumerate(endpoints): 70 | with tf.Graph().as_default(): 71 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 72 | out_tensor, end_points = s3dg.s3dg_base( 73 | inputs, final_endpoint=endpoint, gating_startat=None) 74 | print(endpoint, out_tensor.op.name) 75 | self.assertTrue(out_tensor.op.name.startswith( 76 | 'InceptionV1/' + endpoint)) 77 | self.assertItemsEqual(endpoints[:index+1], end_points) 78 | 79 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 80 | batch_size = 5 81 | num_frames = 64 82 | height, width = 224, 224 83 | 84 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 85 | _, end_points = s3dg.s3dg_base(inputs, 86 | final_endpoint='Mixed_5c') 87 | endpoints_shapes = {'Conv2d_1a_7x7': [5, 32, 112, 112, 64], 88 | 'MaxPool_2a_3x3': [5, 32, 56, 56, 64], 89 | 'Conv2d_2b_1x1': [5, 32, 56, 56, 64], 90 | 'Conv2d_2c_3x3': [5, 32, 56, 56, 192], 91 | 'MaxPool_3a_3x3': [5, 32, 28, 28, 192], 92 | 'Mixed_3b': [5, 32, 28, 28, 256], 93 | 'Mixed_3c': [5, 32, 28, 28, 480], 94 | 'MaxPool_4a_3x3': [5, 16, 14, 14, 480], 95 | 'Mixed_4b': [5, 16, 14, 14, 512], 96 | 'Mixed_4c': [5, 16, 14, 14, 512], 97 | 'Mixed_4d': [5, 16, 14, 14, 512], 98 | 'Mixed_4e': [5, 16, 14, 14, 528], 99 | 'Mixed_4f': [5, 16, 14, 14, 832], 100 | 'MaxPool_5a_2x2': [5, 8, 7, 7, 832], 101 | 'Mixed_5b': [5, 8, 7, 7, 832], 102 | 'Mixed_5c': [5, 8, 7, 7, 1024]} 103 | 104 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 105 | for endpoint_name, expected_shape in endpoints_shapes.iteritems(): 106 | self.assertTrue(endpoint_name in end_points) 107 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 108 | expected_shape) 109 | 110 | def testHalfSizeImages(self): 111 | batch_size = 5 112 | num_frames = 64 113 | height, width = 112, 112 114 | 115 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 116 | mixed_5c, _ = s3dg.s3dg_base(inputs) 117 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 118 | self.assertListEqual(mixed_5c.get_shape().as_list(), 119 | [batch_size, 8, 4, 4, 1024]) 120 | 121 | def testTenFrames(self): 122 | batch_size = 5 123 | num_frames = 10 124 | height, width = 224, 224 125 | 126 | inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 127 | mixed_5c, _ = s3dg.s3dg_base(inputs) 128 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 129 | self.assertListEqual(mixed_5c.get_shape().as_list(), 130 | [batch_size, 2, 7, 7, 1024]) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | num_frames = 64 135 | height, width = 224, 224 136 | num_classes = 1000 137 | 138 | eval_inputs = tf.random_uniform((batch_size, num_frames, height, width, 3)) 139 | logits, _ = s3dg.s3dg(eval_inputs, num_classes, 140 | is_training=False) 141 | predictions = tf.argmax(logits, 1) 142 | 143 | with self.test_session() as sess: 144 | sess.run(tf.global_variables_initializer()) 145 | output = sess.run(predictions) 146 | self.assertEquals(output.shape, (batch_size,)) 147 | 148 | 149 | if __name__ == '__main__': 150 | tf.test.main() 151 | -------------------------------------------------------------------------------- /datasets/build_visualwakewords_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Build Visual WakeWords Dataset with images and labels for person/not-person. 16 | 17 | This script generates the Visual WakeWords dataset annotations from 18 | the raw COCO dataset and converts them to TFRecord. 19 | Visual WakeWords Dataset derives from the COCO dataset to design tiny models 20 | classifying two classes, such as person/not-person. The COCO annotations 21 | are filtered to two classes: foreground_class_of_interest and background 22 | ( for e.g. person and not-person). Bounding boxes for small objects 23 | with area less than 5% of the image area are filtered out. 24 | 25 | The resulting annotations file has the following fields, where 26 | the image and categories fields are same as COCO dataset, while the annotation 27 | field corresponds to the foreground_class_of_interest/background class and 28 | bounding boxes for the foreground_class_of_interest class. 29 | 30 | images{"id", "width", "height", "file_name", "license", "flickr_url", 31 | "coco_url", "date_captured",} 32 | 33 | annotations{ 34 | "image_id", object[{"category_id", "area", "bbox" : [x,y,width,height],}] 35 | "count", 36 | "label" 37 | } 38 | 39 | categories[{ 40 | "id", "name", "supercategory", 41 | }] 42 | 43 | 44 | The TFRecord file contains the following features: 45 | { image/height, image/width, image/source_id, image/encoded, 46 | image/class/label_text, image/class/label, 47 | image/object/class/text, 48 | image/object/bbox/ymin, image/object/bbox/xmin, image/object/bbox/ymax, 49 | image/object/bbox/xmax, image/object/area 50 | image/filename, image/format, image/key/sha256} 51 | For classification models, you need the image/encoded and image/class/label. 52 | Please note that this tool creates sharded output files. 53 | 54 | Example usage: 55 | Add folder tensorflow/models/research/slim to your PYTHONPATH, 56 | and from this folder, run the following commands: 57 | 58 | bash download_mscoco.sh path-to-mscoco-dataset 59 | TRAIN_IMAGE_DIR="path-to-mscoco-dataset/train2014" 60 | VAL_IMAGE_DIR="path-to-mscoco-dataset/val2014" 61 | 62 | TRAIN_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_train2014.json" 63 | VAL_ANNOTATIONS_FILE="path-to-mscoco-dataset/annotations/instances_val2014.json" 64 | 65 | python datasets/build_visualwakewords_data.py --logtostderr \ 66 | --train_image_dir="${TRAIN_IMAGE_DIR}" \ 67 | --val_image_dir="${VAL_IMAGE_DIR}" \ 68 | --train_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \ 69 | --val_annotations_file="${VAL_ANNOTATIONS_FILE}" \ 70 | --output_dir="${OUTPUT_DIR}" \ 71 | --small_object_area_threshold=0.005 \ 72 | --foreground_class_of_interest='person' 73 | """ 74 | 75 | from __future__ import absolute_import 76 | from __future__ import division 77 | from __future__ import print_function 78 | 79 | import os 80 | import tensorflow as tf 81 | from datasets import build_visualwakewords_data_lib 82 | 83 | flags = tf.app.flags 84 | tf.flags.DEFINE_string('train_image_dir', '', 'Training image directory.') 85 | tf.flags.DEFINE_string('val_image_dir', '', 'Validation image directory.') 86 | tf.flags.DEFINE_string('train_annotations_file', '', 87 | 'Training annotations JSON file.') 88 | tf.flags.DEFINE_string('val_annotations_file', '', 89 | 'Validation annotations JSON file.') 90 | tf.flags.DEFINE_string('output_dir', '/tmp/', 'Output data directory.') 91 | tf.flags.DEFINE_float( 92 | 'small_object_area_threshold', 0.005, 93 | 'Threshold of fraction of image area below which small' 94 | 'objects are filtered') 95 | tf.flags.DEFINE_string( 96 | 'foreground_class_of_interest', 'person', 97 | 'Build a binary classifier based on the presence or absence' 98 | 'of this object in the scene (default is person/not-person)') 99 | 100 | FLAGS = flags.FLAGS 101 | 102 | tf.logging.set_verbosity(tf.logging.INFO) 103 | 104 | 105 | def main(unused_argv): 106 | # Path to COCO dataset images and annotations 107 | assert FLAGS.train_image_dir, '`train_image_dir` missing.' 108 | assert FLAGS.val_image_dir, '`val_image_dir` missing.' 109 | assert FLAGS.train_annotations_file, '`train_annotations_file` missing.' 110 | assert FLAGS.val_annotations_file, '`val_annotations_file` missing.' 111 | visualwakewords_annotations_train = os.path.join( 112 | FLAGS.output_dir, 'instances_visualwakewords_train2014.json') 113 | visualwakewords_annotations_val = os.path.join( 114 | FLAGS.output_dir, 'instances_visualwakewords_val2014.json') 115 | visualwakewords_labels_filename = os.path.join(FLAGS.output_dir, 116 | 'labels.txt') 117 | small_object_area_threshold = FLAGS.small_object_area_threshold 118 | foreground_class_of_interest = FLAGS.foreground_class_of_interest 119 | # Create the Visual WakeWords annotations from COCO annotations 120 | if not tf.gfile.IsDirectory(FLAGS.output_dir): 121 | tf.gfile.MakeDirs(FLAGS.output_dir) 122 | build_visualwakewords_data_lib.create_visual_wakeword_annotations( 123 | FLAGS.train_annotations_file, visualwakewords_annotations_train, 124 | small_object_area_threshold, foreground_class_of_interest, 125 | visualwakewords_labels_filename) 126 | build_visualwakewords_data_lib.create_visual_wakeword_annotations( 127 | FLAGS.val_annotations_file, visualwakewords_annotations_val, 128 | small_object_area_threshold, foreground_class_of_interest, 129 | visualwakewords_labels_filename) 130 | 131 | # Create the TF Records for Visual WakeWords Dataset 132 | if not tf.gfile.IsDirectory(FLAGS.output_dir): 133 | tf.gfile.MakeDirs(FLAGS.output_dir) 134 | train_output_path = os.path.join(FLAGS.output_dir, 'train.record') 135 | val_output_path = os.path.join(FLAGS.output_dir, 'val.record') 136 | build_visualwakewords_data_lib.create_tf_record_for_visualwakewords_dataset( 137 | visualwakewords_annotations_train, 138 | FLAGS.train_image_dir, 139 | train_output_path, 140 | num_shards=100) 141 | build_visualwakewords_data_lib.create_tf_record_for_visualwakewords_dataset( 142 | visualwakewords_annotations_val, 143 | FLAGS.val_image_dir, 144 | val_output_path, 145 | num_shards=10) 146 | 147 | 148 | if __name__ == '__main__': 149 | tf.app.run() 150 | -------------------------------------------------------------------------------- /nets/i3d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition for Inflated 3D Inception V1 (I3D). 16 | 17 | The network architecture is proposed by: 18 | Joao Carreira and Andrew Zisserman, 19 | Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset. 20 | https://arxiv.org/abs/1705.07750 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | from nets import i3d_utils 30 | from nets import s3dg 31 | 32 | slim = tf.contrib.slim 33 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 34 | conv3d_spatiotemporal = i3d_utils.conv3d_spatiotemporal 35 | 36 | 37 | def i3d_arg_scope(weight_decay=1e-7, 38 | batch_norm_decay=0.999, 39 | batch_norm_epsilon=0.001, 40 | use_renorm=False, 41 | separable_conv3d=False): 42 | """Defines default arg_scope for I3D. 43 | 44 | Args: 45 | weight_decay: The weight decay to use for regularizing the model. 46 | batch_norm_decay: Decay for batch norm moving average. 47 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 48 | in batch norm. 49 | use_renorm: Whether to use batch renormalization or not. 50 | separable_conv3d: Whether to use separable 3d Convs. 51 | 52 | Returns: 53 | sc: An arg_scope to use for the models. 54 | """ 55 | batch_norm_params = { 56 | # Decay for the moving averages. 57 | 'decay': batch_norm_decay, 58 | # epsilon to prevent 0s in variance. 59 | 'epsilon': batch_norm_epsilon, 60 | # Turns off fused batch norm. 61 | 'fused': False, 62 | 'renorm': use_renorm, 63 | # collection containing the moving mean and moving variance. 64 | 'variables_collections': { 65 | 'beta': None, 66 | 'gamma': None, 67 | 'moving_mean': ['moving_vars'], 68 | 'moving_variance': ['moving_vars'], 69 | } 70 | } 71 | 72 | with slim.arg_scope( 73 | [slim.conv3d, conv3d_spatiotemporal], 74 | weights_regularizer=slim.l2_regularizer(weight_decay), 75 | activation_fn=tf.nn.relu, 76 | normalizer_fn=slim.batch_norm, 77 | normalizer_params=batch_norm_params): 78 | with slim.arg_scope( 79 | [conv3d_spatiotemporal], separable=separable_conv3d) as sc: 80 | return sc 81 | 82 | 83 | def i3d_base(inputs, final_endpoint='Mixed_5c', 84 | scope='InceptionV1'): 85 | """Defines the I3D base architecture. 86 | 87 | Note that we use the names as defined in Inception V1 to facilitate checkpoint 88 | conversion from an image-trained Inception V1 checkpoint to I3D checkpoint. 89 | 90 | Args: 91 | inputs: A 5-D float tensor of size [batch_size, num_frames, height, width, 92 | channels]. 93 | final_endpoint: Specifies the endpoint to construct the network up to. It 94 | can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 95 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 96 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 97 | 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c'] 98 | scope: Optional variable_scope. 99 | 100 | Returns: 101 | A dictionary from components of the network to the corresponding activation. 102 | 103 | Raises: 104 | ValueError: if final_endpoint is not set to one of the predefined values. 105 | """ 106 | 107 | return s3dg.s3dg_base( 108 | inputs, 109 | first_temporal_kernel_size=7, 110 | temporal_conv_startat='Conv2d_2c_3x3', 111 | gating_startat=None, 112 | final_endpoint=final_endpoint, 113 | min_depth=16, 114 | depth_multiplier=1.0, 115 | data_format='NDHWC', 116 | scope=scope) 117 | 118 | 119 | def i3d(inputs, 120 | num_classes=1000, 121 | dropout_keep_prob=0.8, 122 | is_training=True, 123 | prediction_fn=slim.softmax, 124 | spatial_squeeze=True, 125 | reuse=None, 126 | scope='InceptionV1'): 127 | """Defines the I3D architecture. 128 | 129 | The default image size used to train this network is 224x224. 130 | 131 | Args: 132 | inputs: A 5-D float tensor of size [batch_size, num_frames, height, width, 133 | channels]. 134 | num_classes: number of predicted classes. 135 | dropout_keep_prob: the percentage of activation values that are retained. 136 | is_training: whether is training or not. 137 | prediction_fn: a function to get predictions out of logits. 138 | spatial_squeeze: if True, logits is of shape is [B, C], if false logits is 139 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 140 | reuse: whether or not the network and its variables should be reused. To be 141 | able to reuse 'scope' must be given. 142 | scope: Optional variable_scope. 143 | 144 | Returns: 145 | logits: the pre-softmax activations, a tensor of size 146 | [batch_size, num_classes] 147 | end_points: a dictionary from components of the network to the corresponding 148 | activation. 149 | """ 150 | # Final pooling and prediction 151 | with tf.variable_scope( 152 | scope, 'InceptionV1', [inputs, num_classes], reuse=reuse) as scope: 153 | with slim.arg_scope( 154 | [slim.batch_norm, slim.dropout], is_training=is_training): 155 | net, end_points = i3d_base(inputs, scope=scope) 156 | with tf.variable_scope('Logits'): 157 | kernel_size = i3d_utils.reduced_kernel_size_3d(net, [2, 7, 7]) 158 | net = slim.avg_pool3d( 159 | net, kernel_size, stride=1, scope='AvgPool_0a_7x7') 160 | net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b') 161 | logits = slim.conv3d( 162 | net, 163 | num_classes, [1, 1, 1], 164 | activation_fn=None, 165 | normalizer_fn=None, 166 | scope='Conv2d_0c_1x1') 167 | # Temporal average pooling. 168 | logits = tf.reduce_mean(logits, axis=1) 169 | if spatial_squeeze: 170 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 171 | 172 | end_points['Logits'] = logits 173 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 174 | return logits, end_points 175 | 176 | 177 | i3d.default_image_size = 224 178 | -------------------------------------------------------------------------------- /datasets/download_and_convert_cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos. 16 | 17 | This module downloads the cifar10 data, uncompresses it, reads the files 18 | that make up the cifar10 data and creates two TFRecord datasets: one for train 19 | and one for test. Each TFRecord dataset is comprised of a set of TF-Example 20 | protocol buffers, each of which contain a single image and label. 21 | 22 | The script should take several minutes to run. 23 | 24 | """ 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | import sys 31 | import tarfile 32 | 33 | import numpy as np 34 | from six.moves import cPickle 35 | from six.moves import urllib 36 | import tensorflow as tf 37 | 38 | from datasets import dataset_utils 39 | 40 | # The URL where the CIFAR data can be downloaded. 41 | _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 42 | 43 | # The number of training files. 44 | _NUM_TRAIN_FILES = 5 45 | 46 | # The height and width of each image. 47 | _IMAGE_SIZE = 32 48 | 49 | # The names of the classes. 50 | _CLASS_NAMES = [ 51 | 'airplane', 52 | 'automobile', 53 | 'bird', 54 | 'cat', 55 | 'deer', 56 | 'dog', 57 | 'frog', 58 | 'horse', 59 | 'ship', 60 | 'truck', 61 | ] 62 | 63 | 64 | def _add_to_tfrecord(filename, tfrecord_writer, offset=0): 65 | """Loads data from the cifar10 pickle files and writes files to a TFRecord. 66 | 67 | Args: 68 | filename: The filename of the cifar10 pickle file. 69 | tfrecord_writer: The TFRecord writer to use for writing. 70 | offset: An offset into the absolute number of images previously written. 71 | 72 | Returns: 73 | The new offset. 74 | """ 75 | with tf.gfile.Open(filename, 'rb') as f: 76 | if sys.version_info < (3,): 77 | data = cPickle.load(f) 78 | else: 79 | data = cPickle.load(f, encoding='bytes') 80 | 81 | images = data[b'data'] 82 | num_images = images.shape[0] 83 | 84 | images = images.reshape((num_images, 3, 32, 32)) 85 | labels = data[b'labels'] 86 | 87 | with tf.Graph().as_default(): 88 | image_placeholder = tf.placeholder(dtype=tf.uint8) 89 | encoded_image = tf.image.encode_png(image_placeholder) 90 | 91 | with tf.Session('') as sess: 92 | 93 | for j in range(num_images): 94 | sys.stdout.write('\r>> Reading file [%s] image %d/%d' % ( 95 | filename, offset + j + 1, offset + num_images)) 96 | sys.stdout.flush() 97 | 98 | image = np.squeeze(images[j]).transpose((1, 2, 0)) 99 | label = labels[j] 100 | 101 | png_string = sess.run(encoded_image, 102 | feed_dict={image_placeholder: image}) 103 | 104 | example = dataset_utils.image_to_tfexample( 105 | png_string, b'png', _IMAGE_SIZE, _IMAGE_SIZE, label) 106 | tfrecord_writer.write(example.SerializeToString()) 107 | 108 | return offset + num_images 109 | 110 | 111 | def _get_output_filename(dataset_dir, split_name): 112 | """Creates the output filename. 113 | 114 | Args: 115 | dataset_dir: The dataset directory where the dataset is stored. 116 | split_name: The name of the train/test split. 117 | 118 | Returns: 119 | An absolute file path. 120 | """ 121 | return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name) 122 | 123 | 124 | def _download_and_uncompress_dataset(dataset_dir): 125 | """Downloads cifar10 and uncompresses it locally. 126 | 127 | Args: 128 | dataset_dir: The directory where the temporary files are stored. 129 | """ 130 | filename = _DATA_URL.split('/')[-1] 131 | filepath = os.path.join(dataset_dir, filename) 132 | 133 | if not os.path.exists(filepath): 134 | def _progress(count, block_size, total_size): 135 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 136 | filename, float(count * block_size) / float(total_size) * 100.0)) 137 | sys.stdout.flush() 138 | filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress) 139 | print() 140 | statinfo = os.stat(filepath) 141 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 142 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 143 | 144 | 145 | def _clean_up_temporary_files(dataset_dir): 146 | """Removes temporary files used to create the dataset. 147 | 148 | Args: 149 | dataset_dir: The directory where the temporary files are stored. 150 | """ 151 | filename = _DATA_URL.split('/')[-1] 152 | filepath = os.path.join(dataset_dir, filename) 153 | tf.gfile.Remove(filepath) 154 | 155 | tmp_dir = os.path.join(dataset_dir, 'cifar-10-batches-py') 156 | tf.gfile.DeleteRecursively(tmp_dir) 157 | 158 | 159 | def run(dataset_dir): 160 | """Runs the download and conversion operation. 161 | 162 | Args: 163 | dataset_dir: The dataset directory where the dataset is stored. 164 | """ 165 | if not tf.gfile.Exists(dataset_dir): 166 | tf.gfile.MakeDirs(dataset_dir) 167 | 168 | training_filename = _get_output_filename(dataset_dir, 'train') 169 | testing_filename = _get_output_filename(dataset_dir, 'test') 170 | 171 | if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename): 172 | print('Dataset files already exist. Exiting without re-creating them.') 173 | return 174 | 175 | dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 176 | 177 | # First, process the training data: 178 | with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer: 179 | offset = 0 180 | for i in range(_NUM_TRAIN_FILES): 181 | filename = os.path.join(dataset_dir, 182 | 'cifar-10-batches-py', 183 | 'data_batch_%d' % (i + 1)) # 1-indexed. 184 | offset = _add_to_tfrecord(filename, tfrecord_writer, offset) 185 | 186 | # Next, process the testing data: 187 | with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer: 188 | filename = os.path.join(dataset_dir, 189 | 'cifar-10-batches-py', 190 | 'test_batch') 191 | _add_to_tfrecord(filename, tfrecord_writer) 192 | 193 | # Finally, write the labels file: 194 | labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES)) 195 | dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 196 | 197 | _clean_up_temporary_files(dataset_dir) 198 | print('\nFinished converting the Cifar10 dataset!') 199 | -------------------------------------------------------------------------------- /nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 281, 281 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 231, 231 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = overfeat.overfeat(inputs, num_classes) 70 | expected_names = ['overfeat/conv1', 71 | 'overfeat/pool1', 72 | 'overfeat/conv2', 73 | 'overfeat/pool2', 74 | 'overfeat/conv3', 75 | 'overfeat/conv4', 76 | 'overfeat/conv5', 77 | 'overfeat/pool5', 78 | 'overfeat/fc6', 79 | 'overfeat/fc7', 80 | 'overfeat/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 231, 231 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = overfeat.overfeat(inputs, num_classes) 91 | expected_names = ['overfeat/conv1', 92 | 'overfeat/pool1', 93 | 'overfeat/conv2', 94 | 'overfeat/pool2', 95 | 'overfeat/conv3', 96 | 'overfeat/conv4', 97 | 'overfeat/conv5', 98 | 'overfeat/pool5', 99 | 'overfeat/fc6', 100 | 'overfeat/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('overfeat/fc7')) 104 | 105 | def testModelVariables(self): 106 | batch_size = 5 107 | height, width = 231, 231 108 | num_classes = 1000 109 | with self.test_session(): 110 | inputs = tf.random_uniform((batch_size, height, width, 3)) 111 | overfeat.overfeat(inputs, num_classes) 112 | expected_names = ['overfeat/conv1/weights', 113 | 'overfeat/conv1/biases', 114 | 'overfeat/conv2/weights', 115 | 'overfeat/conv2/biases', 116 | 'overfeat/conv3/weights', 117 | 'overfeat/conv3/biases', 118 | 'overfeat/conv4/weights', 119 | 'overfeat/conv4/biases', 120 | 'overfeat/conv5/weights', 121 | 'overfeat/conv5/biases', 122 | 'overfeat/fc6/weights', 123 | 'overfeat/fc6/biases', 124 | 'overfeat/fc7/weights', 125 | 'overfeat/fc7/biases', 126 | 'overfeat/fc8/weights', 127 | 'overfeat/fc8/biases', 128 | ] 129 | model_variables = [v.op.name for v in slim.get_model_variables()] 130 | self.assertSetEqual(set(model_variables), set(expected_names)) 131 | 132 | def testEvaluation(self): 133 | batch_size = 2 134 | height, width = 231, 231 135 | num_classes = 1000 136 | with self.test_session(): 137 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 138 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | predictions = tf.argmax(logits, 1) 142 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 143 | 144 | def testTrainEvalWithReuse(self): 145 | train_batch_size = 2 146 | eval_batch_size = 1 147 | train_height, train_width = 231, 231 148 | eval_height, eval_width = 281, 281 149 | num_classes = 1000 150 | with self.test_session(): 151 | train_inputs = tf.random_uniform( 152 | (train_batch_size, train_height, train_width, 3)) 153 | logits, _ = overfeat.overfeat(train_inputs) 154 | self.assertListEqual(logits.get_shape().as_list(), 155 | [train_batch_size, num_classes]) 156 | tf.get_variable_scope().reuse_variables() 157 | eval_inputs = tf.random_uniform( 158 | (eval_batch_size, eval_height, eval_width, 3)) 159 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 160 | spatial_squeeze=False) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [eval_batch_size, 2, 2, num_classes]) 163 | logits = tf.reduce_mean(logits, [1, 2]) 164 | predictions = tf.argmax(logits, 1) 165 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 166 | 167 | def testForward(self): 168 | batch_size = 1 169 | height, width = 231, 231 170 | with self.test_session() as sess: 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | logits, _ = overfeat.overfeat(inputs) 173 | sess.run(tf.global_variables_initializer()) 174 | output = sess.run(logits) 175 | self.assertTrue(output.any()) 176 | 177 | if __name__ == '__main__': 178 | tf.test.main() 179 | -------------------------------------------------------------------------------- /convert_garbage_data.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import math 7 | import os 8 | import random 9 | import sys 10 | import tensorflow as tf 11 | 12 | from datasets import dataset_utils 13 | 14 | FLAGS = tf.app.flags.FLAGS 15 | 16 | # The number of images in the validation set. 17 | _NUM_VALIDATION = FLAGS.num_validation 18 | 19 | # Seed for repeatability. 20 | #_RANDOM_SEED = 0 21 | 22 | # The number of shards per dataset split. 23 | _NUM_SHARDS = 5 24 | 25 | 26 | class ImageReader(object): 27 | """Helper class that provides TensorFlow image coding utilities.""" 28 | 29 | def __init__(self): 30 | # Initializes function that decodes RGB JPEG data. 31 | self._decode_jpeg_data = tf.compat.v1.placeholder(dtype=tf.string) 32 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 33 | 34 | self._decode_png_data = tf.compat.v1.placeholder(dtype=tf.string) 35 | self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3) 36 | 37 | 38 | def read_jpeg_dims(self, sess, image_data): 39 | image = self.decode_jpeg(sess, image_data) 40 | return image.shape[0], image.shape[1] 41 | 42 | def decode_jpeg(self, sess, image_data): 43 | image = sess.run(self._decode_jpeg, 44 | feed_dict={self._decode_jpeg_data: image_data}) 45 | assert len(image.shape) == 3 46 | assert image.shape[2] == 3 47 | return image 48 | 49 | def read_png_dims(self, sess, image_data): 50 | image = self.decode_png(sess, image_data) 51 | return image.shape[0], image.shape[1] 52 | 53 | def decode_png(self, sess, image_data): 54 | image = sess.run(self._decode_png, 55 | feed_dict={self._decode_png_data: image_data}) 56 | assert len(image.shape) == 3 57 | assert image.shape[2] == 3 58 | return image 59 | 60 | def _get_filenames_and_classes(images_dir): 61 | #Returns a list of filenames and inferred class names. 62 | images_root = images_dir 63 | directories = [] 64 | class_names = [] 65 | for filename in os.listdir(images_root): 66 | path = os.path.join(images_root, filename) 67 | if os.path.isdir(path): 68 | directories.append(path) 69 | class_names.append(filename) 70 | 71 | photo_filenames = [] 72 | 73 | total = 0 74 | for directory in directories: 75 | i = 0 76 | for filename in os.listdir(directory): 77 | path = os.path.join(directory, filename) 78 | if(path.endswith("jpeg")|path.endswith("jpg")|path.endswith("png")): 79 | photo_filenames.append(path) 80 | else: 81 | continue 82 | i = i + 1 83 | total = total + 1 84 | #if(i>=65): 85 | # break 86 | #print(directory[directory.rindex("/")+1:], i) 87 | 88 | print("total: %d, %d" % (total, _NUM_SHARDS)) 89 | 90 | return photo_filenames, sorted(class_names) 91 | 92 | 93 | def _get_dataset_filename(dataset_dir, split_name, shard_id): 94 | output_filename = 'garbage_%s_%05d-of-%05d.tfrecord' % ( 95 | split_name, shard_id, _NUM_SHARDS) 96 | return os.path.join(dataset_dir, output_filename) 97 | 98 | 99 | def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): 100 | """Converts the given filenames to a TFRecord dataset. 101 | 102 | Args: 103 | split_name: The name of the dataset, either 'train' or 'validation'. 104 | filenames: A list of absolute paths to png or jpg images. 105 | class_names_to_ids: A dictionary from class names (strings) to ids 106 | (integers). 107 | dataset_dir: The directory where the converted datasets are stored. 108 | """ 109 | assert split_name in ['train', 'validation'] 110 | 111 | num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) 112 | 113 | with tf.Graph().as_default(): 114 | image_reader = ImageReader() 115 | 116 | with tf.compat.v1.Session('') as sess: 117 | 118 | for shard_id in range(_NUM_SHARDS): 119 | output_filename = _get_dataset_filename( 120 | dataset_dir, split_name, shard_id) 121 | 122 | with tf.io.TFRecordWriter(output_filename) as tfrecord_writer: 123 | start_ndx = shard_id * num_per_shard 124 | end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) 125 | print("start_ndx: %d"%start_ndx) 126 | print("end_ndx: %d"%end_ndx) 127 | for i in range(start_ndx, end_ndx): 128 | #sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( 129 | # i+1, len(filenames), shard_id)) 130 | #sys.stdout.flush() 131 | 132 | if i % 100 == 0: 133 | print("Convert dataset %d, shard %d" % (i, shard_id)) 134 | 135 | # Read the filename: 136 | image_data = tf.io.gfile.GFile(filenames[i], 'rb').read() 137 | 138 | if filenames[i].endswith(".jpg") | filenames[i].endswith(".jpeg"): 139 | height, width = image_reader.read_jpeg_dims(sess, image_data) 140 | elif filenames[i].endswith(".png"): 141 | height, width = image_reader.read_png_dims(sess, image_data) 142 | 143 | class_name = os.path.basename(os.path.dirname(filenames[i])) 144 | class_id = class_names_to_ids[class_name] 145 | 146 | example = dataset_utils.image_to_tfexample( 147 | image_data, b'jpg', height, width, class_id) 148 | tfrecord_writer.write(example.SerializeToString()) 149 | 150 | sys.stdout.write('\n') 151 | sys.stdout.flush() 152 | 153 | def _dataset_exists(dataset_dir): 154 | for split_name in ['train', 'validation']: 155 | for shard_id in range(_NUM_SHARDS): 156 | output_filename = _get_dataset_filename( 157 | dataset_dir, split_name, shard_id) 158 | if not tf.gfile.Exists(output_filename): 159 | return False 160 | return True 161 | 162 | 163 | def run(images_dir, dataset_dir, inference_dir, model_random_seed): 164 | """ 165 | Args: 166 | images_dir: The images directory where the jpeg is read. 167 | dataset_dir: The dataset directory where the dataset is stored. 168 | """ 169 | if (dataset_dir.endswith("data")) != True: 170 | print("Wrong dataset_dir name! dataset_dir must end with data.") 171 | exit(1) 172 | 173 | if _dataset_exists(dataset_dir): 174 | print('Dataset files already exist. Delete them firstly.') 175 | tf.gfile.DeleteRecursively(dataset_dir) 176 | #return 177 | 178 | if not tf.gfile.Exists(dataset_dir): 179 | tf.gfile.MakeDirs(dataset_dir) 180 | 181 | photo_filenames, class_names = _get_filenames_and_classes(images_dir) 182 | 183 | class_names_to_ids = dict(zip(class_names, range(len(class_names)))) 184 | ii = 0 185 | for fname in photo_filenames : 186 | if(fname.endswith("png")): 187 | ii = ii + 1 188 | print("Found png files: %d" % ii) 189 | 190 | #for class_name in class_names: 191 | # print(class_name) 192 | 193 | num_classes = len(class_names_to_ids) 194 | 195 | # Divide into train and test: 196 | random.seed(model_random_seed) 197 | random.shuffle(photo_filenames) 198 | training_filenames = photo_filenames[_NUM_VALIDATION:] 199 | validation_filenames = photo_filenames[:_NUM_VALIDATION] 200 | 201 | # First, convert the training and validation sets. 202 | _convert_dataset('train', training_filenames, class_names_to_ids, 203 | dataset_dir) 204 | _convert_dataset('validation', validation_filenames, class_names_to_ids, 205 | dataset_dir) 206 | 207 | # Finally, write the labels file: 208 | labels_to_class_names = dict(zip(range(len(class_names)), class_names)) 209 | dataset_utils.write_label_file(labels_to_class_names, dataset_dir) 210 | dataset_utils.write_label_file(labels_to_class_names, inference_dir) 211 | 212 | print('\nFinished converting the garbage dataset!') 213 | return num_classes 214 | -------------------------------------------------------------------------------- /nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testGlobalPool(self): 52 | batch_size = 1 53 | height, width = 256, 256 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False, 58 | global_pool=True) 59 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 60 | self.assertListEqual(logits.get_shape().as_list(), 61 | [batch_size, 1, 1, num_classes]) 62 | 63 | def testEndPoints(self): 64 | batch_size = 5 65 | height, width = 224, 224 66 | num_classes = 1000 67 | with self.test_session(): 68 | inputs = tf.random_uniform((batch_size, height, width, 3)) 69 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 70 | expected_names = ['alexnet_v2/conv1', 71 | 'alexnet_v2/pool1', 72 | 'alexnet_v2/conv2', 73 | 'alexnet_v2/pool2', 74 | 'alexnet_v2/conv3', 75 | 'alexnet_v2/conv4', 76 | 'alexnet_v2/conv5', 77 | 'alexnet_v2/pool5', 78 | 'alexnet_v2/fc6', 79 | 'alexnet_v2/fc7', 80 | 'alexnet_v2/fc8' 81 | ] 82 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 83 | 84 | def testNoClasses(self): 85 | batch_size = 5 86 | height, width = 224, 224 87 | num_classes = None 88 | with self.test_session(): 89 | inputs = tf.random_uniform((batch_size, height, width, 3)) 90 | net, end_points = alexnet.alexnet_v2(inputs, num_classes) 91 | expected_names = ['alexnet_v2/conv1', 92 | 'alexnet_v2/pool1', 93 | 'alexnet_v2/conv2', 94 | 'alexnet_v2/pool2', 95 | 'alexnet_v2/conv3', 96 | 'alexnet_v2/conv4', 97 | 'alexnet_v2/conv5', 98 | 'alexnet_v2/pool5', 99 | 'alexnet_v2/fc6', 100 | 'alexnet_v2/fc7' 101 | ] 102 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 103 | self.assertTrue(net.op.name.startswith('alexnet_v2/fc7')) 104 | self.assertListEqual(net.get_shape().as_list(), 105 | [batch_size, 1, 1, 4096]) 106 | 107 | def testModelVariables(self): 108 | batch_size = 5 109 | height, width = 224, 224 110 | num_classes = 1000 111 | with self.test_session(): 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | alexnet.alexnet_v2(inputs, num_classes) 114 | expected_names = ['alexnet_v2/conv1/weights', 115 | 'alexnet_v2/conv1/biases', 116 | 'alexnet_v2/conv2/weights', 117 | 'alexnet_v2/conv2/biases', 118 | 'alexnet_v2/conv3/weights', 119 | 'alexnet_v2/conv3/biases', 120 | 'alexnet_v2/conv4/weights', 121 | 'alexnet_v2/conv4/biases', 122 | 'alexnet_v2/conv5/weights', 123 | 'alexnet_v2/conv5/biases', 124 | 'alexnet_v2/fc6/weights', 125 | 'alexnet_v2/fc6/biases', 126 | 'alexnet_v2/fc7/weights', 127 | 'alexnet_v2/fc7/biases', 128 | 'alexnet_v2/fc8/weights', 129 | 'alexnet_v2/fc8/biases', 130 | ] 131 | model_variables = [v.op.name for v in slim.get_model_variables()] 132 | self.assertSetEqual(set(model_variables), set(expected_names)) 133 | 134 | def testEvaluation(self): 135 | batch_size = 2 136 | height, width = 224, 224 137 | num_classes = 1000 138 | with self.test_session(): 139 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 140 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 141 | self.assertListEqual(logits.get_shape().as_list(), 142 | [batch_size, num_classes]) 143 | predictions = tf.argmax(logits, 1) 144 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 145 | 146 | def testTrainEvalWithReuse(self): 147 | train_batch_size = 2 148 | eval_batch_size = 1 149 | train_height, train_width = 224, 224 150 | eval_height, eval_width = 300, 400 151 | num_classes = 1000 152 | with self.test_session(): 153 | train_inputs = tf.random_uniform( 154 | (train_batch_size, train_height, train_width, 3)) 155 | logits, _ = alexnet.alexnet_v2(train_inputs) 156 | self.assertListEqual(logits.get_shape().as_list(), 157 | [train_batch_size, num_classes]) 158 | tf.get_variable_scope().reuse_variables() 159 | eval_inputs = tf.random_uniform( 160 | (eval_batch_size, eval_height, eval_width, 3)) 161 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 162 | spatial_squeeze=False) 163 | self.assertListEqual(logits.get_shape().as_list(), 164 | [eval_batch_size, 4, 7, num_classes]) 165 | logits = tf.reduce_mean(logits, [1, 2]) 166 | predictions = tf.argmax(logits, 1) 167 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 168 | 169 | def testForward(self): 170 | batch_size = 1 171 | height, width = 224, 224 172 | with self.test_session() as sess: 173 | inputs = tf.random_uniform((batch_size, height, width, 3)) 174 | logits, _ = alexnet.alexnet_v2(inputs) 175 | sess.run(tf.global_variables_initializer()) 176 | output = sess.run(logits) 177 | self.assertTrue(output.any()) 178 | 179 | if __name__ == '__main__': 180 | tf.test.main() 181 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import i3d 27 | from nets import inception 28 | from nets import lenet 29 | from nets import mobilenet_v1 30 | from nets import overfeat 31 | from nets import resnet_v1 32 | from nets import resnet_v2 33 | from nets import s3dg 34 | from nets import vgg 35 | from nets.mobilenet import mobilenet_v2 36 | from nets.nasnet import nasnet 37 | from nets.nasnet import pnasnet 38 | 39 | 40 | slim = tf.contrib.slim 41 | 42 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 43 | 'cifarnet': cifarnet.cifarnet, 44 | 'overfeat': overfeat.overfeat, 45 | 'vgg_a': vgg.vgg_a, 46 | 'vgg_16': vgg.vgg_16, 47 | 'vgg_19': vgg.vgg_19, 48 | 'inception_v1': inception.inception_v1, 49 | 'inception_v2': inception.inception_v2, 50 | 'inception_v3': inception.inception_v3, 51 | 'inception_v4': inception.inception_v4, 52 | 'inception_resnet_v2': inception.inception_resnet_v2, 53 | 'i3d': i3d.i3d, 54 | 's3dg': s3dg.s3dg, 55 | 'lenet': lenet.lenet, 56 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 57 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 58 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 59 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 60 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 61 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 62 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 63 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 64 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 65 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 66 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 67 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 68 | 'mobilenet_v2': mobilenet_v2.mobilenet, 69 | 'mobilenet_v2_140': mobilenet_v2.mobilenet_v2_140, 70 | 'mobilenet_v2_035': mobilenet_v2.mobilenet_v2_035, 71 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 72 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 73 | 'nasnet_large': nasnet.build_nasnet_large, 74 | 'pnasnet_large': pnasnet.build_pnasnet_large, 75 | 'pnasnet_mobile': pnasnet.build_pnasnet_mobile, 76 | } 77 | 78 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 79 | 'cifarnet': cifarnet.cifarnet_arg_scope, 80 | 'overfeat': overfeat.overfeat_arg_scope, 81 | 'vgg_a': vgg.vgg_arg_scope, 82 | 'vgg_16': vgg.vgg_arg_scope, 83 | 'vgg_19': vgg.vgg_arg_scope, 84 | 'inception_v1': inception.inception_v3_arg_scope, 85 | 'inception_v2': inception.inception_v3_arg_scope, 86 | 'inception_v3': inception.inception_v3_arg_scope, 87 | 'inception_v4': inception.inception_v4_arg_scope, 88 | 'inception_resnet_v2': 89 | inception.inception_resnet_v2_arg_scope, 90 | 'i3d': i3d.i3d_arg_scope, 91 | 's3dg': s3dg.s3dg_arg_scope, 92 | 'lenet': lenet.lenet_arg_scope, 93 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 94 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 95 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 96 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 97 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 98 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 99 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 100 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 101 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 102 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 103 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 104 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 105 | 'mobilenet_v2': mobilenet_v2.training_scope, 106 | 'mobilenet_v2_035': mobilenet_v2.training_scope, 107 | 'mobilenet_v2_140': mobilenet_v2.training_scope, 108 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 109 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 110 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 111 | 'pnasnet_large': pnasnet.pnasnet_large_arg_scope, 112 | 'pnasnet_mobile': pnasnet.pnasnet_mobile_arg_scope, 113 | } 114 | 115 | 116 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 117 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 118 | 119 | Args: 120 | name: The name of the network. 121 | num_classes: The number of classes to use for classification. If 0 or None, 122 | the logits layer is omitted and its input features are returned instead. 123 | weight_decay: The l2 coefficient for the model weights. 124 | is_training: `True` if the model is being used for training and `False` 125 | otherwise. 126 | 127 | Returns: 128 | network_fn: A function that applies the model to a batch of images. It has 129 | the following signature: 130 | net, end_points = network_fn(images) 131 | The `images` input is a tensor of shape [batch_size, height, width, 3] 132 | with height = width = network_fn.default_image_size. (The permissibility 133 | and treatment of other sizes depends on the network_fn.) 134 | The returned `end_points` are a dictionary of intermediate activations. 135 | The returned `net` is the topmost layer, depending on `num_classes`: 136 | If `num_classes` was a non-zero integer, `net` is a logits tensor 137 | of shape [batch_size, num_classes]. 138 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 139 | to the logits layer of shape [batch_size, 1, 1, num_features] or 140 | [batch_size, num_features]. Dropout has not been applied to this 141 | (even if the network's original classification does); it remains for 142 | the caller to do this or not. 143 | 144 | Raises: 145 | ValueError: If network `name` is not recognized. 146 | """ 147 | if name not in networks_map: 148 | raise ValueError('Name of network unknown %s' % name) 149 | func = networks_map[name] 150 | @functools.wraps(func) 151 | def network_fn(images, **kwargs): 152 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 153 | with slim.arg_scope(arg_scope): 154 | return func(images, num_classes=num_classes, is_training=is_training, 155 | **kwargs) 156 | if hasattr(func, 'default_image_size'): 157 | network_fn.default_image_size = func.default_image_size 158 | 159 | return network_fn 160 | --------------------------------------------------------------------------------