├── BUILD ├── README.md ├── WORKSPACE ├── __init__.py ├── classify.py ├── datasets ├── __init__.py ├── build_imagenet_data.py ├── cifar10.py ├── cifar10 │ ├── cifar10.py │ └── tfrecord_cifar10_image_read.py ├── dataset_factory.py ├── dataset_utils.py ├── download_and_convert_cifar10.py ├── download_and_convert_flowers.py ├── download_and_convert_imagenet.sh ├── download_and_convert_mnist.py ├── download_and_convert_mydata.py ├── download_imagenet.sh ├── faces.py ├── flowers.py ├── imagenet.py ├── imagenet_2012_validation_synset_labels.txt ├── imagenet_lsvrc_2015_synsets.txt ├── imagenet_metadata.txt ├── mnist.py ├── mydata.py ├── preprocess_imagenet_validation_data.py └── process_bounding_boxes.py ├── deployment ├── __init__.py ├── __init__.pyc ├── model_deploy.py └── model_deploy_test.py ├── download_and_convert_data.py ├── eval_image_classifier.py ├── export_inference_graph.py ├── export_inference_graph_test.py ├── fine-tune.py ├── nets ├── Robin_network │ ├── densenet.py │ ├── dsod.py │ ├── mobilenetv1_version1.py │ ├── mobilenetv2_version1.py │ ├── mobilenetv2_version2.py │ ├── mobilenetv2_version3.py │ ├── pelee.py │ ├── resnet_v1_robin.py │ ├── resnetxt.py │ ├── se-resne_v1.py │ ├── shufflenet_v1.py │ ├── shufflenet_v1_version1.py │ ├── shufflenet_v1_version2.py │ ├── shufflenet_v2.py │ ├── shufflenet_v2_version1.py │ ├── shufflenet_v2_version2.py │ ├── squeezenet.py │ ├── tiny_dsod.py │ └── xception.py ├── __init__.py ├── alexnet.py ├── alexnet_test.py ├── cifarnet.py ├── cyclegan.py ├── cyclegan_test.py ├── dcgan.py ├── dcgan_test.py ├── deeplab_v1-3 │ ├── __init__.py │ ├── deeplabv2-resnet.py │ ├── deeplabv2-vgg-lfov.py │ ├── deeplabv3-resnet-plus.py │ ├── deeplabv3-resnet.py │ ├── resnet_utils.py │ └── resnet_v1.py ├── facenet_backbone │ ├── MobileFaceNet.py │ ├── ShuffleFaceNet.py │ ├── SqueezeFaceNet.py │ ├── inception_resnet_v1.py │ ├── inception_resnet_v2.py │ └── squeezenet.py ├── inception.py ├── inception_resnet_v2.py ├── inception_resnet_v2_test.py ├── inception_utils.py ├── inception_v1.py ├── inception_v1_test.py ├── inception_v2.py ├── inception_v2_test.py ├── inception_v3.py ├── inception_v3_test.py ├── inception_v4.py ├── inception_v4_test.py ├── lenet.py ├── mobilenet │ ├── README.md │ ├── __init__.py │ ├── conv_blocks.py │ ├── madds_top1_accuracy.png │ ├── mnet_v1_vs_v2_pixel1_latency.png │ ├── mobilenet.py │ ├── mobilenet_example.ipynb │ ├── mobilenet_v2.py │ └── mobilenet_v2_test.py ├── mobilenet_v1.md ├── mobilenet_v1.png ├── mobilenet_v1.py ├── mobilenet_v1_eval.py ├── mobilenet_v1_test.py ├── mobilenet_v1_train.py ├── nasnet │ ├── README.md │ ├── __init__.py │ ├── __init__.pyc │ ├── nasnet.py │ ├── nasnet.pyc │ ├── nasnet_test.py │ ├── nasnet_utils.py │ ├── nasnet_utils.pyc │ ├── nasnet_utils_test.py │ ├── pnasnet.py │ └── pnasnet_test.py ├── nets_factory.py ├── nets_factory_test.py ├── overfeat.py ├── overfeat_test.py ├── pix2pix.py ├── pix2pix_test.py ├── resnet_utils.py ├── resnet_v1.py ├── resnet_v1_test.py ├── resnet_v2.py ├── resnet_v2_test.py ├── shufflenet.py ├── shufflenet_test.py ├── squeezenet.py ├── squeezenet_test.py ├── vgg.py ├── vgg_test.py └── xception.py ├── preprocessing ├── __init__.py ├── cifarnet_preprocessing.py ├── inception_preprocessing.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── vgg_preprocessing.py ├── scripts ├── README.md ├── create_tfrecord.sh ├── download_COCO.sh ├── download_VOC.sh ├── download_classification_Pre-trained_Models.sh ├── download_classification_mobile.sh ├── download_detection_model.sh ├── download_imagenet-data.sh ├── download_kitti.sh ├── download_pet_dataset.sh ├── export_interface_graph.sh ├── export_interface_graph_for_mobilenet.sh ├── export_interface_graph_for_squeezenet.sh ├── export_mobilenet.sh ├── finetune_inception_resnet_v2_on_flowers.sh ├── finetune_inception_v1_on_flowers.sh ├── finetune_inception_v3_on_flowers.sh ├── finetune_mobilenet_v1_on_flowers.sh ├── finetune_resnet_v1_50_on_flowers.sh ├── finetune_shufflenet_on_flowers.sh ├── finetune_shufflenet_on_imagenet.sh ├── finetune_squeezenet_on_flowers.sh ├── finetune_squeezenet_on_imagenet.sh ├── generate_validation.sh ├── run.sh ├── run_create_kitti_tf_record.sh ├── run_create_pascal_tf_record.sh ├── run_download_and_preprocess_mscoco.sh ├── train_cifarnet_on_cifar10.sh └── train_lenet_on_mnist.sh ├── setup.py ├── slim_models_demo ├── First_Student_IC_school_bus_202076.jpg ├── Inception_v1_demo.py ├── Inception_v1_demo_locally.py ├── frozen_graph.py ├── label_flower_from_pb.py ├── print_ops_from_pb.py ├── resnet_demo.py ├── test_image_classifier.py ├── vgg_demo.py ├── vgg_demo_Segmentation.py └── vgg_demo_locally.py ├── slim_walkthrough.ipynb ├── test_image_classifier.py ├── test_image_classifier_batch.py ├── tfrecord_fine_tune_model_for_other_set_demo.py ├── tfrecord_image_decode.py ├── tfrecord_image_decode_cifar10.py ├── tfrecord_image_decode_flowers.py ├── tfrecord_image_decode_imagenet.py ├── tfrecord_image_decode_mnist.py ├── tfrecord_image_decode_mydata.py ├── tfrecord_image_read.py ├── tfrecord_inference_model_for_images_demo.py ├── tools ├── BUILD ├── classify_image_on_imagenet.py ├── freeze_graph.py ├── freeze_graph_test.py ├── imagenet_2012_challenge_label_map_proto.pbtxt ├── imagenet_synset_to_human_label_map.txt ├── import_pb_to_tensorboard.py ├── insert_placeholder.py ├── inspect_checkpoint.py ├── label_image │ ├── BUILD │ ├── README.md │ ├── data │ │ └── grace_hopper.jpg │ ├── label_image.py │ └── main.cc ├── label_load_freeze_graph.py ├── optimize_for_inference.py ├── optimize_for_inference_lib.py ├── optimize_for_inference_test.py ├── print_ops_from_pb.py ├── print_selective_registration_header.py ├── print_selective_registration_header_test.py ├── quantization │ ├── BUILD │ ├── graph_to_dot.py │ ├── quantize_graph.py │ └── quantize_graph_test.py ├── saved_model_cli.py ├── saved_model_cli_test.py ├── saved_model_utils.py ├── selective_registration_header_lib.py ├── strip_unused.py ├── strip_unused_lib.py ├── strip_unused_test.py └── summary │ ├── README.md │ ├── __init__.py │ ├── plugin_asset.py │ ├── plugin_asset_test.py │ ├── summary.py │ ├── summary_iterator.py │ ├── summary_test.py │ ├── text_summary.py │ ├── text_summary_test.py │ └── writer │ ├── event_file_writer.py │ ├── writer.py │ ├── writer_cache.py │ └── writer_test.py ├── train_and_eval.py └── train_image_classifier.py /README.md: -------------------------------------------------------------------------------- 1 | # Tensorflow_Model_Slim_Classify 2 | Train/Eval the popular network from model slim,include mobilenet/shufflenet/squeezenet/resnet/inception/vgg/alexnet 3 | 4 | 5 | ## References 6 | 7 | 1. [squeezenet](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/Robin_network/squeezenet.py): [SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and <0.5MB model size](https://arxiv.org/abs/1602.07360) 8 | 9 | 2. [mobilenetv1](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/Robin_network/mobilenetv1_version1.py): [MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications](https://arxiv.org/abs/1704.04861) 10 | 11 | 3. [shufflenetv1](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/Robin_network/shufflenet_v1_version1.py): [ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices](https://arxiv.org/abs/1707.01083) 12 | 13 | 4. [mobilenetv2](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/Robin_network/mobilenetv1_version1.py): [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 14 | 15 | 5. [resnet](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/Robin_network/resnet_v1_robin.py):[ Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) 16 | 17 | 6. [xception](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/Robin_network/xception.py): [Xception: Deep Learning with Depthwise Separable Convolutions](https://arxiv.org/abs/1610.02357) 18 | 19 | 7. [MobileFaceNets](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/facenet_backbone/MobileFaceNet.py): [MobileFaceNets: Efficient CNNs for Accurate Real-time Face Verification on Mobile Devices](https://arxiv.org/abs/1804.07573) 20 | 21 | 8. [shufflenetv2](https://github.com/Robinatp/Tensorflow_Model_Slim_Classify/blob/master/nets/Robin_network/shufflenet_v2_version1.py): [ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design ](https://arxiv.org/abs/1807.11164) 22 | -------------------------------------------------------------------------------- /WORKSPACE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/WORKSPACE -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/__init__.py -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_cifar10.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 50000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [32 x 32 x 3] color image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if not reader: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/cifar10/tfrecord_cifar10_image_read.py: -------------------------------------------------------------------------------- 1 | 2 | import tensorflow as tf 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | from cifar10 import Cifar10DataSet 6 | 7 | import os 8 | import sys 9 | 10 | # This is needed since the notebook is stored in the object_detection folder. 11 | TF_API="/home/robin/eclipse-workspace-python/TF_models/models/research/slim" 12 | sys.path.append(os.path.split(TF_API)[0]) 13 | sys.path.append(TF_API) 14 | 15 | from datasets import cifar10 16 | from datasets import flowers 17 | from datasets import imagenet 18 | from datasets import mnist 19 | from datasets import mydata 20 | from preprocessing import cifarnet_preprocessing 21 | from preprocessing import inception_preprocessing 22 | from preprocessing import lenet_preprocessing 23 | from preprocessing import vgg_preprocessing 24 | 25 | 26 | from tensorflow.contrib import slim 27 | 28 | cifar10_data_dir = "/home/robin/Dataset/cifar10" 29 | 30 | 31 | if __name__ == "__main__": 32 | with tf.Graph().as_default(): 33 | subset = 'train' 34 | dataset = Cifar10DataSet(cifar10_data_dir, subset, False) 35 | image_batch, label_batch = dataset.make_batch(2) 36 | print(image_batch, label_batch) 37 | 38 | 39 | 40 | 41 | with tf.Session() as sess: 42 | with slim.queues.QueueRunners(sess): 43 | for i in range(10): 44 | np_image, np_labels = sess.run([image_batch, label_batch]) 45 | 46 | 47 | # plt.figure() 48 | # plt.imshow(np_image[0]) 49 | # plt.title('%s, %d x %d' % (name, height, width)) 50 | # plt.axis('off') 51 | # plt.show() 52 | 53 | print("labels :", dataset.labels_to_names[np_labels[0]]) 54 | print("--------------") 55 | cv2.imshow('image:',np_image[0]/255 ) 56 | 57 | cv2.waitKey(0) 58 | 59 | 60 | 61 | 62 | 63 | 64 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | from datasets import mydata 26 | 27 | datasets_map = { 28 | 'cifar10': cifar10, 29 | 'flowers': flowers, 30 | 'imagenet': imagenet, 31 | 'mnist': mnist, 32 | "mydata":mydata 33 | } 34 | 35 | 36 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 37 | """Given a dataset name and a split_name returns a Dataset. 38 | 39 | Args: 40 | name: String, the name of the dataset. 41 | split_name: A train/test split name. 42 | dataset_dir: The directory where the dataset files are stored. 43 | file_pattern: The file pattern to use for matching the dataset source files. 44 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 45 | reader defined by each dataset is used. 46 | 47 | Returns: 48 | A `Dataset` class. 49 | 50 | Raises: 51 | ValueError: If the dataset `name` is unknown. 52 | """ 53 | if name not in datasets_map: 54 | raise ValueError('Name of dataset unknown %s' % name) 55 | return datasets_map[name].get_split( 56 | split_name, 57 | dataset_dir, 58 | file_pattern, 59 | reader) 60 | -------------------------------------------------------------------------------- /datasets/download_and_convert_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download and preprocess ImageNet Challenge 2012 18 | # training and validation data set. 19 | # 20 | # The final output of this script are sharded TFRecord files containing 21 | # serialized Example protocol buffers. See build_imagenet_data.py for 22 | # details of how the Example protocol buffers contain the ImageNet data. 23 | # 24 | # The final output of this script appears as such: 25 | # 26 | # data_dir/train-00000-of-01024 27 | # data_dir/train-00001-of-01024 28 | # ... 29 | # data_dir/train-00127-of-01024 30 | # 31 | # and 32 | # 33 | # data_dir/validation-00000-of-00128 34 | # data_dir/validation-00001-of-00128 35 | # ... 36 | # data_dir/validation-00127-of-00128 37 | # 38 | # Note that this script may take several hours to run to completion. The 39 | # conversion of the ImageNet data to TFRecords alone takes 2-3 hours depending 40 | # on the speed of your machine. Please be patient. 41 | # 42 | # **IMPORTANT** 43 | # To download the raw images, the user must create an account with image-net.org 44 | # and generate a username and access_key. The latter two are required for 45 | # downloading the raw images. 46 | # 47 | # usage: 48 | # cd research/slim 49 | # bazel build :download_and_convert_imagenet 50 | # ./bazel-bin/download_and_convert_imagenet.sh [data-dir] 51 | set -e 52 | 53 | if [ -z "$1" ]; then 54 | echo "usage download_and_convert_imagenet.sh [data dir]" 55 | exit 56 | fi 57 | 58 | # Create the output and temporary directories. 59 | DATA_DIR="${1%/}" 60 | SCRATCH_DIR="${DATA_DIR}/raw-data/" 61 | mkdir -p "${DATA_DIR}" 62 | mkdir -p "${SCRATCH_DIR}" 63 | WORK_DIR="$0.runfiles/__main__" 64 | 65 | # Download the ImageNet data. 66 | LABELS_FILE="${WORK_DIR}/datasets/imagenet_lsvrc_2015_synsets.txt" 67 | DOWNLOAD_SCRIPT="${WORK_DIR}/datasets/download_imagenet.sh" 68 | "${DOWNLOAD_SCRIPT}" "${SCRATCH_DIR}" "${LABELS_FILE}" 69 | 70 | # Note the locations of the train and validation data. 71 | TRAIN_DIRECTORY="${SCRATCH_DIR}train/" 72 | VALIDATION_DIRECTORY="${SCRATCH_DIR}validation/" 73 | 74 | # Preprocess the validation data by moving the images into the appropriate 75 | # sub-directory based on the label (synset) of the image. 76 | echo "Organizing the validation data into sub-directories." 77 | PREPROCESS_VAL_SCRIPT="${WORK_DIR}/datasets/preprocess_imagenet_validation_data.py" 78 | VAL_LABELS_FILE="${WORK_DIR}/datasets/imagenet_2012_validation_synset_labels.txt" 79 | 80 | "${PREPROCESS_VAL_SCRIPT}" "${VALIDATION_DIRECTORY}" "${VAL_LABELS_FILE}" 81 | 82 | # Convert the XML files for bounding box annotations into a single CSV. 83 | echo "Extracting bounding box information from XML." 84 | BOUNDING_BOX_SCRIPT="${WORK_DIR}/datasets/process_bounding_boxes.py" 85 | BOUNDING_BOX_FILE="${SCRATCH_DIR}/imagenet_2012_bounding_boxes.csv" 86 | BOUNDING_BOX_DIR="${SCRATCH_DIR}bounding_boxes/" 87 | 88 | "${BOUNDING_BOX_SCRIPT}" "${BOUNDING_BOX_DIR}" "${LABELS_FILE}" \ 89 | | sort >"${BOUNDING_BOX_FILE}" 90 | echo "Finished downloading and preprocessing the ImageNet data." 91 | 92 | # Build the TFRecords version of the ImageNet data. 93 | BUILD_SCRIPT="${WORK_DIR}/build_imagenet_data" 94 | OUTPUT_DIRECTORY="${DATA_DIR}" 95 | IMAGENET_METADATA_FILE="${WORK_DIR}/datasets/imagenet_metadata.txt" 96 | 97 | "${BUILD_SCRIPT}" \ 98 | --train_directory="${TRAIN_DIRECTORY}" \ 99 | --validation_directory="${VALIDATION_DIRECTORY}" \ 100 | --output_directory="${OUTPUT_DIRECTORY}" \ 101 | --imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \ 102 | --labels_file="${LABELS_FILE}" \ 103 | --bounding_box_file="${BOUNDING_BOX_FILE}" 104 | -------------------------------------------------------------------------------- /datasets/download_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download ImageNet Challenge 2012 training and validation data set. 18 | # 19 | # Downloads and decompresses raw images and bounding boxes. 20 | # 21 | # **IMPORTANT** 22 | # To download the raw images, the user must create an account with image-net.org 23 | # and generate a username and access_key. The latter two are required for 24 | # downloading the raw images. 25 | # 26 | # usage: 27 | # ./download_imagenet.sh [dirname] 28 | set -e 29 | 30 | if [ "x$IMAGENET_ACCESS_KEY" == x -o "x$IMAGENET_USERNAME" == x ]; then 31 | cat < ') 62 | sys.exit(-1) 63 | data_dir = sys.argv[1] 64 | validation_labels_file = sys.argv[2] 65 | 66 | # Read in the 50000 synsets associated with the validation data set. 67 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 68 | unique_labels = set(labels) 69 | 70 | # Make all sub-directories in the validation data dir. 71 | for label in unique_labels: 72 | labeled_data_dir = os.path.join(data_dir, label) 73 | os.makedirs(labeled_data_dir) 74 | 75 | # Move all of the image to the appropriate sub-directory. 76 | for i in xrange(len(labels)): 77 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 78 | original_filename = os.path.join(data_dir, basename) 79 | if not os.path.exists(original_filename): 80 | print('Failed to find: ' % original_filename) 81 | sys.exit(-1) 82 | new_filename = os.path.join(data_dir, labels[i], basename) 83 | os.rename(original_filename, new_filename) 84 | -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deployment/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/deployment/__init__.pyc -------------------------------------------------------------------------------- /download_and_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import download_and_convert_cifar10 40 | from datasets import download_and_convert_flowers 41 | from datasets import download_and_convert_mnist 42 | 43 | FLAGS = tf.app.flags.FLAGS 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_name', 47 | None, 48 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 49 | 50 | tf.app.flags.DEFINE_string( 51 | 'dataset_dir', 52 | None, 53 | 'The directory where the output TFRecords and temporary files are saved.') 54 | 55 | 56 | def main(_): 57 | if not FLAGS.dataset_name: 58 | raise ValueError('You must supply the dataset name with --dataset_name') 59 | if not FLAGS.dataset_dir: 60 | raise ValueError('You must supply the dataset directory with --dataset_dir') 61 | 62 | if FLAGS.dataset_name == 'cifar10': 63 | download_and_convert_cifar10.run(FLAGS.dataset_dir) 64 | elif FLAGS.dataset_name == 'flowers': 65 | download_and_convert_flowers.run(FLAGS.dataset_dir) 66 | elif FLAGS.dataset_name == 'mnist': 67 | download_and_convert_mnist.run(FLAGS.dataset_dir) 68 | else: 69 | raise ValueError( 70 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) 71 | 72 | if __name__ == '__main__': 73 | tf.app.run() 74 | -------------------------------------------------------------------------------- /export_inference_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for export_inference_graph.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | 25 | import tensorflow as tf 26 | 27 | from tensorflow.python.platform import gfile 28 | import export_inference_graph 29 | 30 | 31 | class ExportInferenceGraphTest(tf.test.TestCase): 32 | 33 | def testExportInferenceGraph(self): 34 | tmpdir = self.get_temp_dir() 35 | output_file = os.path.join(tmpdir, 'inception_v3.pb') 36 | flags = tf.app.flags.FLAGS 37 | flags.output_file = output_file 38 | flags.model_name = 'inception_v3' 39 | flags.dataset_dir = tmpdir 40 | export_inference_graph.main(None) 41 | self.assertTrue(gfile.Exists(output_file)) 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /nets/Robin_network/mobilenetv2_version3.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import tensorflow.contrib as tc 3 | 4 | import numpy as np 5 | import time 6 | 7 | import tensorflow.contrib.slim as slim 8 | 9 | 10 | class MobileNetV2(object): 11 | def __init__(self, is_training=True, input_size=224): 12 | self.input_size = input_size 13 | self.is_training = is_training 14 | self.normalizer = tc.layers.batch_norm 15 | self.bn_params = {'is_training': self.is_training} 16 | 17 | with tf.variable_scope('MobileNetV2'): 18 | self._create_placeholders() 19 | self._build_model() 20 | 21 | def _create_placeholders(self): 22 | self.input = tf.placeholder(dtype=tf.float32, shape=[None, self.input_size, self.input_size, 3]) 23 | 24 | 25 | def _build_model(self): 26 | self.i = 0 27 | with tf.variable_scope('init_conv'): 28 | output = tc.layers.conv2d(self.input, 32, 3, 2, 29 | normalizer_fn=self.normalizer, normalizer_params=self.bn_params) 30 | print(output.get_shape()) 31 | self.output = self._inverted_bottleneck(output, 1, 16, 0) 32 | self.output = self._inverted_bottleneck(self.output, 6, 24, 1) 33 | self.output = self._inverted_bottleneck(self.output, 6, 24, 0) 34 | self.output = self._inverted_bottleneck(self.output, 6, 32, 1) 35 | self.output = self._inverted_bottleneck(self.output, 6, 32, 0) 36 | self.output = self._inverted_bottleneck(self.output, 6, 32, 0) 37 | self.output = self._inverted_bottleneck(self.output, 6, 64, 1) 38 | self.output = self._inverted_bottleneck(self.output, 6, 64, 0) 39 | self.output = self._inverted_bottleneck(self.output, 6, 64, 0) 40 | self.output = self._inverted_bottleneck(self.output, 6, 64, 0) 41 | self.output = self._inverted_bottleneck(self.output, 6, 96, 0) 42 | self.output = self._inverted_bottleneck(self.output, 6, 96, 0) 43 | self.output = self._inverted_bottleneck(self.output, 6, 96, 0) 44 | self.output = self._inverted_bottleneck(self.output, 6, 160, 1) 45 | self.output = self._inverted_bottleneck(self.output, 6, 160, 0) 46 | self.output = self._inverted_bottleneck(self.output, 6, 160, 0) 47 | self.output = self._inverted_bottleneck(self.output, 6, 320, 0) 48 | self.output = tc.layers.conv2d(self.output, 1280, 1, normalizer_fn=self.normalizer, normalizer_params=self.bn_params) 49 | self.output = tc.layers.avg_pool2d(self.output, 7) 50 | self.output = tc.layers.conv2d(self.output, 1000, 1, activation_fn=None) 51 | 52 | 53 | def _inverted_bottleneck(self, input, up_sample_rate, channels, subsample): 54 | with tf.variable_scope('inverted_bottleneck{}_{}_{}'.format(self.i, up_sample_rate, subsample)): 55 | self.i += 1 56 | stride = 2 if subsample else 1 57 | output = tc.layers.conv2d(input, up_sample_rate*input.get_shape().as_list()[-1], 1, 58 | activation_fn=tf.nn.relu6, 59 | normalizer_fn=self.normalizer, normalizer_params=self.bn_params) 60 | output = tc.layers.separable_conv2d(output, None, 3, 1, stride=stride, 61 | activation_fn=tf.nn.relu6, 62 | normalizer_fn=self.normalizer, normalizer_params=self.bn_params) 63 | output = tc.layers.conv2d(output, channels, 1, activation_fn=None, 64 | normalizer_fn=self.normalizer, normalizer_params=self.bn_params) 65 | if input.get_shape().as_list()[-1] == channels: 66 | output = tf.add(input, output) 67 | return output 68 | 69 | 70 | 71 | if __name__ == '__main__': 72 | model = MobileNetV2(False) 73 | print(model.output.get_shape()) 74 | board_writer = tf.summary.FileWriter(logdir='logs', graph=tf.get_default_graph()) 75 | 76 | print("Parameters") 77 | for v in slim.get_model_variables(): 78 | print('name = {}, shape = {}'.format(v.name, v.get_shape())) 79 | 80 | fake_data = np.ones(shape=(1, 224, 224, 3)) 81 | 82 | sess_config = tf.ConfigProto(device_count={'GPU':0}) 83 | with tf.Session(config=sess_config) as sess: 84 | sess.run(tf.global_variables_initializer()) 85 | 86 | cnt = 0 87 | for i in range(101): 88 | t1 = time.time() 89 | output = sess.run(model.output, feed_dict={model.input: fake_data}) 90 | if i != 0: 91 | cnt += time.time() - t1 92 | print(cnt / 100) 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | from nets import dcgan 24 | 25 | 26 | class DCGANTest(tf.test.TestCase): 27 | 28 | def test_generator_run(self): 29 | tf.set_random_seed(1234) 30 | noise = tf.random_normal([100, 64]) 31 | image, _ = dcgan.generator(noise) 32 | with self.test_session() as sess: 33 | sess.run(tf.global_variables_initializer()) 34 | image.eval() 35 | 36 | def test_generator_graph(self): 37 | tf.set_random_seed(1234) 38 | # Check graph construction for a number of image size/depths and batch 39 | # sizes. 40 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 41 | tf.reset_default_graph() 42 | final_size = 2 ** i 43 | noise = tf.random_normal([batch_size, 64]) 44 | image, end_points = dcgan.generator( 45 | noise, 46 | depth=32, 47 | final_size=final_size) 48 | 49 | self.assertAllEqual([batch_size, final_size, final_size, 3], 50 | image.shape.as_list()) 51 | 52 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 53 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 54 | 55 | # Check layer depths. 56 | for j in range(1, i): 57 | layer = end_points['deconv%i' % j] 58 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 59 | 60 | def test_generator_invalid_input(self): 61 | wrong_dim_input = tf.zeros([5, 32, 32]) 62 | with self.assertRaises(ValueError): 63 | dcgan.generator(wrong_dim_input) 64 | 65 | correct_input = tf.zeros([3, 2]) 66 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 67 | dcgan.generator(correct_input, final_size=30) 68 | 69 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 70 | dcgan.generator(correct_input, final_size=4) 71 | 72 | def test_discriminator_run(self): 73 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 74 | output, _ = dcgan.discriminator(image) 75 | with self.test_session() as sess: 76 | sess.run(tf.global_variables_initializer()) 77 | output.eval() 78 | 79 | def test_discriminator_graph(self): 80 | # Check graph construction for a number of image size/depths and batch 81 | # sizes. 82 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 83 | tf.reset_default_graph() 84 | img_w = 2 ** i 85 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 86 | output, end_points = dcgan.discriminator( 87 | image, 88 | depth=32) 89 | 90 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 91 | 92 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 93 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 94 | 95 | # Check layer depths. 96 | for j in range(1, i+1): 97 | layer = end_points['conv%i' % j] 98 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 99 | 100 | def test_discriminator_invalid_input(self): 101 | wrong_dim_img = tf.zeros([5, 32, 32]) 102 | with self.assertRaises(ValueError): 103 | dcgan.discriminator(wrong_dim_img) 104 | 105 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 106 | with self.assertRaises(ValueError): 107 | dcgan.discriminator(spatially_undefined_shape) 108 | 109 | not_square = tf.zeros([5, 32, 16, 3]) 110 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 111 | dcgan.discriminator(not_square) 112 | 113 | not_power_2 = tf.zeros([5, 30, 30, 3]) 114 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 115 | dcgan.discriminator(not_power_2) 116 | 117 | 118 | if __name__ == '__main__': 119 | tf.test.main() 120 | -------------------------------------------------------------------------------- /nets/deeplab_v1-3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/deeplab_v1-3/__init__.py -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu): 37 | """Defines the default arg scope for inception models. 38 | 39 | Args: 40 | weight_decay: The weight decay to use for regularizing the model. 41 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 42 | batch_norm_decay: Decay for batch norm moving average. 43 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 44 | in batch norm. 45 | activation_fn: Activation function for conv2d. 46 | 47 | Returns: 48 | An `arg_scope` to use for the inception models. 49 | """ 50 | batch_norm_params = { 51 | # Decay for the moving averages. 52 | 'decay': batch_norm_decay, 53 | # epsilon to prevent 0s in variance. 54 | 'epsilon': batch_norm_epsilon, 55 | # collection containing update_ops. 56 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 57 | # use fused batch norm if possible. 58 | 'fused': None, 59 | } 60 | if use_batch_norm: 61 | normalizer_fn = slim.batch_norm 62 | normalizer_params = batch_norm_params 63 | else: 64 | normalizer_fn = None 65 | normalizer_params = {} 66 | # Set weight_decay for weights in Conv and FC layers. 67 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 68 | weights_regularizer=slim.l2_regularizer(weight_decay)): 69 | with slim.arg_scope( 70 | [slim.conv2d], 71 | weights_initializer=slim.variance_scaling_initializer(), 72 | activation_fn=activation_fn, 73 | normalizer_fn=normalizer_fn, 74 | normalizer_params=normalizer_params) as sc: 75 | return sc 76 | -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /nets/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 2 | This folder contains building code for MobileNetV2, based on 3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 4 | 5 | # Performance 6 | ## Latency 7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using 8 | TF-Lite on the large core of Pixel 1 phone. 9 | 10 | ![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png) 11 | 12 | ## MACs 13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed 14 | to compute an inference on a single image is a common metric to measure the efficiency of the model. 15 | 16 | Below is the graph comparing V2 vs a few selected networks. The size 17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there 18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers. 19 | 20 | ![madds_top1_accuracy](madds_top1_accuracy.png) 21 | 22 | # Pretrained models 23 | ## Imagenet Checkpoints 24 | 25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1 26 | ---------------------------|---------|---------------|---------|----|------------- 27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0 28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0 29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8 30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1 31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2 32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6 33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6 34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8 35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6 36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4 37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9 38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2 39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7 40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1 41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9 42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9 43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4 44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7 45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6 46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5 47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9 48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5 49 | 50 | # Example 51 | 52 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb). 53 | 54 | -------------------------------------------------------------------------------- /nets/mobilenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/mobilenet/__init__.py -------------------------------------------------------------------------------- /nets/mobilenet/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/mobilenet/madds_top1_accuracy.png -------------------------------------------------------------------------------- /nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/nasnet/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/nasnet/__init__.pyc -------------------------------------------------------------------------------- /nets/nasnet/nasnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/nasnet/nasnet.pyc -------------------------------------------------------------------------------- /nets/nasnet/nasnet_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/nets/nasnet/nasnet_utils.pyc -------------------------------------------------------------------------------- /nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map.keys()[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in list(nets_factory.networks_map.keys())[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'nasnet_mobile': inception_preprocessing, 58 | 'nasnet_large': inception_preprocessing, 59 | 'pnasnet_large': inception_preprocessing, 60 | 'resnet_v1_50': vgg_preprocessing, 61 | 'resnet_v1_101': vgg_preprocessing, 62 | 'resnet_v1_152': vgg_preprocessing, 63 | 'resnet_v1_200': vgg_preprocessing, 64 | 'resnet_v2_50': vgg_preprocessing, 65 | 'resnet_v2_101': vgg_preprocessing, 66 | 'resnet_v2_152': vgg_preprocessing, 67 | 'resnet_v2_200': vgg_preprocessing, 68 | 'vgg': vgg_preprocessing, 69 | 'vgg_a': vgg_preprocessing, 70 | 'vgg_16': vgg_preprocessing, 71 | 'vgg_19': vgg_preprocessing, 72 | 'shufflenet':inception_preprocessing, 73 | 'squeezenet':inception_preprocessing, 74 | } 75 | 76 | if name not in preprocessing_fn_map: 77 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 78 | 79 | def preprocessing_fn(image, output_height, output_width, **kwargs): 80 | return preprocessing_fn_map[name].preprocess_image( 81 | image, output_height, output_width, is_training=is_training, **kwargs) 82 | 83 | return preprocessing_fn 84 | -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | 1,train_lenet_on_mnist.sh 2 | train and eval the lenet by dataset of mnist,you should set TRAIN_DIR and DATASET_DIR 3 | 4 | 2,train_cifarnet_on_cifar10.sh 5 | train and eval the cifarnet by dataset of cifar10,you should set TRAIN_DIR and DATASET_DIR 6 | 7 | 3,finetune_mobilenet_v1_on_flowers.sh 8 | finetune ,train and eval the mobilenet_v1 by dataset of flowers,you should set TRAIN_DIR and DATASET_DIR.when finetuning the mobilenet, 9 | you must download mobilenet_v1_1.0_224_2017_06_14.tar.gz and set PRETRAINED_CHECKPOINT_DIR 10 | 11 | -------------------------------------------------------------------------------- /scripts/create_tfrecord.sh: -------------------------------------------------------------------------------- 1 | python ../../dataset_tools/create_pet_tf_record.py \ 2 | --data_dir=/workspace/zhangbin/master/tensorflow_models/models/research/object_detection/01_pet_dataset/dataset \ 3 | --output_dir=/workspace/zhangbin/master/tensorflow_models/models/research/object_detection/01_pet_dataset/dataset \ 4 | --label_map_path=/workspace/zhangbin/master/tensorflow_models/models/research/object_detection/01_pet_dataset/dataset/pet_label_map.pbtxt \ 5 | --faces_only=False 6 | -------------------------------------------------------------------------------- /scripts/download_COCO.sh: -------------------------------------------------------------------------------- 1 | wget http://msvocds.blob.core.windows.net/coco2014/train2014.zip 2 | -------------------------------------------------------------------------------- /scripts/download_VOC.sh: -------------------------------------------------------------------------------- 1 | wget https://pjreddie.com/media/files/VOCtrainval_11-May-2012.tar 2 | wget https://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar 3 | wget https://pjreddie.com/media/files/VOCtest_06-Nov-2007.tar 4 | tar xf VOCtrainval_11-May-2012.tar 5 | tar xf VOCtrainval_06-Nov-2007.tar 6 | tar xf VOCtest_06-Nov-2007.tar 7 | -------------------------------------------------------------------------------- /scripts/download_classification_Pre-trained_Models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 4 | wget http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz 5 | wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 6 | wget http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz 7 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 8 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 9 | wget http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz 10 | wget http://download.tensorflow.org/models/resnet_v1_152_2016_08_28.tar.gz 11 | wget http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz 12 | wget http://download.tensorflow.org/models/resnet_v2_101_2017_04_14.tar.gz 13 | wget http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz 14 | wget http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz 15 | wget http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz 16 | wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 17 | wget http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz 18 | wget http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz 19 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 20 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz 21 | -------------------------------------------------------------------------------- /scripts/download_classification_mobile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 4 | wget http://download.tensorflow.org/models/mobilenet_v1_1.0_192_2017_06_14.tar.gz 5 | wget http://download.tensorflow.org/models/mobilenet_v1_1.0_160_2017_06_14.tar.gz 6 | wget http://download.tensorflow.org/models/mobilenet_v1_1.0_128_2017_06_14.tar.gz 7 | wget http://download.tensorflow.org/models/mobilenet_v1_0.75_224_2017_06_14.tar.gz 8 | wget http://download.tensorflow.org/models/mobilenet_v1_0.75_192_2017_06_14.tar.gz 9 | wget http://download.tensorflow.org/models/mobilenet_v1_0.75_160_2017_06_14.tar.gz 10 | wget http://download.tensorflow.org/models/mobilenet_v1_0.75_128_2017_06_14.tar.gz 11 | wget http://download.tensorflow.org/models/mobilenet_v1_0.50_224_2017_06_14.tar.gz 12 | wget http://download.tensorflow.org/models/mobilenet_v1_0.50_192_2017_06_14.tar.gz 13 | wget http://download.tensorflow.org/models/mobilenet_v1_0.50_160_2017_06_14.tar.gz 14 | wget http://download.tensorflow.org/models/mobilenet_v1_0.50_128_2017_06_14.tar.gz 15 | wget http://download.tensorflow.org/models/mobilenet_v1_0.25_224_2017_06_14.tar.gz 16 | wget http://download.tensorflow.org/models/mobilenet_v1_0.25_192_2017_06_14.tar.gz 17 | wget http://download.tensorflow.org/models/mobilenet_v1_0.25_160_2017_06_14.tar.gz 18 | wget http://download.tensorflow.org/models/mobilenet_v1_0.25_128_2017_06_14.tar.gz 19 | -------------------------------------------------------------------------------- /scripts/download_detection_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_2017_11_17.tar.gz 4 | wget http://download.tensorflow.org/models/object_detection/ssd_inception_v2_coco_2017_11_17.tar.gz 5 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_v2_coco_2017_11_08.tar.gz 6 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_coco_2017_11_08.tar.gz 7 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet50_lowproposals_coco_2017_11_08.tar.gz 8 | wget http://download.tensorflow.org/models/object_detection/rfcn_resnet101_coco_2017_11_08.tar.gz 9 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_coco_2017_11_08.tar.gz 10 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_lowproposals_coco_2017_11_08.tar.gz 11 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_coco_2017_11_08.tar.gz 12 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_lowproposals_coco_2017_11_08.tar.gz 13 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_coco_2017_11_08.tar.gz 14 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_nas_lowproposals_coco_2017_11_08.tar.gz 15 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_resnet101_kitti_2017_11_08.tar.gz 16 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_oid_2017_11_08.tar.gz 17 | wget http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_lowproposals_oid_2017_11_08.tar.gz 18 | -------------------------------------------------------------------------------- /scripts/download_imagenet-data.sh: -------------------------------------------------------------------------------- 1 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_val.tar 2 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_test.tar 3 | wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train.tar 4 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_img_train_t3.tar 5 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_bbox_train_v2.tar.gz 6 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_bbox_train_dogs.tar.gz 7 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_bbox_val_v3.tgz 8 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_bbox_test_dogs.zip 9 | #wget -t 0 -c -i http://www.image-net.org/challenges/LSVRC/2012/nnoupb/ILSVRC2012_devkit_t12.tar 10 | -------------------------------------------------------------------------------- /scripts/download_kitti.sh: -------------------------------------------------------------------------------- 1 | wget http://kitti.is.tue.mpg.de/kitti/data_object_label_2.zip 2 | wget http://kitti.is.tue.mpg.de/kitti/data_object_image_2.zip 3 | -------------------------------------------------------------------------------- /scripts/download_pet_dataset.sh: -------------------------------------------------------------------------------- 1 | 2 | wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz 3 | wget http://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz 4 | -------------------------------------------------------------------------------- /scripts/export_interface_graph.sh: -------------------------------------------------------------------------------- 1 | SLIM_NAME=mobilenet_v1 2 | MODEL_FOLDER=./tmp/ 3 | DATASET_DIR=/workspace/zhangbin/dataset_robin/flowers 4 | echo "Freezing graph to ${MODEL_FOLDER}/unfrozen_graph.pb" 5 | python export_inference_graph.py \ 6 | --model_name=${SLIM_NAME} \ 7 | --image_size=224 \ 8 | --dataset_name=flowers \ 9 | --dataset_dir=${DATASET_DIR} \ 10 | --logtostderr \ 11 | --output_file=${MODEL_FOLDER}/unfrozen_graph.pb 12 | 13 | 14 | echo "*******" 15 | echo "Freezing graph to ${MODEL_FOLDER}/frozen_graph.pb" 16 | echo "*******" 17 | CHECKPOINT=./tmp/flowers-models/mobilenet_v1/all/model.ckpt-5032 18 | OUTPUT_NODE_NAMES=MobilenetV1/Predictions/Reshape_1 19 | python tools/freeze_graph.py \ 20 | --input_graph=${MODEL_FOLDER}/unfrozen_graph.pb \ 21 | --input_checkpoint=${CHECKPOINT} \ 22 | --input_binary=true \ 23 | --output_graph=${MODEL_FOLDER}/frozen_graph.pb \ 24 | --output_node_names=${OUTPUT_NODE_NAMES} 25 | 26 | python tools/optimize_for_inference.py \ 27 | --input=${MODEL_FOLDER}/frozen_graph.pb \ 28 | --output=${MODEL_FOLDER}/optimized_graph.pb \ 29 | --frozen_graph=True \ 30 | --input_names=input \ 31 | --output_names=${OUTPUT_NODE_NAMES} 32 | 33 | 34 | python tools/quantization/quantize_graph.py \ 35 | --input=${MODEL_FOLDER}/frozen_graph.pb \ 36 | --output_node_names=${OUTPUT_NODE_NAMES} \ 37 | --print_nodes \ 38 | --output=${MODEL_FOLDER}/quantized_graph.pb \ 39 | --mode=eightbit \ 40 | --logtostderr -------------------------------------------------------------------------------- /scripts/export_interface_graph_for_mobilenet.sh: -------------------------------------------------------------------------------- 1 | SLIM_NAME=mobilenet_v1 2 | MODEL_FOLDER=./tmp/ 3 | DATASET_DIR=/workspace/zhangbin/dataset_robin/flowers 4 | echo "Freezing graph to ${MODEL_FOLDER}/unfrozen_graph.pb" 5 | python export_inference_graph.py \ 6 | --model_name=${SLIM_NAME} \ 7 | --image_size=224 \ 8 | --dataset_name=flowers \ 9 | --dataset_dir=${DATASET_DIR} \ 10 | --logtostderr \ 11 | --output_file=${MODEL_FOLDER}/unfrozen_graph.pb 12 | 13 | 14 | echo "*******" 15 | echo "Freezing graph to ${MODEL_FOLDER}/frozen_graph.pb" 16 | echo "*******" 17 | CHECKPOINT=./tmp/flowers-models/mobilenet_v1/all/model.ckpt-5032 18 | OUTPUT_NODE_NAMES=MobilenetV1/Predictions/Reshape_1 19 | python tools/freeze_graph.py \ 20 | --input_graph=${MODEL_FOLDER}/unfrozen_graph.pb \ 21 | --input_checkpoint=${CHECKPOINT} \ 22 | --input_binary=true \ 23 | --output_graph=${MODEL_FOLDER}/frozen_graph.pb \ 24 | --output_node_names=${OUTPUT_NODE_NAMES} 25 | 26 | python tools/optimize_for_inference.py \ 27 | --input=${MODEL_FOLDER}/frozen_graph.pb \ 28 | --output=${MODEL_FOLDER}/optimized_graph.pb \ 29 | --frozen_graph=True \ 30 | --input_names=input \ 31 | --output_names=${OUTPUT_NODE_NAMES} 32 | 33 | 34 | python tools/quantization/quantize_graph.py \ 35 | --input=${MODEL_FOLDER}/frozen_graph.pb \ 36 | --output_node_names=${OUTPUT_NODE_NAMES} \ 37 | --print_nodes \ 38 | --output=${MODEL_FOLDER}/quantized_graph.pb \ 39 | --mode=eightbit \ 40 | --logtostderr -------------------------------------------------------------------------------- /scripts/export_interface_graph_for_squeezenet.sh: -------------------------------------------------------------------------------- 1 | SLIM_NAME=squeezenet 2 | MODEL_FOLDER=./tmp/ 3 | DATASET_DIR=/workspace/zhangbin/dataset_robin/flowers 4 | echo "Freezing graph to ${MODEL_FOLDER}/unfrozen_graph.pb" 5 | python export_inference_graph.py \ 6 | --model_name=${SLIM_NAME} \ 7 | --image_size=224 \ 8 | --dataset_name=flowers \ 9 | --dataset_dir=${DATASET_DIR} \ 10 | --logtostderr \ 11 | --output_file=${MODEL_FOLDER}/unfrozen_graph.pb 12 | 13 | 14 | echo "*******" 15 | echo "Freezing graph to ${MODEL_FOLDER}/frozen_graph.pb" 16 | echo "*******" 17 | CHECKPOINT=./tmp/flowers-models/squeezenet/all/model.ckpt-43419 18 | OUTPUT_NODE_NAMES=SqueezeNet/Predictions/Reshape_1 19 | python tools/freeze_graph.py \ 20 | --input_graph=${MODEL_FOLDER}/unfrozen_graph.pb \ 21 | --input_checkpoint=${CHECKPOINT} \ 22 | --input_binary=true \ 23 | --output_graph=${MODEL_FOLDER}/frozen_graph.pb \ 24 | --output_node_names=${OUTPUT_NODE_NAMES} 25 | 26 | python tools/optimize_for_inference.py \ 27 | --input=${MODEL_FOLDER}/frozen_graph.pb \ 28 | --output=${MODEL_FOLDER}/optimized_graph.pb \ 29 | --frozen_graph=True \ 30 | --input_names=input \ 31 | --output_names=${OUTPUT_NODE_NAMES} 32 | 33 | 34 | python tools/quantization/quantize_graph.py \ 35 | --input=${MODEL_FOLDER}/frozen_graph.pb \ 36 | --output_node_names=${OUTPUT_NODE_NAMES} \ 37 | --print_nodes \ 38 | --output=${MODEL_FOLDER}/quantized_graph.pb \ 39 | --mode=eightbit \ 40 | --logtostderr 41 | -------------------------------------------------------------------------------- /scripts/finetune_inception_resnet_v2_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an Inception Resnet V2 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_resnet_v2_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the pre-trained Inception Resnet V2 checkpoint is saved to. 31 | MODEL_NAME=inception_resnet_v2 32 | 33 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 34 | TRAIN_DIR=/tmp/flowers-models/${MODEL_NAME} 35 | 36 | # Where the dataset is saved to. 37 | DATASET_DIR=/tmp/flowers 38 | 39 | # Download the pre-trained checkpoint. 40 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 41 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 42 | fi 43 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt ]; then 44 | wget http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz 45 | tar -xvf inception_resnet_v2_2016_08_30.tar.gz 46 | mv inception_resnet_v2.ckpt ${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt 47 | rm inception_resnet_v2_2016_08_30.tar.gz 48 | fi 49 | 50 | # Download the dataset 51 | python download_and_convert_data.py \ 52 | --dataset_name=flowers \ 53 | --dataset_dir=${DATASET_DIR} 54 | 55 | # Fine-tune only the new layers for 1000 steps. 56 | python train_image_classifier.py \ 57 | --train_dir=${TRAIN_DIR} \ 58 | --dataset_name=flowers \ 59 | --dataset_split_name=train \ 60 | --dataset_dir=${DATASET_DIR} \ 61 | --model_name=${MODEL_NAME} \ 62 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/${MODEL_NAME}.ckpt \ 63 | --checkpoint_exclude_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 64 | --trainable_scopes=InceptionResnetV2/Logits,InceptionResnetV2/AuxLogits \ 65 | --max_number_of_steps=1000 \ 66 | --batch_size=32 \ 67 | --learning_rate=0.01 \ 68 | --learning_rate_decay_type=fixed \ 69 | --save_interval_secs=60 \ 70 | --save_summaries_secs=60 \ 71 | --log_every_n_steps=10 \ 72 | --optimizer=rmsprop \ 73 | --weight_decay=0.00004 74 | 75 | # Run evaluation. 76 | python eval_image_classifier.py \ 77 | --checkpoint_path=${TRAIN_DIR} \ 78 | --eval_dir=${TRAIN_DIR} \ 79 | --dataset_name=flowers \ 80 | --dataset_split_name=validation \ 81 | --dataset_dir=${DATASET_DIR} \ 82 | --model_name=${MODEL_NAME} 83 | 84 | # Fine-tune all the new layers for 500 steps. 85 | python train_image_classifier.py \ 86 | --train_dir=${TRAIN_DIR}/all \ 87 | --dataset_name=flowers \ 88 | --dataset_split_name=train \ 89 | --dataset_dir=${DATASET_DIR} \ 90 | --model_name=${MODEL_NAME} \ 91 | --checkpoint_path=${TRAIN_DIR} \ 92 | --max_number_of_steps=500 \ 93 | --batch_size=32 \ 94 | --learning_rate=0.0001 \ 95 | --learning_rate_decay_type=fixed \ 96 | --save_interval_secs=60 \ 97 | --save_summaries_secs=60 \ 98 | --log_every_n_steps=10 \ 99 | --optimizer=rmsprop \ 100 | --weight_decay=0.00004 101 | 102 | # Run evaluation. 103 | python eval_image_classifier.py \ 104 | --checkpoint_path=${TRAIN_DIR}/all \ 105 | --eval_dir=${TRAIN_DIR}/all \ 106 | --dataset_name=flowers \ 107 | --dataset_split_name=validation \ 108 | --dataset_dir=${DATASET_DIR} \ 109 | --model_name=${MODEL_NAME} 110 | -------------------------------------------------------------------------------- /scripts/finetune_inception_v1_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 25 | #set -e 26 | 27 | # Where the pre-trained InceptionV1 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=./tmp/checkpoints/inception_v1 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=./tmp/flowers-models/inception_v1 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/workspace/zhangbin/dataset_robin/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 42 | tar -xvf inception_v1_2016_08_28.tar.gz 43 | mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 44 | # rm inception_v1_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 2000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v1 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV1/Logits \ 61 | --trainable_scopes=InceptionV1/Logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 \ 70 | --clone_on_cpu=True 71 | 72 | # Run evaluation. 73 | python eval_image_classifier.py \ 74 | --checkpoint_path=${TRAIN_DIR} \ 75 | --eval_dir=${TRAIN_DIR} \ 76 | --dataset_name=flowers \ 77 | --dataset_split_name=validation \ 78 | --dataset_dir=${DATASET_DIR} \ 79 | --model_name=inception_v1 80 | 81 | # Fine-tune all the new layers for 1000 steps. 82 | python train_image_classifier.py \ 83 | --train_dir=${TRAIN_DIR}/all \ 84 | --dataset_name=flowers \ 85 | --dataset_split_name=train \ 86 | --dataset_dir=${DATASET_DIR} \ 87 | --checkpoint_path=${TRAIN_DIR} \ 88 | --model_name=inception_v1 \ 89 | --max_number_of_steps=1000 \ 90 | --batch_size=32 \ 91 | --learning_rate=0.001 \ 92 | --save_interval_secs=60 \ 93 | --save_summaries_secs=60 \ 94 | --log_every_n_steps=100 \ 95 | --optimizer=rmsprop \ 96 | --weight_decay=0.00004 \ 97 | --clone_on_cpu=True 98 | 99 | # Run evaluation. 100 | python eval_image_classifier.py \ 101 | --checkpoint_path=${TRAIN_DIR}/all \ 102 | --eval_dir=${TRAIN_DIR}/all \ 103 | --dataset_name=flowers \ 104 | --dataset_split_name=validation \ 105 | --dataset_dir=${DATASET_DIR} \ 106 | --model_name=inception_v1 107 | -------------------------------------------------------------------------------- /scripts/finetune_inception_v3_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV3 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v3_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained InceptionV3 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/inception_v3 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 42 | tar -xvf inception_v3_2016_08_28.tar.gz 43 | mv inception_v3.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt 44 | rm inception_v3_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 1000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v3 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 61 | --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 62 | --max_number_of_steps=1000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --learning_rate_decay_type=fixed \ 66 | --save_interval_secs=60 \ 67 | --save_summaries_secs=60 \ 68 | --log_every_n_steps=100 \ 69 | --optimizer=rmsprop \ 70 | --weight_decay=0.00004 71 | 72 | # Run evaluation. 73 | python eval_image_classifier.py \ 74 | --checkpoint_path=${TRAIN_DIR} \ 75 | --eval_dir=${TRAIN_DIR} \ 76 | --dataset_name=flowers \ 77 | --dataset_split_name=validation \ 78 | --dataset_dir=${DATASET_DIR} \ 79 | --model_name=inception_v3 80 | 81 | # Fine-tune all the new layers for 500 steps. 82 | python train_image_classifier.py \ 83 | --train_dir=${TRAIN_DIR}/all \ 84 | --dataset_name=flowers \ 85 | --dataset_split_name=train \ 86 | --dataset_dir=${DATASET_DIR} \ 87 | --model_name=inception_v3 \ 88 | --checkpoint_path=${TRAIN_DIR} \ 89 | --max_number_of_steps=500 \ 90 | --batch_size=32 \ 91 | --learning_rate=0.0001 \ 92 | --learning_rate_decay_type=fixed \ 93 | --save_interval_secs=60 \ 94 | --save_summaries_secs=60 \ 95 | --log_every_n_steps=10 \ 96 | --optimizer=rmsprop \ 97 | --weight_decay=0.00004 98 | 99 | # Run evaluation. 100 | python eval_image_classifier.py \ 101 | --checkpoint_path=${TRAIN_DIR}/all \ 102 | --eval_dir=${TRAIN_DIR}/all \ 103 | --dataset_name=flowers \ 104 | --dataset_split_name=validation \ 105 | --dataset_dir=${DATASET_DIR} \ 106 | --model_name=inception_v3 107 | -------------------------------------------------------------------------------- /scripts/finetune_mobilenet_v1_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 25 | #set -e 26 | 27 | # Where the pre-trained InceptionV1 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=./tmp/checkpoints/mobilenet_v1_1.0_224 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=./tmp/flowers-models/mobilenet_v1 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/workspace/zhangbin/dataset_robin/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/mobilenet_v1_1.0_224.ckpt ]; then 41 | #wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 42 | tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz 43 | #mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 44 | #rm inception_v1_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 2000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=mobilenet_v1 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/mobilenet_v1_1.0_224.ckpt \ 60 | --checkpoint_exclude_scopes=MobilenetV1/Logits \ 61 | --trainable_scopes=MobilenetV1/Logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 \ 70 | --clone_on_cpu=True 71 | 72 | # Run evaluation. 73 | python eval_image_classifier.py \ 74 | --checkpoint_path=${TRAIN_DIR} \ 75 | --eval_dir=${TRAIN_DIR} \ 76 | --dataset_name=flowers \ 77 | --dataset_split_name=validation \ 78 | --dataset_dir=${DATASET_DIR} \ 79 | --model_name=mobilenet_v1 80 | 81 | # Fine-tune all the new layers for 1000 steps. 82 | python train_image_classifier.py \ 83 | --train_dir=${TRAIN_DIR}/all \ 84 | --dataset_name=flowers \ 85 | --dataset_split_name=train \ 86 | --dataset_dir=${DATASET_DIR} \ 87 | --checkpoint_path=${TRAIN_DIR} \ 88 | --model_name=mobilenet_v1 \ 89 | --max_number_of_steps=1000 \ 90 | --batch_size=32 \ 91 | --learning_rate=0.001 \ 92 | --save_interval_secs=60 \ 93 | --save_summaries_secs=60 \ 94 | --log_every_n_steps=100 \ 95 | --optimizer=rmsprop \ 96 | --weight_decay=0.00004 \ 97 | --clone_on_cpu=True 98 | 99 | # Run evaluation. 100 | python eval_image_classifier.py \ 101 | --checkpoint_path=${TRAIN_DIR}/all \ 102 | --eval_dir=${TRAIN_DIR}/all \ 103 | --dataset_name=flowers \ 104 | --dataset_split_name=validation \ 105 | --dataset_dir=${DATASET_DIR} \ 106 | --model_name=mobilenet_v1 107 | python test_image_classifier.py \ 108 | --checkpoint_path=${TRAIN_DIR}/all \ 109 | --test_path=/workspace/zhangbin/dataset_robin/flowers/flower_photos/roses/3550491463_3eb092054c_m.jpg \ 110 | --num_classes=5 \ 111 | --label_path=/workspace/zhangbin/dataset_robin/flowers/labels.txt \ 112 | --model_name=mobilenet_v1 113 | -------------------------------------------------------------------------------- /scripts/finetune_resnet_v1_50_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes a ResNetV1-50 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained ResNetV1-50 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/resnet_v1_50 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then 41 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 42 | tar -xvf resnet_v1_50_2016_08_28.tar.gz 43 | mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt 44 | rm resnet_v1_50_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 3000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=resnet_v1_50 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \ 60 | --checkpoint_exclude_scopes=resnet_v1_50/logits \ 61 | --trainable_scopes=resnet_v1_50/logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 70 | 71 | # Run evaluation. 72 | python eval_image_classifier.py \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --eval_dir=${TRAIN_DIR} \ 75 | --dataset_name=flowers \ 76 | --dataset_split_name=validation \ 77 | --dataset_dir=${DATASET_DIR} \ 78 | --model_name=resnet_v1_50 79 | 80 | # Fine-tune all the new layers for 1000 steps. 81 | python train_image_classifier.py \ 82 | --train_dir=${TRAIN_DIR}/all \ 83 | --dataset_name=flowers \ 84 | --dataset_split_name=train \ 85 | --dataset_dir=${DATASET_DIR} \ 86 | --checkpoint_path=${TRAIN_DIR} \ 87 | --model_name=resnet_v1_50 \ 88 | --max_number_of_steps=1000 \ 89 | --batch_size=32 \ 90 | --learning_rate=0.001 \ 91 | --save_interval_secs=60 \ 92 | --save_summaries_secs=60 \ 93 | --log_every_n_steps=100 \ 94 | --optimizer=rmsprop \ 95 | --weight_decay=0.00004 96 | 97 | # Run evaluation. 98 | python eval_image_classifier.py \ 99 | --checkpoint_path=${TRAIN_DIR}/all \ 100 | --eval_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=flowers \ 102 | --dataset_split_name=validation \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=resnet_v1_50 105 | -------------------------------------------------------------------------------- /scripts/finetune_shufflenet_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 25 | #set -e 26 | 27 | # Where the pre-trained InceptionV1 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=./tmp/checkpoints/shufflenet 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=./tmp/flowers-models/shufflenet 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/workspace/zhangbin/dataset_robin/flowers 35 | 36 | if [ 0>1 ]; then 37 | # Download the pre-trained checkpoint. 38 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 39 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 40 | fi 41 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/mobilenet_v1_1.0_224.ckpt ]; then 42 | #wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 43 | tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz 44 | #mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 45 | #rm inception_v1_2016_08_28.tar.gz 46 | fi 47 | 48 | # Download the dataset 49 | python download_and_convert_data.py \ 50 | --dataset_name=flowers \ 51 | --dataset_dir=${DATASET_DIR} 52 | 53 | # Fine-tune only the new layers for 2000 steps. 54 | python train_image_classifier.py \ 55 | --train_dir=${TRAIN_DIR} \ 56 | --dataset_name=flowers \ 57 | --dataset_split_name=train \ 58 | --dataset_dir=${DATASET_DIR} \ 59 | --model_name=mobilenet_v1 \ 60 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/mobilenet_v1_1.0_224.ckpt \ 61 | --checkpoint_exclude_scopes=MobilenetV1/Logits \ 62 | --trainable_scopes=MobilenetV1/Logits \ 63 | --max_number_of_steps=3000 \ 64 | --batch_size=32 \ 65 | --learning_rate=0.01 \ 66 | --save_interval_secs=60 \ 67 | --save_summaries_secs=60 \ 68 | --log_every_n_steps=100 \ 69 | --optimizer=rmsprop \ 70 | --weight_decay=0.00004 \ 71 | --clone_on_cpu=True 72 | 73 | # Run evaluation. 74 | python eval_image_classifier.py \ 75 | --checkpoint_path=${TRAIN_DIR} \ 76 | --eval_dir=${TRAIN_DIR} \ 77 | --dataset_name=flowers \ 78 | --dataset_split_name=validation \ 79 | --dataset_dir=${DATASET_DIR} \ 80 | --model_name=mobilenet_v1 81 | fi 82 | 83 | # Fine-tune all the new layers for 1000 steps. 84 | python train_image_classifier.py \ 85 | --train_dir=${TRAIN_DIR}/all \ 86 | --dataset_name=flowers \ 87 | --dataset_split_name=train \ 88 | --dataset_dir=${DATASET_DIR} \ 89 | --model_name=shufflenet \ 90 | --max_number_of_steps=30000 \ 91 | --batch_size=32 \ 92 | --learning_rate=0.001 \ 93 | --save_interval_secs=600 \ 94 | --save_summaries_secs=6000 \ 95 | --log_every_n_steps=1 \ 96 | --optimizer=rmsprop \ 97 | --weight_decay=0.00004 \ 98 | --clone_on_cpu=True 99 | 100 | # Run evaluation. 101 | python eval_image_classifier.py \ 102 | --checkpoint_path=${TRAIN_DIR}/all \ 103 | --eval_dir=${TRAIN_DIR}/all \ 104 | --dataset_name=flowers \ 105 | --dataset_split_name=validation \ 106 | --dataset_dir=${DATASET_DIR} \ 107 | --model_name=shufflenet 108 | python test_image_classifier.py \ 109 | --checkpoint_path=${TRAIN_DIR}/all \ 110 | --test_path=/workspace/zhangbin/dataset_robin/flowers/flower_photos/roses/3550491463_3eb092054c_m.jpg \ 111 | --num_classes=5 \ 112 | --label_path=/workspace/zhangbin/dataset_robin/flowers/labels.txt \ 113 | --model_name=shufflenet 114 | -------------------------------------------------------------------------------- /scripts/finetune_shufflenet_on_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 25 | #set -e 26 | 27 | # Where the pre-trained InceptionV1 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=./tmp/checkpoints/shufflenet 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=./tmp/imagenet-models/shufflenet 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/workspace/zhangbin/dataset_robin/imagenet-data/raw-data/tfrecod 35 | 36 | if [ 0>1 ]; then 37 | # Download the pre-trained checkpoint. 38 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 39 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 40 | fi 41 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/mobilenet_v1_1.0_224.ckpt ]; then 42 | #wget http://download.tensorflow.org/models/mobilenet_v1_1.0_224_2017_06_14.tar.gz 43 | tar -xvf mobilenet_v1_1.0_224_2017_06_14.tar.gz 44 | #mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 45 | #rm inception_v1_2016_08_28.tar.gz 46 | fi 47 | 48 | # Download the dataset 49 | python download_and_convert_data.py \ 50 | --dataset_name=flowers \ 51 | --dataset_dir=${DATASET_DIR} 52 | 53 | # Fine-tune only the new layers for 2000 steps. 54 | python train_image_classifier.py \ 55 | --train_dir=${TRAIN_DIR} \ 56 | --dataset_name=flowers \ 57 | --dataset_split_name=train \ 58 | --dataset_dir=${DATASET_DIR} \ 59 | --model_name=mobilenet_v1 \ 60 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/mobilenet_v1_1.0_224.ckpt \ 61 | --checkpoint_exclude_scopes=MobilenetV1/Logits \ 62 | --trainable_scopes=MobilenetV1/Logits \ 63 | --max_number_of_steps=3000 \ 64 | --batch_size=32 \ 65 | --learning_rate=0.01 \ 66 | --save_interval_secs=60 \ 67 | --save_summaries_secs=60 \ 68 | --log_every_n_steps=100 \ 69 | --optimizer=rmsprop \ 70 | --weight_decay=0.00004 \ 71 | --clone_on_cpu=True 72 | 73 | # Run evaluation. 74 | python eval_image_classifier.py \ 75 | --checkpoint_path=${TRAIN_DIR} \ 76 | --eval_dir=${TRAIN_DIR} \ 77 | --dataset_name=flowers \ 78 | --dataset_split_name=validation \ 79 | --dataset_dir=${DATASET_DIR} \ 80 | --model_name=mobilenet_v1 81 | fi 82 | 83 | # Fine-tune all the new layers for 1000 steps. 84 | python train_image_classifier.py \ 85 | --train_dir=${TRAIN_DIR}/all \ 86 | --dataset_name=imagenet \ 87 | --dataset_split_name=validation \ 88 | --dataset_dir=${DATASET_DIR} \ 89 | --model_name=shufflenet \ 90 | --max_number_of_steps=300000 \ 91 | --batch_size=32 \ 92 | --learning_rate=0.0001 \ 93 | --save_interval_secs=600 \ 94 | --save_summaries_secs=6000 \ 95 | --log_every_n_steps=1 \ 96 | --optimizer=rmsprop \ 97 | --weight_decay=0.00004 \ 98 | --clone_on_cpu=True 99 | 100 | # Run evaluation. 101 | python eval_image_classifier.py \ 102 | --checkpoint_path=${TRAIN_DIR}/all \ 103 | --eval_dir=${TRAIN_DIR}/all \ 104 | --dataset_name=imagenet \ 105 | --dataset_split_name=validation \ 106 | --dataset_dir=${DATASET_DIR} \ 107 | --model_name=shufflenet 108 | -------------------------------------------------------------------------------- /scripts/generate_validation.sh: -------------------------------------------------------------------------------- 1 | #Note the locations of the train and validation data. 2 | TRAIN_DIRECTORY="train" 3 | VALIDATION_DIRECTORY="validation" 4 | 5 | # list of 5 labels: daisy, dandelion, roses, sunflowers, tulips 6 | LABELS_FILE="labels.txt" 7 | ls -1 "${TRAIN_DIRECTORY}" | grep -v 'LICENSE' | sed 's/\///' | sort > "${LABELS_FILE}" 8 | 9 | # Generate the validation data set. 10 | while read LABEL; do 11 | VALIDATION_DIR_FOR_LABEL="${VALIDATION_DIRECTORY}/${LABEL}" 12 | TRAIN_DIR_FOR_LABEL="${TRAIN_DIRECTORY}/${LABEL}" 13 | 14 | # Move the first randomly selected 100 images to the validation set. 15 | mkdir -p "${VALIDATION_DIR_FOR_LABEL}" 16 | VALIDATION_IMAGES=$(ls -1 "${TRAIN_DIR_FOR_LABEL}" | shuf | head -5) 17 | for IMAGE in ${VALIDATION_IMAGES}; do 18 | mv -f "${TRAIN_DIRECTORY}/${LABEL}/${IMAGE}" "${VALIDATION_DIR_FOR_LABEL}" 19 | done 20 | done < "${LABELS_FILE}" 21 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | protoc object_detection/protos/*.proto --python_out=. 2 | export PYTHONPATH= 3 | export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim 4 | echo $PYTHONPATH 5 | 6 | PATH_TO_YOUR_PIPELINE_CONFIG=`pwd`/object_detection/01_pet_dataset/model/mask_rcnn_inception_v2_coco_2018_01_28/mask_rcnn_inception_v2_coco.config 7 | PATH_TO_TRAIN_DIR=`pwd`/object_detection/01_pet_dataset/model/mask_rcnn_inception_v2_coco_2018_01_28/train 8 | PATH_TO_EVAL_DIR=`pwd`/object_detection/01_pet_dataset/model/mask_rcnn_inception_v2_coco_2018_01_28/eval 9 | 10 | echo $PATH_TO_YOUR_PIPELINE_CONFIG 11 | ##test the env 12 | python object_detection/builders/model_builder_test.py 13 | 14 | 15 | echo "input command:train or eval or export:" 16 | read a 17 | echo "input is $a" 18 | 19 | 20 | ##create the tfrecord files 21 | #python object_detection/create_pet_tf_record.py \ 22 | # --label_map_path=`pwd`/object_detection/pets_tf_tutorials/data/pet_label_map.pbtxt \ 23 | # --data_dir=`pwd`/object_detection/pets_tf_tutorials/data/dataset/ \ 24 | # --output_dir=`pwd`/object_detection/pets_tf_tutorials/data/ 25 | 26 | if [ $a = train ] ; then 27 | 28 | ## From the tensorflow/models/research/ directory 29 | #python object_detection/train.py \ 30 | # --logtostderr \ 31 | # --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 32 | # --train_dir=${PATH_TO_TRAIN_DIR} 33 | 34 | python object_detection/train.py \ 35 | --logtostderr \ 36 | --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 37 | --train_dir=${PATH_TO_TRAIN_DIR} 38 | fi 39 | 40 | 41 | if [ $a = eval ] ; then 42 | ## From the tensorflow/models/research/ directory 43 | #python object_detection/eval.py \ 44 | # --logtostderr \ 45 | # --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 46 | # --checkpoint_dir=${PATH_TO_TRAIN_DIR} \ 47 | # --eval_dir=${PATH_TO_EVAL_DIR} 48 | 49 | 50 | python object_detection/eval.py \ 51 | --logtostderr \ 52 | --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 53 | --checkpoint_dir=${PATH_TO_TRAIN_DIR} \ 54 | --eval_dir=${PATH_TO_EVAL_DIR} 55 | fi 56 | 57 | 58 | if [ $a = export ] ; then 59 | ## From tensorflow/models/research/ 60 | #python object_detection/export_inference_graph.py \ 61 | # --input_type image_tensor \ 62 | # --pipeline_config_path ${PIPELINE_CONFIG_PATH} \ 63 | # --trained_checkpoint_prefix ${TRAIN_PATH} \ 64 | # --output_directory output_inference_graph.pb 65 | 66 | python object_detection/export_inference_graph.py \ 67 | --input_type image_tensor \ 68 | --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 69 | --trained_checkpoint_prefix=`pwd`/object_detection/01_pet_dataset/model/mask_rcnn_inception_v2_coco_2018_01_28/train/model.ckpt-83540 \ 70 | --output_directory=`pwd`/object_detection/01_pet_dataset/model/mask_rcnn_inception_v2_coco_2018_01_28/fine_tuned_model/model-83540 71 | 72 | 73 | #python object_detection/export_inference_graph.py \ 74 | # --input_type image_tensor \ 75 | # --pipeline_config_path=${PATH_TO_YOUR_PIPELINE_CONFIG} \ 76 | # --trained_checkpoint_prefix=`pwd`/object_detection/02_hands_tf_tutorials/Egohands_models/ssd_inception_v2_coco_2017_11_17/train/model.ckpt-45322 \ 77 | # --output_directory=`pwd`/object_detection/02_hands_tf_tutorials/Egohands_models/ssd_inception_v2_coco_2017_11_17/fine_tuned_model/model-45322 78 | fi 79 | -------------------------------------------------------------------------------- /scripts/run_create_kitti_tf_record.sh: -------------------------------------------------------------------------------- 1 | python object_detection/dataset_tools/create_kitti_tf_record.py \ 2 | --data_dir=/workspace/zhangbin/dataset_robin/kitti \ 3 | --label_map_path=object_detection/data/kitti_label_map.pbtxt \ 4 | --output_path=/workspace/zhangbin/dataset_robin/kitti/kitti 5 | -------------------------------------------------------------------------------- /scripts/run_create_pascal_tf_record.sh: -------------------------------------------------------------------------------- 1 | python object_detection/dataset_tools/create_pascal_tf_record.py \ 2 | --data_dir=/workspace/zhangbin/dataset_robin/VOC2012/VOCdevkit \ 3 | --year=VOC2012 \ 4 | --label_map_path=object_detection/data/pascal_label_map.pbtxt \ 5 | --output_path=/workspace/zhangbin/dataset_robin/VOC2012/pascal_2012.record 6 | 7 | -------------------------------------------------------------------------------- /scripts/run_download_and_preprocess_mscoco.sh: -------------------------------------------------------------------------------- 1 | bash object_detection/dataset_tools/download_and_preprocess_mscoco.sh \ 2 | /workspace/zhangbin/dataset_robin/mscoco 3 | -------------------------------------------------------------------------------- /scripts/train_cifarnet_on_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Cifar10 dataset 19 | # 2. Trains a CifarNet model on the Cifar10 training set. 20 | # 3. Evaluates the model on the Cifar10 testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./scripts/train_cifarnet_on_cifar10.sh 25 | #set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=./tmp/cifarnet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/workspace/zhangbin/dataset_robin/cifar10 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=cifar10 \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=cifar10 \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=cifarnet \ 45 | --preprocessing_name=cifarnet \ 46 | --max_number_of_steps=100000 \ 47 | --batch_size=128 \ 48 | --save_interval_secs=120 \ 49 | --save_summaries_secs=120 \ 50 | --log_every_n_steps=100 \ 51 | --optimizer=sgd \ 52 | --learning_rate=0.1 \ 53 | --learning_rate_decay_factor=0.1 \ 54 | --num_epochs_per_decay=200 \ 55 | --weight_decay=0.004 \ 56 | --clone_on_cpu=True 57 | 58 | 59 | # Run evaluation. 60 | python eval_image_classifier.py \ 61 | --checkpoint_path=${TRAIN_DIR} \ 62 | --eval_dir=${TRAIN_DIR} \ 63 | --dataset_name=cifar10 \ 64 | --dataset_split_name=test \ 65 | --dataset_dir=${DATASET_DIR} \ 66 | --model_name=cifarnet 67 | -------------------------------------------------------------------------------- /scripts/train_lenet_on_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the MNIST dataset 19 | # 2. Trains a LeNet model on the MNIST training set. 20 | # 3. Evaluates the model on the MNIST testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/train_lenet_on_mnist.sh 25 | #set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=./tmp/lenet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/workspace/zhangbin/dataset_robin/mnist/ 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=mnist \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=mnist \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=lenet \ 45 | --preprocessing_name=lenet \ 46 | --max_number_of_steps=20000 \ 47 | --batch_size=50 \ 48 | --learning_rate=0.01 \ 49 | --save_interval_secs=60 \ 50 | --save_summaries_secs=60 \ 51 | --log_every_n_steps=100 \ 52 | --optimizer=sgd \ 53 | --learning_rate_decay_type=fixed \ 54 | --weight_decay=0 \ 55 | --clone_on_cpu=True 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=mnist \ 62 | --dataset_split_name=test \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=lenet 65 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup script for slim.""" 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | 21 | setup( 22 | name='slim', 23 | version='0.1', 24 | include_package_data=True, 25 | packages=find_packages(), 26 | description='tf-slim', 27 | ) 28 | -------------------------------------------------------------------------------- /slim_models_demo/First_Student_IC_school_bus_202076.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/slim_models_demo/First_Student_IC_school_bus_202076.jpg -------------------------------------------------------------------------------- /slim_models_demo/Inception_v1_demo.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | 6 | 7 | #%matplotlib inline 8 | from matplotlib import pyplot as plt 9 | try: 10 | import urllib2 as urllib 11 | except ImportError: 12 | import urllib.request as urllib 13 | 14 | from datasets import imagenet 15 | from nets import inception 16 | from preprocessing import inception_preprocessing 17 | 18 | from tensorflow.contrib import slim 19 | 20 | from datasets import dataset_utils 21 | 22 | url = "http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz" 23 | checkpoints_dir = '../tmp/checkpoints' 24 | 25 | # if not tf.gfile.Exists(checkpoints_dir): 26 | # tf.gfile.MakeDirs(checkpoints_dir) 27 | # 28 | # dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir) 29 | 30 | image_size = inception.inception_v1.default_image_size 31 | 32 | with tf.Graph().as_default(): 33 | url = ("https://upload.wikimedia.org/wikipedia/commons/d/d9/" 34 | "First_Student_IC_school_bus_202076.jpg") 35 | image_string = urllib.urlopen(url).read() 36 | image = tf.image.decode_jpeg(image_string, channels=3) 37 | processed_image = inception_preprocessing.preprocess_image(image, image_size, image_size, is_training=False) 38 | processed_images = tf.expand_dims(processed_image, 0) 39 | 40 | # Create the model, use the default arg scope to configure the batch norm parameters. 41 | with slim.arg_scope(inception.inception_v1_arg_scope()): 42 | logits, _ = inception.inception_v1(processed_images, num_classes=1001, is_training=False) 43 | probabilities = tf.nn.softmax(logits) 44 | 45 | init_fn = slim.assign_from_checkpoint_fn( 46 | os.path.join(checkpoints_dir, 'inception_v1.ckpt'), 47 | slim.get_model_variables('InceptionV1')) 48 | 49 | with tf.Session() as sess: 50 | init_fn(sess) 51 | np_image, probabilities = sess.run([image, probabilities]) 52 | probabilities = probabilities[0, 0:] 53 | sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])] 54 | 55 | plt.figure() 56 | plt.imshow(np_image.astype(np.uint8)) 57 | plt.axis('off') 58 | plt.show() 59 | 60 | names = imagenet.create_readable_names_for_imagenet_labels() 61 | for i in range(5): 62 | index = sorted_inds[i] 63 | print('Probability %0.2f%% => [%s]' % (probabilities[index] * 100, names[index])) -------------------------------------------------------------------------------- /slim_models_demo/Inception_v1_demo_locally.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import tensorflow as tf 5 | 6 | 7 | #%matplotlib inline 8 | from matplotlib import pyplot as plt 9 | try: 10 | import urllib2 as urllib 11 | except ImportError: 12 | import urllib.request as urllib 13 | import cv2 14 | from datasets import imagenet 15 | from nets import inception 16 | from preprocessing import inception_preprocessing 17 | 18 | from tensorflow.contrib import slim 19 | 20 | from datasets import dataset_utils 21 | 22 | url = "http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz" 23 | checkpoints_dir = '../tmp/checkpoints' 24 | 25 | # if not tf.gfile.Exists(checkpoints_dir): 26 | # tf.gfile.MakeDirs(checkpoints_dir) 27 | # 28 | # dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir) 29 | 30 | image_size = inception.inception_v1.default_image_size 31 | 32 | with tf.Graph().as_default(): 33 | image =cv2.imread("First_Student_IC_school_bus_202076.jpg") 34 | image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# change channel 35 | image = tf.cast(image, tf.float32) 36 | # print(image.dtype) 37 | # image = tf.image.decode_jpeg(tf.read_file("First_Student_IC_school_bus_202076.jpg"), channels=3) 38 | processed_image = inception_preprocessing.preprocess_image(image, image_size, image_size, is_training=False) 39 | processed_images = tf.expand_dims(processed_image, 0) 40 | 41 | # Create the model, use the default arg scope to configure the batch norm parameters. 42 | with slim.arg_scope(inception.inception_v1_arg_scope()): 43 | logits, _ = inception.inception_v1(processed_images, num_classes=1001, is_training=False) 44 | probabilities = tf.nn.softmax(logits) 45 | 46 | init_fn = slim.assign_from_checkpoint_fn( 47 | os.path.join(checkpoints_dir, 'inception_v1.ckpt'), 48 | slim.get_model_variables('InceptionV1')) 49 | 50 | saver = tf.train.Saver()#method 2 for restore 51 | 52 | with tf.Session() as sess: 53 | # init_fn(sess)#method 1 for restore 54 | saver.restore(sess, os.path.join(checkpoints_dir, 'inception_v1.ckpt'))#method 2 for restore 55 | 56 | np_image ,network_input ,probabilities = sess.run([image,processed_image,probabilities]) 57 | 58 | print(probabilities.shape) 59 | probabilities = probabilities[0,:] 60 | sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), 61 | key=lambda x:x[1])] 62 | 63 | 64 | plt.figure() 65 | plt.imshow(np_image.astype(np.uint8)) 66 | plt.suptitle("Downloaded image", fontsize=14, fontweight='bold') 67 | plt.axis('off') 68 | plt.show() 69 | 70 | # to show the image. 71 | plt.imshow( network_input) 72 | plt.suptitle("Resized, Cropped and Mean-Centered input to network", 73 | fontsize=14, fontweight='bold') 74 | plt.axis('off') 75 | plt.show() 76 | 77 | names = imagenet.create_readable_names_for_imagenet_labels() 78 | for i in range(5): 79 | index = sorted_inds[i] 80 | print('Probability %0.2f%% => [%s]' % (probabilities[index] * 100, names[index])) -------------------------------------------------------------------------------- /slim_models_demo/frozen_graph.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.framework import graph_util 3 | import os 4 | import PIL.Image as Image 5 | import numpy as np 6 | def freeze_graph(model_dir, output_node_names): 7 | """ 8 | freeze the saved checkpoints/graph to *.pb 9 | """ 10 | checkpoint = tf.train.get_checkpoint_state(model_dir) 11 | input_checkpoint = checkpoint.model_checkpoint_path 12 | 13 | output_graph = os.path.join(model_dir, "frozen_graph.pb") 14 | 15 | saver = tf.train.import_meta_graph(input_checkpoint + ".meta", 16 | clear_devices=True) 17 | 18 | graph = tf.get_default_graph() 19 | input_graph_def = graph.as_graph_def() 20 | 21 | with tf.Session() as sess: 22 | saver.restore(sess, input_checkpoint) 23 | 24 | output_graph_def = graph_util.convert_variables_to_constants(sess, 25 | input_graph_def, 26 | output_node_names.split(",")) 27 | 28 | with tf.gfile.GFile(output_graph, "wb") as f: 29 | f.write(output_graph_def.SerializeToString()) 30 | print("%d ops in the final graph" % (len(output_graph_def.node))) 31 | 32 | 33 | def load_graph(frozen_graph_filename): 34 | """ 35 | Loads Frozen graph 36 | """ 37 | with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 38 | graph_def = tf.GraphDef() 39 | graph_def.ParseFromString(f.read()) 40 | 41 | with tf.Graph().as_default() as graph: 42 | tf.import_graph_def(graph_def) 43 | return graph 44 | 45 | #mobilenet 46 | # freeze_graph("tf_files/mobilenet/", output_node_names="final_result") 47 | # graph = load_graph("tf_files/mobilenet/frozen_graph.pb") 48 | # 49 | # for op in graph.get_operations(): 50 | # print(op.name) 51 | # 52 | # 53 | # input_x = graph.get_tensor_by_name("import/input:0") 54 | # print(input_x) 55 | # out = graph.get_tensor_by_name("import/final_result:0") 56 | # print(out) 57 | # 58 | # input_operation = graph.get_operation_by_name('import/input') 59 | # print(input_operation.outputs[0]) 60 | # output_operation = graph.get_operation_by_name('import/final_result') 61 | # print(output_operation.outputs[0]) 62 | 63 | #inception 64 | freeze_graph("../tmp/checkpoints/with_placeholder/", output_node_names="vgg_16/fc8/squeezed") 65 | graph = load_graph("../tmp/checkpoints/with_placeholder/frozen_graph.pb") 66 | 67 | for op in graph.get_operations(): 68 | print(op.name) 69 | 70 | 71 | input_x = graph.get_tensor_by_name("import/input:0") 72 | print(input_x) 73 | out = graph.get_tensor_by_name("import/vgg_16/fc8/squeezed:0") 74 | print(out) 75 | 76 | input_operation = graph.get_operation_by_name('import/input') 77 | print(input_operation.outputs[0]) 78 | output_operation = graph.get_operation_by_name('import/vgg_16/fc8/squeezed') 79 | print(output_operation.outputs[0]) 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /slim_models_demo/label_flower_from_pb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import tensorflow as tf 7 | import numpy as np 8 | import PIL.Image as Image 9 | from pylab import * 10 | import time 11 | from tensorflow.python.platform import gfile 12 | import os 13 | from datasets import imagenet 14 | 15 | 16 | # load the graph from pb file 17 | def load_graph(model_file): 18 | graph = tf.Graph() 19 | graph_def = tf.GraphDef() 20 | with open(model_file, "rb") as f: 21 | graph_def.ParseFromString(f.read()) 22 | with graph.as_default(): 23 | output_tensor,input_tensor = tf.import_graph_def(graph_def, name='', 24 | return_elements=["Softmax:0","DecodeJpeg/contents:0"]) 25 | with tf.Session() as sess: 26 | ops = sess.graph.get_operations() 27 | for op in ops: 28 | print(op.name) 29 | return graph, output_tensor,input_tensor 30 | 31 | # classify the picture and print the result 32 | def recognize(jpg_path, pb_file_path): 33 | with tf.Graph().as_default(): 34 | 35 | graph, output_tensor,input_tensor = load_graph(pb_file_path) 36 | 37 | with tf.Session(graph=graph) as sess: 38 | # # get the input tensor operation 39 | # input_x = graph.get_tensor_by_name("import/DecodeJpeg/contents:0") 40 | # # get the output tensor operation 41 | # output = graph.get_tensor_by_name("import/Softmax:0") 42 | # read the image 43 | image_data = gfile.FastGFile(jpg_path, 'rb').read() 44 | t1 = time.time() 45 | pre = sess.run(output_tensor, feed_dict={input_tensor:image_data}) 46 | t2 = time.time() 47 | writer = tf.summary.FileWriter("./logs_from_pb", graph=tf.get_default_graph()) 48 | 49 | results = np.squeeze(pre) 50 | prediction_labels = np.argmax(results, axis=0) 51 | names = imagenet.create_readable_names_for_imagenet_labels() 52 | top_k = results.argsort()[-5:][::-1] 53 | for i in top_k: 54 | print(names[i+1], results[i]) 55 | 56 | print('probability: %s: %.3g, running time: %.3g' % (names[prediction_labels+1],results[prediction_labels], t2-t1)) 57 | 58 | 59 | if __name__=="__main__": 60 | jpg_path = "First_Student_IC_school_bus_202076.jpg" 61 | pb_file_path=os.path.join("../tmp/checkpoints", 'vgg_16_freeze_graph.pb') 62 | recognize(jpg_path, pb_file_path) 63 | 64 | -------------------------------------------------------------------------------- /slim_models_demo/print_ops_from_pb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | 6 | 7 | # print all op names 8 | def print_ops(pb_path): 9 | with tf.gfile.FastGFile(os.path.join(pb_path), 'rb') as f: 10 | graph_def = tf.GraphDef() 11 | graph_def.ParseFromString(f.read()) 12 | _ = tf.import_graph_def(graph_def, name='') 13 | 14 | with tf.Session() as sess: 15 | 16 | ops = sess.graph.get_operations() 17 | for op in ops: 18 | print(op.name) 19 | 20 | writer =tf.summary.FileWriter("log_print_ops/",sess.graph) 21 | writer.close() 22 | 23 | 24 | checkpoints_dir = '../tmp/checkpoints' 25 | # print_ops(os.path.join(checkpoints_dir, 'vgg_16_freeze_graph.pb')) 26 | print_ops("../tmp/checkpoints/with_placeholder/frozen_graph.pb") 27 | -------------------------------------------------------------------------------- /slim_models_demo/resnet_demo.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import sys 6 | import os 7 | 8 | #%matplotlib inline 9 | from matplotlib import pyplot as plt 10 | 11 | import numpy as np 12 | import os 13 | import tensorflow as tf 14 | import urllib2 15 | 16 | from datasets import imagenet 17 | from nets import resnet_v1 18 | from preprocessing import vgg_preprocessing 19 | import cv2 20 | 21 | 22 | checkpoints_dir = '../tmp/checkpoints/' 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | url = "http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz" 28 | checkpoints_dir = '../tmp/checkpoints' 29 | 30 | # if not tf.gfile.Exists(checkpoints_dir): 31 | # tf.gfile.MakeDirs(checkpoints_dir) 32 | # 33 | # dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir) 34 | 35 | # 网络模型的输入图像有默认的尺寸 36 | # 因此,我们需要先调整输入图片的尺寸 37 | image_size = resnet_v1.resnet_v1_50.default_image_size 38 | 39 | with tf.Graph().as_default(): 40 | 41 | image =cv2.imread("First_Student_IC_school_bus_202076.jpg") 42 | image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# change channel 43 | # image = image [:, :, (2, 1, 0)] # change channel 44 | 45 | 46 | # 对图片做缩放操作,保持长宽比例不变,裁剪得到图片中央的区域 47 | # 裁剪后的图片大小等于网络模型的默认尺寸 48 | processed_image = vgg_preprocessing.preprocess_image(image, 49 | image_size, 50 | image_size, 51 | is_training=False) 52 | 53 | # 可以批量导入图像 54 | # 第一个维度指定每批图片的张数 55 | # 我们每次只导入一张图片 56 | processed_images = tf.expand_dims(processed_image, 0) 57 | 58 | # 创建模型,使用默认的arg scope参数 59 | # arg_scope是slim library的一个常用参数 60 | # 可以设置它指定网络层的参数,比如stride, padding 等等。 61 | with slim.arg_scope(resnet_v1.resnet_arg_scope()): 62 | logits, _ = resnet_v1.resnet_v1_50(processed_images, 63 | num_classes=1000, 64 | is_training=False) 65 | 66 | # 我们在输出层使用softmax函数,使输出项是概率值 67 | probabilities = tf.nn.softmax(logits) 68 | 69 | # 创建一个函数,从checkpoint读入网络权值 70 | init_fn = slim.assign_from_checkpoint_fn( 71 | os.path.join(checkpoints_dir, 'resnet_v1_50.ckpt'), 72 | slim.get_model_variables('resnet_v1_50')) 73 | 74 | with tf.Session() as sess: 75 | # 加载权值 76 | init_fn(sess) 77 | 78 | ops = sess.graph.get_operations() 79 | for op in ops: 80 | print(op.name) 81 | 82 | print("Parameters") 83 | for v in slim.get_model_variables(): 84 | print('name = {}, shape = {}'.format(v.name, v.get_shape())) 85 | 86 | 87 | writer = tf.summary.FileWriter("./logs_resnet", graph=tf.get_default_graph()) 88 | 89 | print("Finish!") 90 | 91 | # 图片经过缩放和裁剪,最终以numpy矩阵的格式传入网络模型 92 | network_input, probabilities = sess.run([processed_image, 93 | probabilities]) 94 | probabilities = probabilities[0, 0:] 95 | sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), 96 | key=lambda x:x[1])] 97 | 98 | 99 | 100 | names = imagenet.create_readable_names_for_imagenet_labels() 101 | for i in range(5): 102 | index = sorted_inds[i] 103 | # 打印top5的预测类别和相应的概率值。 104 | print('Probability %0.2f => [%s]' % (probabilities[index], names[index+1])) 105 | 106 | 107 | -------------------------------------------------------------------------------- /slim_models_demo/test_image_classifier.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import math 7 | import tensorflow as tf 8 | import numpy as np 9 | from nets import nets_factory 10 | from preprocessing import preprocessing_factory 11 | from datasets import imagenet 12 | slim = tf.contrib.slim 13 | 14 | tf.app.flags.DEFINE_string( 15 | 'master', '', 'The address of the TensorFlow master to use.') 16 | 17 | tf.app.flags.DEFINE_string( 18 | 'checkpoint_path', '../tmp/checkpoints/with_placeholder', 19 | 'The directory where the model was written to or an absolute path to a ' 20 | 'checkpoint file.') 21 | 22 | tf.app.flags.DEFINE_string( 23 | 'test_path', 'First_Student_IC_school_bus_202076.jpg', 'Test image path.') 24 | 25 | tf.app.flags.DEFINE_integer( 26 | 'num_classes', 1000, 'Number of classes.') 27 | 28 | tf.app.flags.DEFINE_integer( 29 | 'labels_offset', 0, 30 | 'An offset for the labels in the dataset. This flag is primarily used to ' 31 | 'evaluate the VGG and ResNet architectures which do not use a background ' 32 | 'class for the ImageNet dataset.') 33 | 34 | tf.app.flags.DEFINE_string( 35 | 'model_name', 'vgg_16', 'The name of the architecture to evaluate.') 36 | 37 | tf.app.flags.DEFINE_string( 38 | 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 39 | 'as `None`, then the model_name flag is used.') 40 | 41 | tf.app.flags.DEFINE_integer( 42 | 'test_image_size', None, 'Eval image size') 43 | 44 | FLAGS = tf.app.flags.FLAGS 45 | 46 | 47 | def main(_): 48 | if not FLAGS.test_path: 49 | raise ValueError('You must supply the test list with --test_path') 50 | 51 | tf.logging.set_verbosity(tf.logging.INFO) 52 | with tf.Graph().as_default(): 53 | # tf_global_step = slim.get_or_create_global_step() 54 | 55 | #################### 56 | # Select the model # 57 | #################### 58 | network_fn = nets_factory.get_network_fn( 59 | FLAGS.model_name, 60 | num_classes=(FLAGS.num_classes - FLAGS.labels_offset), 61 | is_training=False) 62 | 63 | ##################################### 64 | # Select the preprocessing function # 65 | ##################################### 66 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 67 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 68 | preprocessing_name, 69 | is_training=False) 70 | 71 | test_image_size = FLAGS.test_image_size or network_fn.default_image_size 72 | 73 | if tf.gfile.IsDirectory(FLAGS.checkpoint_path): 74 | checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path) 75 | else: 76 | checkpoint_path = FLAGS.checkpoint_path 77 | print("restore from",checkpoint_path) 78 | tf.Graph().as_default() 79 | with tf.Session() as sess: 80 | image = open(FLAGS.test_path, 'rb').read() 81 | image = tf.image.decode_jpeg(image, channels=3) 82 | processed_image = image_preprocessing_fn(image, test_image_size, test_image_size) 83 | processed_images = tf.expand_dims(processed_image, 0) 84 | 85 | logits, _ = network_fn(processed_images) 86 | probabilities = tf.nn.softmax(logits) 87 | saver = tf.train.Saver() 88 | saver.restore(sess, checkpoint_path) 89 | 90 | np_image, network_input, predictions = sess.run([image, processed_image, probabilities]) 91 | probabilities = np.squeeze(predictions,0) 92 | names = imagenet.create_readable_names_for_imagenet_labels() 93 | 94 | pre = np.argmax(probabilities, axis=0) 95 | print('{} {} {}'.format(FLAGS.test_path,pre ,names[pre+1])) 96 | top_k = probabilities.argsort()[-5:][::-1] 97 | for index in top_k: 98 | print('Probability %0.2f => [%s]' % (probabilities[index], names[index+1])) 99 | 100 | if __name__ == '__main__': 101 | tf.app.run() -------------------------------------------------------------------------------- /slim_models_demo/vgg_demo.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import sys 3 | import os 4 | 5 | #%matplotlib inline 6 | from matplotlib import pyplot as plt 7 | 8 | import numpy as np 9 | import os 10 | import tensorflow as tf 11 | import urllib2 12 | 13 | from datasets import imagenet 14 | from nets import vgg 15 | from preprocessing import vgg_preprocessing 16 | from tensorflow.python.framework import graph_util 17 | from tensorflow.python.platform import gfile 18 | 19 | checkpoints_dir = '../tmp/checkpoints/' 20 | 21 | slim = tf.contrib.slim 22 | 23 | #download the vgg_16_2016_08_28.tar.gz checkpoint from models 24 | url = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz" 25 | checkpoints_dir = '../tmp/checkpoints' 26 | 27 | # if not tf.gfile.Exists(checkpoints_dir): 28 | # tf.gfile.MakeDirs(checkpoints_dir) 29 | # 30 | # dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir) 31 | 32 | 33 | #set the default image_seze 34 | image_size = vgg.vgg_16.default_image_size 35 | 36 | 37 | 38 | def save_graph_to_file(sess, graph, graph_file_name): 39 | output_graph_def = graph_util.convert_variables_to_constants( 40 | sess, graph, ["Softmax"]) 41 | with gfile.FastGFile(graph_file_name, 'wb') as f: 42 | f.write(output_graph_def.SerializeToString()) 43 | return 44 | 45 | with tf.Graph().as_default(): 46 | 47 | url = ("https://upload.wikimedia.org/wikipedia/commons/d/d9/" 48 | "First_Student_IC_school_bus_202076.jpg") 49 | 50 | # connect the internet and download it 51 | image_string = urllib2.urlopen(url).read() 52 | 53 | #decode the image 54 | image = tf.image.decode_jpeg(image_string, channels=3) 55 | 56 | # 对图片做缩放操作,保持长宽比例不变,裁剪得到图片中央的区域 57 | # 裁剪后的图片大小等于网络模型的默认尺寸 58 | processed_image = vgg_preprocessing.preprocess_image(image, 59 | image_size, 60 | image_size, 61 | is_training=False) 62 | 63 | # 可以批量导入图像 64 | # 第一个维度指定每批图片的张数 65 | # 我们每次只导入一张图片 66 | processed_images = tf.expand_dims(processed_image, 0) 67 | 68 | # 创建模型,使用默认的arg scope参数 69 | # arg_scope是slim library的一个常用参数 70 | # 可以设置它指定网络层的参数,比如stride, padding 等等。 71 | with slim.arg_scope(vgg.vgg_arg_scope()): 72 | logits, _ = vgg.vgg_16(processed_images, 73 | num_classes=1000, 74 | is_training=False) 75 | 76 | # 我们在输出层使用softmax函数,使输出项是概率值 77 | probabilities = tf.nn.softmax(logits) 78 | 79 | # 创建一个函数,从checkpoint读入网络权值 80 | init_fn = slim.assign_from_checkpoint_fn( 81 | os.path.join(checkpoints_dir, 'vgg_16.ckpt'), 82 | slim.get_model_variables('vgg_16')) 83 | 84 | with tf.Session() as sess: 85 | 86 | # 加载权值 87 | init_fn(sess) 88 | 89 | print("network operation") 90 | ops = sess.graph.get_operations() 91 | for op in ops: 92 | print(op.name) 93 | save_graph_to_file(sess,sess.graph_def ,os.path.join(checkpoints_dir, 'vgg_16_freeze_graph.pb')) 94 | # 图片经过缩放和裁剪,最终以numpy矩阵的格式传入网络模型 95 | np_image, network_input, probabilities = sess.run([image, 96 | processed_image, 97 | probabilities]) 98 | probabilities = probabilities[0, 0:] 99 | sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), 100 | key=lambda x:x[1])] 101 | 102 | # 显示下载的图片 103 | plt.figure() 104 | plt.imshow(np_image.astype(np.uint8)) 105 | plt.suptitle("Downloaded image", fontsize=14, fontweight='bold') 106 | plt.axis('off') 107 | plt.show() 108 | 109 | # 显示最终传入网络模型的图片 110 | # 图像的像素值做了[-1, 1]的归一化 111 | # to show the image. 112 | plt.imshow( network_input / (network_input.max() - network_input.min()) ) 113 | plt.suptitle("Resized, Cropped and Mean-Centered input to network", 114 | fontsize=14, fontweight='bold') 115 | plt.axis('off') 116 | plt.show() 117 | 118 | names = imagenet.create_readable_names_for_imagenet_labels() 119 | for i in range(5): 120 | index = sorted_inds[i] 121 | # 打印top5的预测类别和相应的概率值。 122 | print('Probability %0.2f => [%s]' % (probabilities[index], names[index+1])) 123 | 124 | res = slim.get_model_variables() 125 | 126 | -------------------------------------------------------------------------------- /slim_models_demo/vgg_demo_Segmentation.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import sys 3 | import os 4 | 5 | 6 | 7 | #%matplotlib inline 8 | from matplotlib import pyplot as plt 9 | 10 | import numpy as np 11 | import os 12 | import tensorflow as tf 13 | import urllib2 14 | 15 | from datasets import imagenet 16 | from nets import vgg 17 | from preprocessing import vgg_preprocessing 18 | import cv2 19 | 20 | from preprocessing import vgg_preprocessing 21 | 22 | 23 | checkpoints_dir = '../tmp/checkpoints/' 24 | 25 | slim = tf.contrib.slim 26 | # 加载像素均值及相关函数 27 | from preprocessing.vgg_preprocessing import (_mean_image_subtraction, 28 | _R_MEAN, _G_MEAN, _B_MEAN) 29 | 30 | # 展现分割结果的函数,以不同的颜色区分各个类别 31 | def discrete_matshow(data, labels_names=[], title=""): 32 | #获取离散化的色彩表 33 | cmap = plt.get_cmap('Paired', np.max(data)-np.min(data)+1) 34 | mat = plt.matshow(data, 35 | cmap=cmap, 36 | vmin = np.min(data)-.5, 37 | vmax = np.max(data)+.5) 38 | #在色彩表的整数刻度做记号 39 | cax = plt.colorbar(mat, 40 | ticks=np.arange(np.min(data),np.max(data)+1)) 41 | 42 | # 添加类别的名称 43 | if labels_names: 44 | cax.ax.set_yticklabels(labels_names) 45 | 46 | if title: 47 | plt.suptitle(title, fontsize=14, fontweight='bold') 48 | plt.show() 49 | 50 | with tf.Graph().as_default(): 51 | 52 | url = ("https://upload.wikimedia.org/wikipedia/commons/d/d9/" 53 | "First_Student_IC_school_bus_202076.jpg") 54 | 55 | image_string = urllib2.urlopen(url).read() 56 | image = tf.image.decode_jpeg(image_string, channels=3) 57 | 58 | # 减去均值之前,将像素值转为32位浮点 59 | image_float = tf.to_float(image, name='ToFloat') 60 | 61 | # 每个像素减去像素的均值 62 | processed_image = _mean_image_subtraction(image_float, 63 | [_R_MEAN, _G_MEAN, _B_MEAN]) 64 | 65 | input_image = tf.expand_dims(processed_image, 0) 66 | 67 | with slim.arg_scope(vgg.vgg_arg_scope()): 68 | 69 | # spatial_squeeze选项指定是否启用全卷积模式 70 | logits, _ = vgg.vgg_16(input_image, 71 | num_classes=1000, 72 | is_training=False, 73 | spatial_squeeze=False) 74 | 75 | # 得到每个像素点在所有1000个类别下的概率值,挑选出每个像素概率最大的类别 76 | # 严格说来,这并不是概率值,因为我们没有调用softmax函数 77 | # 但效果等同于softmax输出值最大的类别 78 | pred = tf.argmax(logits, dimension=3) 79 | 80 | init_fn = slim.assign_from_checkpoint_fn( 81 | os.path.join(checkpoints_dir, 'vgg_16.ckpt'), 82 | slim.get_model_variables('vgg_16')) 83 | 84 | with tf.Session() as sess: 85 | init_fn(sess) 86 | segmentation, np_image = sess.run([pred, image]) 87 | 88 | # 去除空的维度 89 | segmentation = np.squeeze(segmentation) 90 | 91 | unique_classes, relabeled_image = np.unique(segmentation, 92 | return_inverse=True) 93 | 94 | segmentation_size = segmentation.shape 95 | 96 | relabeled_image = relabeled_image.reshape(segmentation_size) 97 | 98 | labels_names = [] 99 | names = imagenet.create_readable_names_for_imagenet_labels() 100 | 101 | 102 | for index, current_class_number in enumerate(unique_classes): 103 | 104 | labels_names.append(str(index) + ' ' + names[current_class_number+1]) 105 | 106 | discrete_matshow(data=relabeled_image, labels_names=labels_names, title="Segmentation") 107 | 108 | res = slim.get_model_variables() 109 | -------------------------------------------------------------------------------- /slim_models_demo/vgg_demo_locally.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | import sys 6 | import os 7 | 8 | #%matplotlib inline 9 | from matplotlib import pyplot as plt 10 | 11 | import numpy as np 12 | import os 13 | import tensorflow as tf 14 | import urllib2 15 | 16 | from datasets import imagenet 17 | from nets import vgg 18 | from preprocessing import vgg_preprocessing 19 | import cv2 20 | 21 | checkpoints_dir = '../tmp/checkpoints/' 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | 27 | # 网络模型的输入图像有默认的尺寸 28 | # 因此,我们需要先调整输入图片的尺寸 29 | image_size = vgg.vgg_16.default_image_size 30 | 31 | with tf.Graph().as_default(): 32 | 33 | image =cv2.imread("First_Student_IC_school_bus_202076.jpg") 34 | image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)# change channel 35 | # image = image [:, :, (2, 1, 0)] # change channel 36 | 37 | 38 | # 对图片做缩放操作,保持长宽比例不变,裁剪得到图片中央的区域 39 | # 裁剪后的图片大小等于网络模型的默认尺寸 40 | processed_image = vgg_preprocessing.preprocess_image(image, 41 | image_size, 42 | image_size, 43 | is_training=False) 44 | 45 | # 可以批量导入图像 46 | # 第一个维度指定每批图片的张数 47 | # 我们每次只导入一张图片 48 | processed_images = tf.expand_dims(processed_image, 0) 49 | 50 | # 创建模型,使用默认的arg scope参数 51 | # arg_scope是slim library的一个常用参数 52 | # 可以设置它指定网络层的参数,比如stride, padding 等等。 53 | with slim.arg_scope(vgg.vgg_arg_scope()): 54 | logits, _ = vgg.vgg_16(processed_images, 55 | num_classes=1000, 56 | is_training=False) 57 | 58 | # 我们在输出层使用softmax函数,使输出项是概率值 59 | probabilities = tf.nn.softmax(logits) 60 | 61 | # 创建一个函数,从checkpoint读入网络权值 62 | init_fn = slim.assign_from_checkpoint_fn( 63 | os.path.join(checkpoints_dir, 'vgg_16.ckpt'), 64 | slim.get_model_variables('vgg_16')) 65 | 66 | with tf.Session() as sess: 67 | 68 | # 加载权值 69 | init_fn(sess) 70 | 71 | # 图片经过缩放和裁剪,最终以numpy矩阵的格式传入网络模型 72 | network_input, probabilities = sess.run([processed_image, 73 | probabilities]) 74 | probabilities = probabilities[0, 0:] 75 | sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), 76 | key=lambda x:x[1])] 77 | 78 | # 显示下载的图片 79 | plt.figure() 80 | plt.imshow(image.astype(np.uint8)) 81 | plt.suptitle("Downloaded image", fontsize=14, fontweight='bold') 82 | plt.axis('off') 83 | plt.show() 84 | 85 | 86 | # 显示最终传入网络模型的图片 87 | # 图像的像素值做了[-1, 1]的归一化 88 | # to show the image. 89 | plt.imshow( network_input / (network_input.max() - network_input.min()) ) 90 | plt.suptitle("Resized, Cropped and Mean-Centered input to network", 91 | fontsize=14, fontweight='bold') 92 | plt.axis('off') 93 | plt.show() 94 | 95 | cv2.imshow("Downloaded image",image) 96 | cv2.imshow("Resized, Cropped and Mean-Centered input to network",network_input) 97 | cv2.waitKey(0) 98 | 99 | names = imagenet.create_readable_names_for_imagenet_labels() 100 | for i in range(5): 101 | index = sorted_inds[i] 102 | # 打印top5的预测类别和相应的概率值。 103 | print('Probability %0.2f => [%s]' % (probabilities[index], names[index+1])) 104 | 105 | # res = slim.get_model_variables() 106 | 107 | -------------------------------------------------------------------------------- /tfrecord_fine_tune_model_for_other_set_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from datasets import flowers 4 | from nets import inception 5 | from preprocessing import inception_preprocessing 6 | 7 | from tensorflow.contrib import slim 8 | 9 | 10 | image_size = inception.inception_v1.default_image_size 11 | batch_size = 3 12 | flowers_data_dir = "/home/robin/Dataset/flowers" 13 | 14 | 15 | def load_batch(dataset, batch_size=32, height=299, width=299, is_training=False): 16 | """Loads a single batch of data. 17 | 18 | Args: 19 | dataset: The dataset to load. 20 | batch_size: The number of images in the batch. 21 | height: The size of each image after preprocessing. 22 | width: The size of each image after preprocessing. 23 | is_training: Whether or not we're currently training or evaluating. 24 | 25 | Returns: 26 | images: A Tensor of size [batch_size, height, width, 3], image samples that have been preprocessed. 27 | images_raw: A Tensor of size [batch_size, height, width, 3], image samples that can be used for visualization. 28 | labels: A Tensor of size [batch_size], whose values range between 0 and dataset.num_classes. 29 | """ 30 | data_provider = slim.dataset_data_provider.DatasetDataProvider( 31 | dataset, common_queue_capacity=32, 32 | common_queue_min=8) 33 | image_raw, label = data_provider.get(['image', 'label']) 34 | 35 | # Preprocess image for usage by Inception. 36 | image = inception_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training) 37 | 38 | # Preprocess the image for display purposes. 39 | image_raw = tf.expand_dims(image_raw, 0) 40 | image_raw = tf.image.resize_images(image_raw, [height, width]) 41 | image_raw = tf.squeeze(image_raw) 42 | 43 | # Batch it up. 44 | images, images_raw, labels = tf.train.batch( 45 | [image, image_raw, label], 46 | batch_size=batch_size, 47 | num_threads=1, 48 | capacity=2 * batch_size) 49 | 50 | return images, images_raw, labels 51 | 52 | 53 | def get_init_fn(): 54 | """Returns a function run by the chief worker to warm-start the training.""" 55 | checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"] 56 | 57 | exclusions = [scope.strip() for scope in checkpoint_exclude_scopes] 58 | 59 | variables_to_restore = [] 60 | for var in slim.get_model_variables(): 61 | for exclusion in exclusions: 62 | if var.op.name.startswith(exclusion): 63 | break 64 | else: 65 | variables_to_restore.append(var) 66 | 67 | return slim.assign_from_checkpoint_fn( 68 | os.path.join(checkpoints_dir, 'inception_v1.ckpt'), 69 | variables_to_restore) 70 | 71 | 72 | train_dir = '/tmp/inception_finetuned/' 73 | 74 | with tf.Graph().as_default(): 75 | tf.logging.set_verbosity(tf.logging.INFO) 76 | 77 | dataset = flowers.get_split('train', flowers_data_dir) 78 | images, _, labels = load_batch(dataset, height=image_size, width=image_size) 79 | 80 | # Create the model, use the default arg scope to configure the batch norm parameters. 81 | with slim.arg_scope(inception.inception_v1_arg_scope()): 82 | logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True) 83 | 84 | # Specify the loss function: 85 | one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes) 86 | slim.losses.softmax_cross_entropy(logits, one_hot_labels) 87 | total_loss = slim.losses.get_total_loss() 88 | 89 | # Create some summaries to visualize the training process: 90 | tf.summary.scalar('losses/Total Loss', total_loss) 91 | 92 | # Specify the optimizer and create the train op: 93 | optimizer = tf.train.AdamOptimizer(learning_rate=0.01) 94 | train_op = slim.learning.create_train_op(total_loss, optimizer) 95 | 96 | # Run the training: 97 | final_loss = slim.learning.train( 98 | train_op, 99 | logdir=train_dir, 100 | init_fn=get_init_fn(), 101 | number_of_steps=2) 102 | 103 | 104 | print('Finished training. Last batch loss %f' % final_loss) -------------------------------------------------------------------------------- /tfrecord_inference_model_for_images_demo.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from datasets import flowers 4 | from preprocessing import inception_preprocessing 5 | from nets import inception 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from tensorflow.contrib import slim 10 | 11 | image_size = inception.inception_v1.default_image_size 12 | batch_size = 3 13 | flowers_data_dir = "/home/robin/Dataset/flowers" 14 | train_dir = "" 15 | 16 | def load_batch(dataset, batch_size=32, height=299, width=299, is_training=False): 17 | """Loads a single batch of data. 18 | 19 | Args: 20 | dataset: The dataset to load. 21 | batch_size: The number of images in the batch. 22 | height: The size of each image after preprocessing. 23 | width: The size of each image after preprocessing. 24 | is_training: Whether or not we're currently training or evaluating. 25 | 26 | Returns: 27 | images: A Tensor of size [batch_size, height, width, 3], image samples that have been preprocessed. 28 | images_raw: A Tensor of size [batch_size, height, width, 3], image samples that can be used for visualization. 29 | labels: A Tensor of size [batch_size], whose values range between 0 and dataset.num_classes. 30 | """ 31 | data_provider = slim.dataset_data_provider.DatasetDataProvider( 32 | dataset, common_queue_capacity=32, 33 | common_queue_min=8) 34 | image_raw, label = data_provider.get(['image', 'label']) 35 | 36 | # Preprocess image for usage by Inception. 37 | image = inception_preprocessing.preprocess_image(image_raw, height, width, is_training=is_training) 38 | 39 | # Preprocess the image for display purposes. 40 | image_raw = tf.expand_dims(image_raw, 0) 41 | image_raw = tf.image.resize_images(image_raw, [height, width]) 42 | image_raw = tf.squeeze(image_raw) 43 | 44 | # Batch it up. 45 | images, images_raw, labels = tf.train.batch( 46 | [image, image_raw, label], 47 | batch_size=batch_size, 48 | num_threads=1, 49 | capacity=2 * batch_size) 50 | 51 | return images, images_raw, labels 52 | 53 | with tf.Graph().as_default(): 54 | tf.logging.set_verbosity(tf.logging.INFO) 55 | 56 | dataset = flowers.get_split('train', flowers_data_dir) 57 | images, images_raw, labels = load_batch(dataset, height=image_size, width=image_size) 58 | 59 | # Create the model, use the default arg scope to configure the batch norm parameters. 60 | with slim.arg_scope(inception.inception_v1_arg_scope()): 61 | logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True) 62 | 63 | probabilities = tf.nn.softmax(logits) 64 | 65 | checkpoint_path = tf.train.latest_checkpoint(train_dir) 66 | init_fn = slim.assign_from_checkpoint_fn( 67 | checkpoint_path, 68 | slim.get_variables_to_restore()) 69 | 70 | with tf.Session() as sess: 71 | with slim.queues.QueueRunners(sess): 72 | sess.run(tf.initialize_local_variables()) 73 | init_fn(sess) 74 | np_probabilities, np_images_raw, np_labels = sess.run([probabilities, images_raw, labels]) 75 | 76 | for i in range(batch_size): 77 | image = np_images_raw[i, :, :, :] 78 | true_label = np_labels[i] 79 | predicted_label = np.argmax(np_probabilities[i, :]) 80 | predicted_name = dataset.labels_to_names[predicted_label] 81 | true_name = dataset.labels_to_names[true_label] 82 | 83 | plt.figure() 84 | plt.imshow(image.astype(np.uint8)) 85 | plt.title('Ground Truth: [%s], Prediction [%s]' % (true_name, predicted_name)) 86 | plt.axis('off') 87 | plt.show() -------------------------------------------------------------------------------- /tools/import_pb_to_tensorboard.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ================================ 15 | """Imports a protobuf model as a graph in Tensorboard.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import argparse 22 | import sys 23 | 24 | from tensorflow.core.framework import graph_pb2 25 | from tensorflow.python.client import session 26 | from tensorflow.python.framework import importer 27 | from tensorflow.python.framework import ops 28 | from tensorflow.python.platform import app 29 | from tensorflow.python.platform import gfile 30 | from tensorflow.python.summary import summary 31 | 32 | 33 | def import_to_tensorboard(model_dir, log_dir): 34 | """View an imported protobuf model (`.pb` file) as a graph in Tensorboard. 35 | 36 | Args: 37 | model_dir: The location of the protobuf (`pb`) model to visualize 38 | log_dir: The location for the Tensorboard log to begin visualization from. 39 | 40 | Usage: 41 | Call this function with your model location and desired log directory. 42 | Launch Tensorboard by pointing it to the log directory. 43 | View your imported `.pb` model as a graph. 44 | """ 45 | with session.Session(graph=ops.Graph()) as sess: 46 | with gfile.FastGFile(model_dir, "rb") as f: 47 | graph_def = graph_pb2.GraphDef() 48 | graph_def.ParseFromString(f.read()) 49 | importer.import_graph_def(graph_def) 50 | 51 | pb_visual_writer = summary.FileWriter(log_dir) 52 | pb_visual_writer.add_graph(sess.graph) 53 | print("Model Imported. Visualize by running: " 54 | "tensorboard --logdir={}".format(log_dir)) 55 | 56 | 57 | def main(unused_args): 58 | import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir) 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | parser.register("type", "bool", lambda v: v.lower() == "true") 63 | parser.add_argument( 64 | "--model_dir", 65 | type=str, 66 | default="", 67 | required=True, 68 | help="The location of the protobuf (\'pb\') model to visualize.") 69 | parser.add_argument( 70 | "--log_dir", 71 | type=str, 72 | default="", 73 | required=True, 74 | help="The location for the Tensorboard log to begin visualization from.") 75 | FLAGS, unparsed = parser.parse_known_args() 76 | app.run(main=main, argv=[sys.argv[0]] + unparsed) 77 | -------------------------------------------------------------------------------- /tools/insert_placeholder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from nets import vgg 4 | from tensorflow.python.training import saver as saver_lib 5 | from tensorflow.python import pywrap_tensorflow 6 | slim = tf.contrib.slim 7 | input_checkpoint = '../tmp/checkpoints/vgg_16.ckpt' 8 | 9 | # Where to save the modified graph 10 | save_path = '../tmp/checkpoints/with_placeholder' 11 | 12 | # TODO(shizehao): use graph editor library insead 13 | with tf.Graph().as_default() as graph: 14 | input_images = tf.placeholder(tf.float32, [None, 224, 224, 3], 'input') 15 | with slim.arg_scope(vgg.vgg_arg_scope()): 16 | logits, _ = vgg.vgg_16(input_images, 17 | num_classes=1000, 18 | is_training=False) 19 | 20 | 21 | saver = tf.train.Saver() 22 | with tf.Session() as sess: 23 | var_list = {} 24 | reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) 25 | var_to_shape_map = reader.get_variable_to_shape_map() 26 | for key in var_to_shape_map: 27 | try: 28 | tensor = sess.graph.get_tensor_by_name(key + ":0") 29 | except KeyError: 30 | # This tensor doesn't exist in the graph (for example it's 31 | # 'global_step' or a similar housekeeping element) so skip it. 32 | continue 33 | var_list[key] = tensor 34 | saver = saver_lib.Saver(var_list=var_list) 35 | 36 | # Restore variables 37 | saver.restore(sess, input_checkpoint) 38 | 39 | # Save new checkpoint and the graph 40 | saver.save(sess, save_path+'/with_placeholder') 41 | tf.train.write_graph(graph, save_path, 'graph.pbtxt') 42 | 43 | 44 | -------------------------------------------------------------------------------- /tools/label_image/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # TensorFlow C++ inference example for labeling images. 3 | 4 | package( 5 | default_visibility = ["//tensorflow:internal"], 6 | ) 7 | 8 | licenses(["notice"]) # Apache 2.0 9 | 10 | exports_files(["LICENSE"]) 11 | 12 | load("//tensorflow:tensorflow.bzl", "tf_cc_binary") 13 | 14 | tf_cc_binary( 15 | name = "label_image", 16 | srcs = [ 17 | "main.cc", 18 | ], 19 | linkopts = select({ 20 | "//tensorflow:android": [ 21 | "-pie", 22 | "-landroid", 23 | "-ljnigraphics", 24 | "-llog", 25 | "-lm", 26 | "-z defs", 27 | "-s", 28 | "-Wl,--exclude-libs,ALL", 29 | ], 30 | "//conditions:default": ["-lm"], 31 | }), 32 | deps = select({ 33 | "//tensorflow:android": [ 34 | # cc:cc_ops is used to include image ops (for label_image) 35 | # Jpg, gif, and png related code won't be included 36 | "//tensorflow/cc:cc_ops", 37 | "//tensorflow/core:android_tensorflow_lib", 38 | # cc:android_tensorflow_image_op is for including jpeg/gif/png 39 | # decoder to enable real-image evaluation on Android 40 | "//tensorflow/core/kernels:android_tensorflow_image_op", 41 | ], 42 | "//conditions:default": [ 43 | "//tensorflow/cc:cc_ops", 44 | "//tensorflow/core:core_cpu", 45 | "//tensorflow/core:framework", 46 | "//tensorflow/core:framework_internal", 47 | "//tensorflow/core:lib", 48 | "//tensorflow/core:protos_all_cc", 49 | "//tensorflow/core:tensorflow", 50 | ], 51 | }), 52 | ) 53 | 54 | py_binary( 55 | name = "label_image_py", 56 | srcs = ["label_image.py"], 57 | main = "label_image.py", 58 | srcs_version = "PY2AND3", 59 | deps = [ 60 | "//tensorflow:tensorflow_py", 61 | ], 62 | ) 63 | 64 | filegroup( 65 | name = "all_files", 66 | srcs = glob( 67 | ["**/*"], 68 | exclude = [ 69 | "**/METADATA", 70 | "**/OWNERS", 71 | "bin/**", 72 | "gen/**", 73 | ], 74 | ), 75 | visibility = ["//tensorflow:__subpackages__"], 76 | ) 77 | -------------------------------------------------------------------------------- /tools/label_image/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow C++ and Python Image Recognition Demo 2 | 3 | This example shows how you can load a pre-trained TensorFlow network and use it 4 | to recognize objects in images in C++. For Java see the [Java 5 | README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java), 6 | and for Go see the [godoc 7 | example](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#ex-package). 8 | 9 | ## Description 10 | 11 | This demo uses a Google Inception model to classify image files that are passed 12 | in on the command line. 13 | 14 | ## To build/install/run 15 | 16 | The TensorFlow `GraphDef` that contains the model definition and weights is not 17 | packaged in the repo because of its size. Instead, you must first download the 18 | file to the `data` directory in the source tree: 19 | 20 | ```bash 21 | $ curl -L "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz" | 22 | tar -C tensorflow/examples/label_image/data -xz 23 | ``` 24 | 25 | Then, as long as you've managed to build the main TensorFlow framework, you 26 | should have everything you need to run this example installed already. 27 | 28 | Once extracted, see the labels file in the data directory for the possible 29 | classifications, which are the 1,000 categories used in the Imagenet 30 | competition. 31 | 32 | To build it, run this command: 33 | 34 | ```bash 35 | $ bazel build tensorflow/examples/label_image/... 36 | ``` 37 | 38 | That should build a binary executable that you can then run like this: 39 | 40 | ```bash 41 | $ bazel-bin/tensorflow/examples/label_image/label_image 42 | ``` 43 | 44 | This uses the default example image that ships with the framework, and should 45 | output something similar to this: 46 | 47 | ``` 48 | I tensorflow/examples/label_image/main.cc:206] military uniform (653): 0.834306 49 | I tensorflow/examples/label_image/main.cc:206] mortarboard (668): 0.0218692 50 | I tensorflow/examples/label_image/main.cc:206] academic gown (401): 0.0103579 51 | I tensorflow/examples/label_image/main.cc:206] pickelhaube (716): 0.00800814 52 | I tensorflow/examples/label_image/main.cc:206] bulletproof vest (466): 0.00535088 53 | ``` 54 | 55 | In this case, we're using the default image of Admiral Grace Hopper, and you can 56 | see the network correctly spots she's wearing a military uniform, with a high 57 | score of 0.8. 58 | 59 | Next, try it out on your own images by supplying the --image= argument, e.g. 60 | 61 | ```bash 62 | $ bazel-bin/tensorflow/examples/label_image/label_image --image=my_image.png 63 | ``` 64 | 65 | For a more detailed look at this code, you can check out the C++ section of the 66 | [Inception tutorial](https://www.tensorflow.org/tutorials/image_recognition/). 67 | 68 | ## Python implementation 69 | 70 | label_image.py is a python implementation that provides code corresponding 71 | to the C++ code here. This gives more intuitive mapping between C++ and 72 | Python than the Python code mentioned in the 73 | [Inception tutorial](https://www.tensorflow.org/tutorials/image_recognition/). 74 | and could be easier to add visualization or debug code. 75 | 76 | 77 | `bazel-bin/tensorflow/examples/label_image/label_image_py` should be there after 78 | ```bash 79 | $ bazel build tensorflow/examples/label_image/... 80 | ``` 81 | 82 | Run 83 | 84 | ```bash 85 | $ bazel-bin/tensorflow/examples/label_image/label_image_py 86 | ``` 87 | 88 | Or, with tensorflow python package installed, you can run it like: 89 | ```bash 90 | $ python3 tensorflow/examples/label_image/label_image.py 91 | ``` 92 | 93 | And get result similar to this: 94 | ``` 95 | military uniform 0.834305 96 | mortarboard 0.0218694 97 | academic gown 0.0103581 98 | pickelhaube 0.00800818 99 | bulletproof vest 0.0053509 100 | ``` 101 | Run by yourself on python ,you should use command as flows: 102 | 103 | python tools/label_image/label_image.py \ 104 | --image=/workspace/zhangbin/dataset_robin/flowers/flower_photos/daisy/21652746_cc379e0eea_m.jpg \ 105 | --graph=/workspace/zhangbin/master/models/research/slim/tmp/frozen_graph.pb \ 106 | --labels=/workspace/zhangbin/dataset_robin/flowers/labels.txt \ 107 | --input_height=224 \ 108 | --input_width=224 \ 109 | --input_layer="input" \ 110 | --output_layer="SqueezeNet/Predictions/Reshape_1" 111 | -------------------------------------------------------------------------------- /tools/label_image/data/grace_hopper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/tools/label_image/data/grace_hopper.jpg -------------------------------------------------------------------------------- /tools/label_load_freeze_graph.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | import argparse 3 | import numpy as np 4 | import tensorflow as tf 5 | import sys 6 | from tensorflow.python.platform import gfile 7 | 8 | def load_graph(frozen_graph_filename): 9 | # We parse the graph_def file 10 | with tf.gfile.GFile(frozen_graph_filename, "rb") as f: 11 | graph_def = tf.GraphDef() 12 | graph_def.ParseFromString(f.read()) 13 | 14 | # We load the graph_def in the default graph 15 | with tf.Graph().as_default() as graph: 16 | tf.import_graph_def( 17 | graph_def, 18 | input_map=None, 19 | return_elements=None, 20 | name="", 21 | op_dict=None, 22 | producer_op_list=None 23 | ) 24 | 25 | writer = tf.summary.FileWriter("./logs_inception_from_freeze_graph", graph=graph) 26 | writer.close() 27 | 28 | return graph 29 | 30 | def load_labels(label_file): 31 | label = [] 32 | proto_as_ascii_lines = tf.gfile.GFile(label_file).readlines() 33 | for l in proto_as_ascii_lines: 34 | label.append(l.rstrip()) 35 | return label 36 | 37 | def main(_): 38 | #load the freeze graph and return graph 39 | graph = load_graph(FLAGS.frozen_model_filename) 40 | 41 | # We can list operations 42 | #op.values() gives you a list of tensors it produces 43 | #op.name gives you the name 44 | for op in graph.get_operations(): 45 | # print(op.name,op.values()) 46 | print(op.name) 47 | # prefix/Placeholder/inputs_placeholder 48 | # ... 49 | # prefix/Accuracy/predictions 50 | #操作有:prefix/Placeholder/inputs_placeholder 51 | #操作有:prefix/Accuracy/predictions 52 | #为了预测,我们需要找到我们需要feed的tensor,那么就需要该tensor的名字 53 | #注意prefix/Placeholder/inputs_placeholder仅仅是操作的名字,prefix/Placeholder/inputs_placeholder:0才是tensor的名字 54 | x = graph.get_tensor_by_name(FLAGS.input_tensor_name) 55 | y = graph.get_tensor_by_name(FLAGS.output_tensor_name) 56 | 57 | with tf.Session(graph=graph) as sess: 58 | image_data = gfile.FastGFile(FLAGS.image_dir, 'rb').read() 59 | pre = sess.run(y, feed_dict={x:image_data}) 60 | print(pre) 61 | 62 | results = np.squeeze(pre) 63 | classes =load_labels(FLAGS.output_labels) 64 | top_k = results.argsort()[-5:][::-1] 65 | for i in top_k: 66 | print(classes[i], results[i]) 67 | print ("finish") 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser() 72 | 73 | parser.add_argument( 74 | "--frozen_model_filename", 75 | default="model/freeze_model.pb", 76 | type=str, 77 | help="Frozen model file to import" 78 | ) 79 | 80 | parser.add_argument( 81 | '--output_labels', 82 | type=str, 83 | default='tmp/retrained_labels.txt', 84 | help='Where to load the trained graph\'s labels.' 85 | ) 86 | 87 | parser.add_argument( 88 | '--image_dir', 89 | type=str, 90 | default='flower_data/sunflowers/1022552002_2b93faf9e7_n.jpg', 91 | help='Path to folders of labeled images.' 92 | ) 93 | 94 | parser.add_argument( 95 | '--input_tensor_name', 96 | type=str, 97 | default='import/DecodeJpeg/contents:0', 98 | help="""\ 99 | The name of the output classification layer in the retrained graph.\ 100 | """ 101 | ) 102 | 103 | parser.add_argument( 104 | '--output_tensor_name', 105 | type=str, 106 | default='final_training_ops/Softmax:0', 107 | help="""\ 108 | The name of the output classification layer in the retrained graph.\ 109 | """ 110 | ) 111 | 112 | FLAGS, unparsed = parser.parse_known_args() 113 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 114 | -------------------------------------------------------------------------------- /tools/print_ops_from_pb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | import tensorflow as tf 3 | import numpy as np 4 | import os 5 | import argparse 6 | 7 | 8 | # print all op names 9 | def print_ops(pb_path,output_layer): 10 | with tf.gfile.FastGFile(os.path.join(pb_path), 'rb') as f: 11 | graph_def = tf.GraphDef() 12 | graph_def.ParseFromString(f.read()) 13 | _ = tf.import_graph_def(graph_def, name='') 14 | 15 | with tf.Session() as sess: 16 | 17 | ops = sess.graph.get_operations() 18 | for op in ops: 19 | print(op.name) 20 | 21 | writer =tf.summary.FileWriter("log_print_ops/",graph = sess.graph) 22 | writer.close() 23 | 24 | graph = tf.get_default_graph() 25 | input = graph.get_tensor_by_name(output_layer) 26 | print(input) 27 | 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser() 32 | 33 | parser.add_argument("--pb_path", default='tmp/frozen_graph.pb',help="name of pb_path") 34 | parser.add_argument("--output_layer",default='MobilenetV1/Predictions/Reshape_1:0', help="name of output layer") 35 | args = parser.parse_args() 36 | print_ops(args.pb_path,args.output_layer) 37 | -------------------------------------------------------------------------------- /tools/print_selective_registration_header.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Prints a header file to be used with SELECTIVE_REGISTRATION. 16 | 17 | An example of command-line usage is: 18 | bazel build tensorflow/python/tools:print_selective_registration_header && \ 19 | bazel-bin/tensorflow/python/tools/print_selective_registration_header \ 20 | --graphs=path/to/graph.pb > ops_to_register.h 21 | 22 | Then when compiling tensorflow, include ops_to_register.h in the include search 23 | path and pass -DSELECTIVE_REGISTRATION and -DSUPPORT_SELECTIVE_REGISTRATION 24 | - see core/framework/selective_registration.h for more details. 25 | 26 | When compiling for Android: 27 | bazel build -c opt --copt="-DSELECTIVE_REGISTRATION" \ 28 | --copt="-DSUPPORT_SELECTIVE_REGISTRATION" \ 29 | //tensorflow/contrib/android:libtensorflow_inference.so \ 30 | --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ 31 | --crosstool_top=//external:android/crosstool --cpu=armeabi-v7a 32 | """ 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import argparse 39 | import sys 40 | 41 | from tensorflow.python.platform import app 42 | from tensorflow.python.tools import selective_registration_header_lib 43 | 44 | FLAGS = None 45 | 46 | 47 | def main(unused_argv): 48 | graphs = FLAGS.graphs.split(',') 49 | print(selective_registration_header_lib.get_header( 50 | graphs, FLAGS.proto_fileformat, FLAGS.default_ops)) 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.register('type', 'bool', lambda v: v.lower() == 'true') 56 | parser.add_argument( 57 | '--graphs', 58 | type=str, 59 | default='', 60 | help='Comma-separated list of paths to model files to be analyzed.', 61 | required=True) 62 | parser.add_argument( 63 | '--proto_fileformat', 64 | type=str, 65 | default='rawproto', 66 | help='Format of proto file, either textproto or rawproto.') 67 | parser.add_argument( 68 | '--default_ops', 69 | type=str, 70 | default='NoOp:NoOp,_Recv:RecvOp,_Send:SendOp', 71 | help='Default operator:kernel pairs to always include implementation for.' 72 | 'Pass "all" to have all operators and kernels included; note that this ' 73 | 'should be used only when it is useful compared with simply not using ' 74 | 'selective registration, as it can in some cases limit the effect of ' 75 | 'compilation caches') 76 | 77 | FLAGS, unparsed = parser.parse_known_args() 78 | app.run(main=main, argv=[sys.argv[0]] + unparsed) 79 | -------------------------------------------------------------------------------- /tools/quantization/BUILD: -------------------------------------------------------------------------------- 1 | # Description: 2 | # Utilities for quantizing TensorFlow graphs to lower bit depths. 3 | 4 | package(default_visibility = ["//visibility:public"]) 5 | 6 | licenses(["notice"]) # Apache 2.0 7 | 8 | exports_files(["LICENSE"]) 9 | 10 | load("//tensorflow:tensorflow.bzl", "py_test") 11 | 12 | py_library( 13 | name = "quantize_graph_lib", 14 | srcs = ["quantize_graph.py"], 15 | srcs_version = "PY2AND3", 16 | deps = [ 17 | "//tensorflow/core:protos_all_py", 18 | "//tensorflow/python:array_ops", 19 | "//tensorflow/python:constant_op", 20 | "//tensorflow/python:dtypes", 21 | "//tensorflow/python:framework", 22 | "//tensorflow/python:framework_ops", 23 | "//tensorflow/python:graph_util", 24 | "//tensorflow/python:platform", 25 | "//tensorflow/python:session", 26 | "//tensorflow/python:tensor_shape", 27 | "//tensorflow/python:tensor_util", 28 | "//third_party/py/numpy", 29 | ], 30 | ) 31 | 32 | py_binary( 33 | name = "quantize_graph", 34 | srcs = ["quantize_graph.py"], 35 | srcs_version = "PY2AND3", 36 | deps = [ 37 | "//tensorflow/core:protos_all_py", 38 | "//tensorflow/python", # TODO(b/34059704): remove when fixed 39 | "//tensorflow/python:array_ops", 40 | "//tensorflow/python:client", 41 | "//tensorflow/python:framework", 42 | "//tensorflow/python:framework_for_generated_wrappers", 43 | "//tensorflow/python:graph_util", 44 | "//tensorflow/python:platform", 45 | "//tensorflow/python:tensor_util", 46 | "//third_party/py/numpy", 47 | ], 48 | ) 49 | 50 | py_test( 51 | name = "quantize_graph_test", 52 | size = "small", 53 | srcs = ["quantize_graph_test.py"], 54 | srcs_version = "PY2AND3", 55 | tags = ["nomsan"], # http://b/32242946 56 | deps = [ 57 | ":quantize_graph", 58 | "//tensorflow/core:protos_all_py", 59 | "//tensorflow/python:client", 60 | "//tensorflow/python:client_testlib", 61 | "//tensorflow/python:framework", 62 | "//tensorflow/python:framework_for_generated_wrappers", 63 | "//tensorflow/python:graph_util", 64 | "//tensorflow/python:platform", 65 | "//third_party/py/numpy", 66 | ], 67 | ) 68 | 69 | py_binary( 70 | name = "graph_to_dot", 71 | srcs = ["graph_to_dot.py"], 72 | main = "graph_to_dot.py", 73 | srcs_version = "PY2AND3", 74 | deps = [ 75 | "//tensorflow/core:protos_all_py", 76 | "//tensorflow/python:platform", 77 | ], 78 | ) 79 | 80 | filegroup( 81 | name = "all_files", 82 | srcs = glob( 83 | ["**/*"], 84 | exclude = [ 85 | "**/METADATA", 86 | "**/OWNERS", 87 | ], 88 | ), 89 | visibility = ["//tensorflow:__subpackages__"], 90 | ) 91 | -------------------------------------------------------------------------------- /tools/quantization/graph_to_dot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Converts a GraphDef file into a DOT format suitable for visualization. 16 | 17 | This script takes a GraphDef representing a network, and produces a DOT file 18 | that can then be visualized by GraphViz tools like dot and xdot. 19 | 20 | """ 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import re 26 | 27 | from google.protobuf import text_format 28 | 29 | from tensorflow.core.framework import graph_pb2 30 | from tensorflow.python.platform import app 31 | from tensorflow.python.platform import flags 32 | from tensorflow.python.platform import gfile 33 | 34 | FLAGS = flags.FLAGS 35 | 36 | flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""") 37 | flags.DEFINE_bool("input_binary", True, 38 | """Whether the input files are in binary format.""") 39 | flags.DEFINE_string("dot_output", "", """Where to write the DOT output.""") 40 | 41 | 42 | def main(unused_args): 43 | if not gfile.Exists(FLAGS.graph): 44 | print("Input graph file '" + FLAGS.graph + "' does not exist!") 45 | return -1 46 | 47 | graph = graph_pb2.GraphDef() 48 | with open(FLAGS.graph, "r") as f: 49 | if FLAGS.input_binary: 50 | graph.ParseFromString(f.read()) 51 | else: 52 | text_format.Merge(f.read(), graph) 53 | 54 | with open(FLAGS.dot_output, "wb") as f: 55 | print("digraph graphname {", file=f) 56 | for node in graph.node: 57 | output_name = node.name 58 | print(" \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f) 59 | for input_full_name in node.input: 60 | parts = input_full_name.split(":") 61 | input_name = re.sub(r"^\^", "", parts[0]) 62 | print(" \"" + input_name + "\" -> \"" + output_name + "\";", file=f) 63 | print("}", file=f) 64 | print("Created DOT file '" + FLAGS.dot_output + "'.") 65 | 66 | 67 | if __name__ == "__main__": 68 | app.run() 69 | -------------------------------------------------------------------------------- /tools/saved_model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """SavedModel utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.contrib.saved_model.python.saved_model import reader 22 | 23 | 24 | def get_meta_graph_def(saved_model_dir, tag_set): 25 | """Gets MetaGraphDef from SavedModel. 26 | 27 | Returns the MetaGraphDef for the given tag-set and SavedModel directory. 28 | 29 | Args: 30 | saved_model_dir: Directory containing the SavedModel to inspect or execute. 31 | tag_set: Group of tag(s) of the MetaGraphDef to load, in string format, 32 | separated by ','. For tag-set contains multiple tags, all tags must be 33 | passed in. 34 | 35 | Raises: 36 | RuntimeError: An error when the given tag-set does not exist in the 37 | SavedModel. 38 | 39 | Returns: 40 | A MetaGraphDef corresponding to the tag-set. 41 | """ 42 | saved_model = reader.read_saved_model(saved_model_dir) 43 | set_of_tags = set(tag_set.split(',')) 44 | for meta_graph_def in saved_model.meta_graphs: 45 | if set(meta_graph_def.meta_info_def.tags) == set_of_tags: 46 | return meta_graph_def 47 | 48 | raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set + 49 | ' could not be found in SavedModel') 50 | -------------------------------------------------------------------------------- /tools/strip_unused.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Removes unneeded nodes from a GraphDef file. 16 | 17 | This script is designed to help streamline models, by taking the input and 18 | output nodes that will be used by an application and figuring out the smallest 19 | set of operations that are required to run for those arguments. The resulting 20 | minimal graph is then saved out. 21 | 22 | The advantages of running this script are: 23 | - You may be able to shrink the file size. 24 | - Operations that are unsupported on your platform but still present can be 25 | safely removed. 26 | The resulting graph may not be as flexible as the original though, since any 27 | input nodes that weren't explicitly mentioned may not be accessible any more. 28 | 29 | An example of command-line usage is: 30 | bazel build tensorflow/python/tools:strip_unused && \ 31 | bazel-bin/tensorflow/python/tools/strip_unused \ 32 | --input_graph=some_graph_def.pb \ 33 | --output_graph=/tmp/stripped_graph.pb \ 34 | --input_node_names=input0 35 | --output_node_names=softmax 36 | 37 | You can also look at strip_unused_test.py for an example of how to use it. 38 | 39 | """ 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import argparse 45 | import sys 46 | 47 | from tensorflow.python.framework import dtypes 48 | from tensorflow.python.platform import app 49 | from tensorflow.python.tools import strip_unused_lib 50 | 51 | FLAGS = None 52 | 53 | 54 | def main(unused_args): 55 | strip_unused_lib.strip_unused_from_files(FLAGS.input_graph, 56 | FLAGS.input_binary, 57 | FLAGS.output_graph, 58 | FLAGS.output_binary, 59 | FLAGS.input_node_names, 60 | FLAGS.output_node_names, 61 | FLAGS.placeholder_type_enum) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser() 66 | parser.register('type', 'bool', lambda v: v.lower() == 'true') 67 | parser.add_argument( 68 | '--input_graph', 69 | type=str, 70 | default='', 71 | help='TensorFlow \'GraphDef\' file to load.') 72 | parser.add_argument( 73 | '--input_binary', 74 | nargs='?', 75 | const=True, 76 | type='bool', 77 | default=False, 78 | help='Whether the input files are in binary format.') 79 | parser.add_argument( 80 | '--output_graph', 81 | type=str, 82 | default='', 83 | help='Output \'GraphDef\' file name.') 84 | parser.add_argument( 85 | '--output_binary', 86 | nargs='?', 87 | const=True, 88 | type='bool', 89 | default=True, 90 | help='Whether to write a binary format graph.') 91 | parser.add_argument( 92 | '--input_node_names', 93 | type=str, 94 | default='', 95 | help='The name of the input nodes, comma separated.') 96 | parser.add_argument( 97 | '--output_node_names', 98 | type=str, 99 | default='', 100 | help='The name of the output nodes, comma separated.') 101 | parser.add_argument( 102 | '--placeholder_type_enum', 103 | type=int, 104 | default=dtypes.float32.as_datatype_enum, 105 | help='The AttrValue enum to use for placeholders.') 106 | FLAGS, unparsed = parser.parse_known_args() 107 | app.run(main=main, argv=[sys.argv[0]] + unparsed) 108 | -------------------------------------------------------------------------------- /tools/summary/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Event Processing 2 | 3 | This folder contains classes useful for analyzing and visualizing TensorFlow 4 | events files. The code is primarily being developed to support TensorBoard, 5 | but it can be used by anyone who wishes to analyze or visualize TensorFlow 6 | events files. 7 | 8 | If you wish to load TensorFlow events, you should use an EventAccumulator 9 | (to load from a single events file) or an EventMultiplexer (to load from 10 | multiple events files). 11 | 12 | The API around these tools has not solidified, and we may make backwards- 13 | incompatible changes without warning. 14 | 15 | If you have questions or requests, please contact danmane@google.com 16 | -------------------------------------------------------------------------------- /tools/summary/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Robinatp/Tensorflow_Model_Slim_Classify/4f652fe6937334b39d978f0e7f557bfc85f865ad/tools/summary/__init__.py -------------------------------------------------------------------------------- /tools/summary/plugin_asset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from tensorflow.python.framework import ops 21 | from tensorflow.python.framework import test_util 22 | from tensorflow.python.platform import googletest 23 | from tensorflow.python.summary import plugin_asset 24 | 25 | 26 | class _UnnamedPluginAsset(plugin_asset.PluginAsset): 27 | """An example asset with a dummy serialize method provided, but no name.""" 28 | 29 | def assets(self): 30 | return {} 31 | 32 | 33 | class _ExamplePluginAsset(_UnnamedPluginAsset): 34 | """Simple example asset.""" 35 | plugin_name = "_ExamplePluginAsset" 36 | 37 | 38 | class _OtherExampleAsset(_UnnamedPluginAsset): 39 | """Simple example asset.""" 40 | plugin_name = "_OtherExampleAsset" 41 | 42 | 43 | class _ExamplePluginThatWillCauseCollision(_UnnamedPluginAsset): 44 | plugin_name = "_ExamplePluginAsset" 45 | 46 | 47 | class PluginAssetTest(test_util.TensorFlowTestCase): 48 | 49 | def testGetPluginAsset(self): 50 | epa = plugin_asset.get_plugin_asset(_ExamplePluginAsset) 51 | self.assertIsInstance(epa, _ExamplePluginAsset) 52 | epa2 = plugin_asset.get_plugin_asset(_ExamplePluginAsset) 53 | self.assertIs(epa, epa2) 54 | opa = plugin_asset.get_plugin_asset(_OtherExampleAsset) 55 | self.assertIsNot(epa, opa) 56 | 57 | def testUnnamedPluginFails(self): 58 | with self.assertRaises(ValueError): 59 | plugin_asset.get_plugin_asset(_UnnamedPluginAsset) 60 | 61 | def testPluginCollisionDetected(self): 62 | plugin_asset.get_plugin_asset(_ExamplePluginAsset) 63 | with self.assertRaises(ValueError): 64 | plugin_asset.get_plugin_asset(_ExamplePluginThatWillCauseCollision) 65 | 66 | def testGetAllPluginAssets(self): 67 | epa = plugin_asset.get_plugin_asset(_ExamplePluginAsset) 68 | opa = plugin_asset.get_plugin_asset(_OtherExampleAsset) 69 | self.assertItemsEqual(plugin_asset.get_all_plugin_assets(), [epa, opa]) 70 | 71 | def testRespectsGraphArgument(self): 72 | g1 = ops.Graph() 73 | g2 = ops.Graph() 74 | e1 = plugin_asset.get_plugin_asset(_ExamplePluginAsset, g1) 75 | e2 = plugin_asset.get_plugin_asset(_ExamplePluginAsset, g2) 76 | 77 | self.assertEqual(e1, plugin_asset.get_all_plugin_assets(g1)[0]) 78 | self.assertEqual(e2, plugin_asset.get_all_plugin_assets(g2)[0]) 79 | 80 | if __name__ == "__main__": 81 | googletest.main() 82 | -------------------------------------------------------------------------------- /tools/summary/summary_iterator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Provides a method for reading events from an event file via an iterator.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | from tensorflow.core.util import event_pb2 23 | from tensorflow.python.lib.io import tf_record 24 | 25 | 26 | def summary_iterator(path): 27 | # pylint: disable=line-too-long 28 | """An iterator for reading `Event` protocol buffers from an event file. 29 | 30 | You can use this function to read events written to an event file. It returns 31 | a Python iterator that yields `Event` protocol buffers. 32 | 33 | Example: Print the contents of an events file. 34 | 35 | ```python 36 | for e in tf.train.summary_iterator(path to events file): 37 | print(e) 38 | ``` 39 | 40 | Example: Print selected summary values. 41 | 42 | ```python 43 | # This example supposes that the events file contains summaries with a 44 | # summary value tag 'loss'. These could have been added by calling 45 | # `add_summary()`, passing the output of a scalar summary op created with 46 | # with: `tf.summary.scalar('loss', loss_tensor)`. 47 | for e in tf.train.summary_iterator(path to events file): 48 | for v in e.summary.value: 49 | if v.tag == 'loss': 50 | print(v.simple_value) 51 | ``` 52 | 53 | See the protocol buffer definitions of 54 | [Event](https://www.tensorflow.org/code/tensorflow/core/util/event.proto) 55 | and 56 | [Summary](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) 57 | for more information about their attributes. 58 | 59 | Args: 60 | path: The path to an event file created by a `SummaryWriter`. 61 | 62 | Yields: 63 | `Event` protocol buffers. 64 | """ 65 | # pylint: enable=line-too-long 66 | for r in tf_record.tf_record_iterator(path): 67 | yield event_pb2.Event.FromString(r) 68 | -------------------------------------------------------------------------------- /tools/summary/text_summary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implements text_summary in TensorFlow, with TensorBoard support. 16 | 17 | The text_summary is a wrapper around the generic tensor_summary that takes a 18 | string-type tensor and emits a TensorSummary op with SummaryMetadata that 19 | notes that this summary is textual data for the TensorBoard text plugin. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | from tensorflow.core.framework import summary_pb2 27 | from tensorflow.python.framework import dtypes 28 | from tensorflow.python.ops.summary_ops import tensor_summary 29 | 30 | PLUGIN_NAME = "text" 31 | 32 | 33 | def text_summary(name, tensor, collections=None): 34 | """Summarizes textual data. 35 | 36 | Text data summarized via this plugin will be visible in the Text Dashboard 37 | in TensorBoard. The standard TensorBoard Text Dashboard will render markdown 38 | in the strings, and will automatically organize 1d and 2d tensors into tables. 39 | If a tensor with more than 2 dimensions is provided, a 2d subarray will be 40 | displayed along with a warning message. (Note that this behavior is not 41 | intrinsic to the text summary api, but rather to the default TensorBoard text 42 | plugin.) 43 | 44 | Args: 45 | name: A name for the generated node. Will also serve as a series name in 46 | TensorBoard. 47 | tensor: a string-type Tensor to summarize. 48 | collections: Optional list of ops.GraphKeys. The collections to add the 49 | summary to. Defaults to [_ops.GraphKeys.SUMMARIES] 50 | 51 | Returns: 52 | A TensorSummary op that is configured so that TensorBoard will recognize 53 | that it contains textual data. The TensorSummary is a scalar `Tensor` of 54 | type `string` which contains `Summary` protobufs. 55 | 56 | Raises: 57 | ValueError: If tensor has the wrong type. 58 | """ 59 | if tensor.dtype != dtypes.string: 60 | raise ValueError("Expected tensor %s to have dtype string, got %s" % 61 | (tensor.name, tensor.dtype)) 62 | 63 | summary_metadata = summary_pb2.SummaryMetadata( 64 | plugin_data=summary_pb2.SummaryMetadata.PluginData( 65 | plugin_name=PLUGIN_NAME)) 66 | t_summary = tensor_summary( 67 | name=name, 68 | tensor=tensor, 69 | summary_metadata=summary_metadata, 70 | collections=collections) 71 | return t_summary 72 | -------------------------------------------------------------------------------- /tools/summary/text_summary_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | from tensorflow.python.framework import test_util 21 | from tensorflow.python.ops import array_ops 22 | from tensorflow.python.platform import googletest 23 | from tensorflow.python.summary import text_summary 24 | 25 | 26 | class TextPluginTest(test_util.TensorFlowTestCase): 27 | """Test the Text Summary API. 28 | 29 | These tests are focused on testing the API design of the text_summary method. 30 | It doesn't test the PluginAsset and tensors registry functionality, because 31 | that is better tested by the text_plugin test that actually consumes that 32 | metadata. 33 | """ 34 | 35 | def testTextSummaryAPI(self): 36 | with self.test_session(): 37 | 38 | with self.assertRaises(ValueError): 39 | num = array_ops.constant(1) 40 | text_summary.text_summary("foo", num) 41 | 42 | # The API accepts vectors. 43 | arr = array_ops.constant(["one", "two", "three"]) 44 | summ = text_summary.text_summary("foo", arr) 45 | self.assertEqual(summ.op.type, "TensorSummaryV2") 46 | 47 | # the API accepts scalars 48 | summ = text_summary.text_summary("foo", array_ops.constant("one")) 49 | self.assertEqual(summ.op.type, "TensorSummaryV2") 50 | 51 | 52 | if __name__ == "__main__": 53 | googletest.main() 54 | -------------------------------------------------------------------------------- /tools/summary/writer/writer_cache.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A cache for FileWriters.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import threading 22 | 23 | from tensorflow.python.framework import ops 24 | from tensorflow.python.summary.writer.writer import FileWriter 25 | from tensorflow.python.util.tf_export import tf_export 26 | 27 | 28 | @tf_export('summary.FileWriterCache') 29 | class FileWriterCache(object): 30 | """Cache for file writers. 31 | 32 | This class caches file writers, one per directory. 33 | """ 34 | # Cache, keyed by directory. 35 | _cache = {} 36 | 37 | # Lock protecting _FILE_WRITERS. 38 | _lock = threading.RLock() 39 | 40 | @staticmethod 41 | def clear(): 42 | """Clear cached summary writers. Currently only used for unit tests.""" 43 | with FileWriterCache._lock: 44 | # Make sure all the writers are closed now (otherwise open file handles 45 | # may hang around, blocking deletions on Windows). 46 | for item in FileWriterCache._cache.values(): 47 | item.close() 48 | FileWriterCache._cache = {} 49 | 50 | @staticmethod 51 | def get(logdir): 52 | """Returns the FileWriter for the specified directory. 53 | 54 | Args: 55 | logdir: str, name of the directory. 56 | 57 | Returns: 58 | A `FileWriter`. 59 | """ 60 | with FileWriterCache._lock: 61 | if logdir not in FileWriterCache._cache: 62 | FileWriterCache._cache[logdir] = FileWriter( 63 | logdir, graph=ops.get_default_graph()) 64 | return FileWriterCache._cache[logdir] 65 | --------------------------------------------------------------------------------