├── __init__.py ├── datasets ├── __init__.py ├── preprocess_imagenet_validation_data.py ├── process_downloaded_imagenet.sh ├── process_bounding_boxes.py ├── LICENSE ├── imagenet_lsvrc_2015_synsets.txt └── build_imagenet_data.py ├── requirements.txt ├── tensorflow_extentions ├── __init__.py └── grouped_convolution.py ├── tabby_cat.jpg ├── tools ├── __init__.py ├── fine_tune.py ├── tools.py └── stats.py ├── preprocess_imagenet.sh ├── configs ├── __init__.py ├── v_1_0_SqNxt_23.py ├── v_2_0_SqNxt_23.py ├── v_1_0_G_SqNxt_23.py ├── v_1_0_SqNxt_23_v5.py ├── v_2_0_SqNxt_23_v5.py └── v_1_0_SqNxt_23_mod.py ├── run_cpu_docker.sh ├── run_tensorboard.sh ├── run_train.sh ├── LICENSE ├── .gitignore ├── optimizer.py ├── predict.py ├── train.py ├── squeezenext_model.py ├── dataloader.py ├── README.md └── squeezenext_architecture.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==0.23.3 2 | numpy==1.11.0 3 | six==1.10.0 4 | -------------------------------------------------------------------------------- /tensorflow_extentions/__init__.py: -------------------------------------------------------------------------------- 1 | from grouped_convolution import grouped_convolution -------------------------------------------------------------------------------- /tabby_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Timen/squeezenext-tensorflow/HEAD/tabby_cat.jpg -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from tools import define_first_dim, get_checkpoint_step,get_or_create_global_step,warmup_phase 2 | import stats 3 | import fine_tune -------------------------------------------------------------------------------- /preprocess_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DATA_DIR="/usr/local/share/Datasets/Imagenet/" 3 | 4 | PYTHONPATH=$PWD bash datasets/process_downloaded_imagenet.sh $DATA_DIR 5 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | configs = {} 3 | for module in os.listdir(os.path.dirname(__file__)): 4 | if module == '__init__.py' or module[-3:] != '.py': 5 | continue 6 | configs[module[:-3]] = __import__(module[:-3], locals(), globals()).training_params 7 | 8 | -------------------------------------------------------------------------------- /run_cpu_docker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODELS="/usr/local/share/models/" 4 | 5 | docker run -it \ 6 | -v $(pwd):/usr/local/src/ \ 7 | -v $MODELS:$MODELS \ 8 | -v $DATA_DIR:$DATA_DIR \ 9 | tensorflow/tensorflow \ 10 | bash -c "export DATA_DIR="$DATA_DIR"; bash" 11 | 12 | -------------------------------------------------------------------------------- /run_tensorboard.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | MODELS="/usr/local/share/models/" 4 | 5 | docker run -it -p 6006:6006 \ 6 | -v $(pwd):/usr/local/src/ \ 7 | -v $MODELS:$MODELS \ 8 | -v $DATA_DIR:$DATA_DIR \ 9 | tensorflow/tensorflow python -m tensorboard.main --logdir=$MODELS 10 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DATE=`date '+%Y-%m-%d_%H-%M'` 3 | TRAIN_PATH="/usr/local/share/models/" 4 | TRAIN_DIR=$TRAIN_PATH$DATE 5 | 6 | if [[ ! -e $DATA_DIR ]]; then 7 | echo "Data dir $DATA_DIR does not exists." 1>&2 8 | exit 1 9 | fi 10 | if [[ ! -e $TRAIN_DIR ]]; then 11 | mkdir $TRAIN_DIR 12 | elif [[ ! -d $TRAIN_DIR ]]; then 13 | echo "Model dir $TRAIN_DIR already exists but is not a directory" 1>&2 14 | fi 15 | 16 | PYTHONPATH="./" python train.py \ 17 | --model_dir $TRAIN_DIR \ 18 | --configuration "v_1_0_SqNxt_23" \ 19 | --batch_size 256 \ 20 | --num_epochs 120 \ 21 | --training_file_pattern $DATA_DIR"tf-records/train-*" \ 22 | --validation_file_pattern $DATA_DIR"tf-records/validation-*" 23 | -------------------------------------------------------------------------------- /tools/fine_tune.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import tensorflow as tf 3 | slim = tf.contrib.slim 4 | 5 | def init_weights(scope_name, path): 6 | if path == None: 7 | return 8 | 9 | # look for checkpoint 10 | model_path = tf.train.latest_checkpoint(path) 11 | initializer_fn = None 12 | 13 | if model_path: 14 | # only restore variables in the scope_name scope 15 | variables_to_restore = slim.get_variables_to_restore(include=[scope_name]) 16 | # Create the saver which will be used to restore the variables. 17 | initializer_fn = slim.assign_from_checkpoint_fn(model_path, variables_to_restore) 18 | else: 19 | print("could not find the fine tune ckpt at {}".format(path)) 20 | exit() 21 | 22 | def InitFn(scaffold,sess): 23 | initializer_fn(sess) 24 | return InitFn -------------------------------------------------------------------------------- /configs/v_1_0_SqNxt_23.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | training_params = { 4 | # the base learning rate used in the polynomial decay 5 | "base_lr":0.1, 6 | 7 | # how many steps to warmup the learning rate for 8 | "warmup_iter":3120, 9 | 10 | # What learning rate to start with in the warmup phase (ramps up to base_lr) 11 | "warmup_start_lr":0.025, 12 | 13 | #input size 14 | "image_size":227, 15 | 16 | # Block defs each tuple(x,y,z) describes one block with x number of filters at it's largest depth 17 | # y number of repeated units or bottlenecks, z stride for the first unit of the block. 18 | "block_defs":[(32,6,1),(64,6,2),(128,8,2),(256,1,2)], 19 | 20 | # definition of filters, kernel size and stride of the input convolution 21 | "input_def":(64,(7,7),2), 22 | 23 | # number of output classes 24 | "num_classes":1000, 25 | 26 | # How many groups to use for the grouped convolutions 27 | "groups": 1, 28 | 29 | # Whether to do relu before addition of the network and the residual 30 | "seperate_relus": 1 31 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tijmen Verhulsdonck 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/v_2_0_SqNxt_23.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # Some parameters will be divided or multiplied by 4 to compensate for the reduced batch size 4 | # as the original used batch size 1024 and the gtx1080ti can only fit a batch size 256. 5 | 6 | training_params = { 7 | # the base learning rate used in the polynomial decay 8 | "base_lr":0.4, 9 | 10 | # how many steps to warmup the learning rate for 11 | "warmup_iter":780, 12 | 13 | # What learning rate to start with in the warmup phase (ramps up to base_lr) 14 | "warmup_start_lr":0.1, 15 | 16 | #input size 17 | "image_size":227, 18 | 19 | # Block defs each tuple(x,y,z) describes one block with x number of filters at it's largest depth 20 | # y number of repeated units or bottlenecks, z stride for the first unit of the block. 21 | "block_defs":[(64,6,1),(128,6,2),(256,8,2),(512,1,2)], 22 | 23 | # definition of filters, kernel size and stride of the input convolution 24 | "input_def":(64,(7,7),2), 25 | 26 | # number of output classes 27 | "num_classes":1000, 28 | 29 | # How many groups to use for the grouped convolutions 30 | "groups": 1, 31 | 32 | # Whether to do relu before addition of the network and the residual 33 | "seperate_relus": 1 34 | } -------------------------------------------------------------------------------- /configs/v_1_0_G_SqNxt_23.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # Some parameters will be divided or multiplied by 4 to compensate for the reduced batch size 4 | # as the original used batch size 1024 and the gtx1080ti can only fit a batch size 256. 5 | 6 | training_params = { 7 | # the base learning rate used in the polynomial decay 8 | "base_lr":0.4, 9 | 10 | # how many steps to warmup the learning rate for 11 | "warmup_iter":780, 12 | 13 | # What learning rate to start with in the warmup phase (ramps up to base_lr) 14 | "warmup_start_lr":0.1, 15 | 16 | #input size 17 | "image_size":227, 18 | 19 | # Block defs each tuple(x,y,z) describes one block with x number of filters at it's largest depth 20 | # y number of repeated units or bottlenecks, z stride for the first unit of the block. 21 | "block_defs":[(32,6,1),(64,6,2),(128,8,2),(256,1,2)], 22 | 23 | # definition of filters, kernel size and stride of the input convolution 24 | "input_def":(64,(7,7),2), 25 | 26 | # number of output classes 27 | "num_classes":1000, 28 | 29 | # How many groups to use for the grouped convolutions 30 | "groups": 2, 31 | 32 | # Whether to do relu before addition of the network and the residual 33 | "seperate_relus": 1 34 | } -------------------------------------------------------------------------------- /configs/v_1_0_SqNxt_23_v5.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # Some parameters will be divided or multiplied by 4 to compensate for the reduced batch size 4 | # as the original used batch size 1024 and the gtx1080ti can only fit a batch size 256. 5 | 6 | training_params = { 7 | # the base learning rate used in the polynomial decay 8 | "base_lr":0.4, 9 | 10 | # how many steps to warmup the learning rate for 11 | "warmup_iter":780, 12 | 13 | # What learning rate to start with in the warmup phase (ramps up to base_lr) 14 | "warmup_start_lr":0.1, 15 | 16 | #input size 17 | "image_size":227, 18 | 19 | # Block defs each tuple(x,y,z) describes one block with x number of filters at it's largest depth 20 | # y number of repeated units or bottlenecks, z stride for the first unit of the block. 21 | "block_defs":[(32,2,1),(64,4,2),(128,14,2),(256,1,2)], 22 | 23 | # definition of filters, kernel size and stride of the input convolution 24 | "input_def":(64,(5,5),2), 25 | 26 | # number of output classes 27 | "num_classes":1000, 28 | 29 | # How many groups to use for the grouped convolutions 30 | "groups": 1, 31 | 32 | # Whether to do relu before addition of the network and the residual 33 | "seperate_relus": 1 34 | } -------------------------------------------------------------------------------- /configs/v_2_0_SqNxt_23_v5.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # Some parameters will be divided or multiplied by 4 to compensate for the reduced batch size 4 | # as the original used batch size 1024 and the gtx1080ti can only fit a batch size 256. 5 | 6 | training_params = { 7 | # the base learning rate used in the polynomial decay 8 | "base_lr":0.4, 9 | 10 | # how many steps to warmup the learning rate for 11 | "warmup_iter":780, 12 | 13 | # What learning rate to start with in the warmup phase (ramps up to base_lr) 14 | "warmup_start_lr":0.1, 15 | 16 | #input size 17 | "image_size":227, 18 | 19 | # Block defs each tuple(x,y,z) describes one block with x number of filters at it's largest depth 20 | # y number of repeated units or bottlenecks, z stride for the first unit of the block. 21 | "block_defs":[(64,2,1),(128,4,2),(256,14,2),(512,1,2)], 22 | 23 | # definition of filters, kernel size and stride of the input convolution 24 | "input_def":(64,(5,5),2), 25 | 26 | # number of output classes 27 | "num_classes":1000, 28 | 29 | # How many groups to use for the grouped convolutions 30 | "groups": 1, 31 | 32 | # Whether to do relu before addition of the network and the residual 33 | "seperate_relus": 1 34 | } -------------------------------------------------------------------------------- /configs/v_1_0_SqNxt_23_mod.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # Some parameters will be divided or multiplied by 4 to compensate for the reduced batch size 4 | # as the original used batch size 1024 and the gtx1080ti can only fit a batch size 256. 5 | 6 | training_params = { 7 | # the base learning rate used in the polynomial decay 8 | "base_lr":0.4/4.0, 9 | 10 | # how many steps to warmup the learning rate for 11 | "warmup_iter":780*4, 12 | 13 | # What learning rate to start with in the warmup phase (ramps up to base_lr) 14 | "warmup_start_lr":0.1/4.0, 15 | 16 | #input size 17 | "image_size":227, 18 | 19 | # Block defs each tuple(x,y,z) describes one block with x number of filters at it's largest depth 20 | # y number of repeated units or bottlenecks, z stride for the first unit of the block. 21 | "block_defs":[(32,6,1),(64,6,2),(128,8,2),(256,1,2)], 22 | 23 | # definition of filters, kernel size and stride of the input convolution 24 | "input_def":(64,(7,7),2), 25 | 26 | # number of output classes 27 | "num_classes":1000, 28 | 29 | # How many groups to use for the grouped convolutions 30 | "groups": 1, 31 | 32 | # Whether to do relu before addition of the network and the residual 33 | "seperate_relus": 0 34 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # idea 107 | .idea/ 108 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | import tools 5 | slim = tf.contrib.slim 6 | 7 | 8 | 9 | class PolyOptimizer(object): 10 | def __init__(self, training_params): 11 | self.base_lr = training_params["base_lr"] 12 | self.warmup_steps = training_params["warmup_iter"] 13 | self.warmup_learning_rate = training_params["warmup_start_lr"] 14 | self.power = 2.0 15 | self.momentum = 0.9 16 | 17 | def optimize(self,loss, training,total_steps): 18 | """ 19 | Momentum optimizer using a polynomial decay and a warmup phas to match this 20 | prototxt: https://github.com/amirgholami/SqueezeNext/blob/master/1.0-SqNxt-23/solver.prototxt 21 | :param loss: 22 | Loss value scalar 23 | :param training: 24 | Whether or not the model is training used to prevent updating moving mean of batch norm during eval 25 | :param total_steps: 26 | Total steps of the model used in the polynomial decay 27 | :return: 28 | Train op created with slim.learning.create_train_op 29 | """ 30 | with tf.name_scope("PolyOptimizer"): 31 | global_step = tools.get_or_create_global_step() 32 | 33 | learning_rate_schedule = tf.train.polynomial_decay( 34 | learning_rate=self.base_lr, 35 | global_step=global_step, 36 | decay_steps=total_steps, 37 | power=self.power 38 | ) 39 | learning_rate_schedule = tools.warmup_phase(learning_rate_schedule,self.base_lr, self.warmup_steps,self.warmup_learning_rate) 40 | tf.summary.scalar("learning_rate",learning_rate_schedule) 41 | optimizer = tf.train.MomentumOptimizer(learning_rate_schedule,self.momentum) 42 | return slim.learning.create_train_op(loss, 43 | optimizer, 44 | global_step=global_step, 45 | aggregation_method=tf.AggregationMethod.ADD_N, 46 | update_ops=None if training else []) 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /tools/tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | import os 5 | 6 | def define_first_dim(tensor_dict,dim_size): 7 | """ 8 | Define the first dim keeping the remaining dims the same 9 | :param tensor_dict: 10 | Dictionary of tensors 11 | :param dim_size: 12 | Size of first dimention 13 | :return: 14 | Dictionary of dimensions with the first dim defined as dim_size 15 | """ 16 | for key, tensor in tensor_dict.iteritems(): 17 | shape = tensor.get_shape().as_list()[1:] 18 | tensor_dict[key] = tf.reshape(tensor, [dim_size] + shape) 19 | return tensor_dict 20 | 21 | def get_checkpoint_step(checkpoint_dir): 22 | """ 23 | Get step at which checkpoint was saved from file name 24 | :param checkpoint_dir: 25 | Directory containing a checkpoint 26 | :return: 27 | Step at which checkpoint was saved 28 | """ 29 | ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 30 | if ckpt is None: 31 | return None 32 | else: 33 | return int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1]) 34 | 35 | 36 | def get_or_create_global_step(): 37 | """ 38 | Checks if global step variable exists otherwise creates it 39 | :return: 40 | Global step tensor 41 | """ 42 | global_step = tf.train.get_global_step() 43 | if global_step is None: 44 | global_step = tf.train.create_global_step() 45 | return global_step 46 | 47 | def warmup_phase(learning_rate_schedule,base_lr,warmup_steps,warmup_learning_rate): 48 | """ 49 | Ramps up the learning rate from warmup_learning_rate till base_lr in warmup_steps before 50 | switching to learning_rate_schedule. 51 | The warmup is linear and calculated using the below functions. 52 | slope = (base_lr - warmup_learning_rate) / warmup_steps 53 | warmup_rate = slope * global_step + warmup_learning_rate 54 | 55 | :param learning_rate_schedule: 56 | A regular learning rate schedule such as stepwise,exponential decay etc 57 | :param base_lr: 58 | The learning rate to which to ramp up to 59 | :param warmup_steps: 60 | The number of steps of the warmup phase 61 | :param warmup_learning_rate: 62 | The learning rate from which to start ramping up to base_lr 63 | :return: 64 | Warmup learning rate for global step < warmup_steps else returns learning_rate_schedule 65 | """ 66 | with tf.name_scope("warmup_learning_rate"): 67 | global_step = tf.cast(get_or_create_global_step(),tf.float32) 68 | if warmup_steps > 0: 69 | if base_lr < warmup_learning_rate: 70 | raise ValueError('learning_rate_base must be larger or equal to ' 71 | 'warmup_learning_rate.') 72 | slope = (learning_rate_schedule - warmup_learning_rate) / warmup_steps 73 | warmup_rate = slope * global_step + warmup_learning_rate 74 | learning_rate_schedule = tf.where(global_step < warmup_steps, warmup_rate, 75 | learning_rate_schedule) 76 | return learning_rate_schedule 77 | -------------------------------------------------------------------------------- /datasets/preprocess_imagenet_validation_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | r"""Process the ImageNet Challenge bounding boxes for TensorFlow model training. 17 | 18 | Associate the ImageNet 2012 Challenge validation data set with labels. 19 | 20 | The raw ImageNet validation data set is expected to reside in JPEG files 21 | located in the following directory structure. 22 | 23 | data_dir/ILSVRC2012_val_00000001.JPEG 24 | data_dir/ILSVRC2012_val_00000002.JPEG 25 | ... 26 | data_dir/ILSVRC2012_val_00050000.JPEG 27 | 28 | This script moves the files into a directory structure like such: 29 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 30 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 31 | ... 32 | where 'n01440764' is the unique synset label associated with 33 | these images. 34 | 35 | This directory reorganization requires a mapping from validation image 36 | number (i.e. suffix of the original file) to the associated label. This 37 | is provided in the ImageNet development kit via a Matlab file. 38 | 39 | In order to make life easier and divorce ourselves from Matlab, we instead 40 | supply a custom text file that provides this mapping for us. 41 | 42 | Sample usage: 43 | ./preprocess_imagenet_validation_data.py ILSVRC2012_img_val \ 44 | imagenet_2012_validation_synset_labels.txt 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | 51 | import os 52 | import sys 53 | 54 | from six.moves import xrange # pylint: disable=redefined-builtin 55 | 56 | 57 | if __name__ == '__main__': 58 | if len(sys.argv) < 3: 59 | print('Invalid usage\n' 60 | 'usage: preprocess_imagenet_validation_data.py ' 61 | ' ') 62 | sys.exit(-1) 63 | data_dir = sys.argv[1] 64 | validation_labels_file = sys.argv[2] 65 | 66 | # Read in the 50000 synsets associated with the validation data set. 67 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 68 | unique_labels = set(labels) 69 | 70 | # Make all sub-directories in the validation data dir. 71 | for label in unique_labels: 72 | labeled_data_dir = os.path.join(data_dir, label) 73 | os.makedirs(labeled_data_dir) 74 | 75 | # Move all of the image to the appropriate sub-directory. 76 | for i in xrange(len(labels)): 77 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 78 | original_filename = os.path.join(data_dir, basename) 79 | if not os.path.exists(original_filename): 80 | print('Failed to find: ', original_filename) 81 | sys.exit(-1) 82 | new_filename = os.path.join(data_dir, labels[i], basename) 83 | os.rename(original_filename, new_filename) -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import tensorflow as tf 3 | from configs import configs 4 | from squeezenext_model import Model 5 | from scipy import misc 6 | import numpy as np 7 | import scipy 8 | import argparse 9 | from datasets.build_imagenet_data import _build_synset_lookup 10 | 11 | parser = argparse.ArgumentParser(description='Process some integers.') 12 | parser.add_argument('image_path', type=str, 13 | help='Location of eval jpeg image') 14 | parser.add_argument('--model_dir', type=str, required=True, 15 | help='Location of model_dir') 16 | parser.add_argument('--configuration', type=str, default="v_1_0_SqNxt_23_mod", 17 | help='Name of model config file') 18 | parser.add_argument('--imagenet_metadata_file', type=str, default="./datasets/imagenet_metadata.txt", 19 | help='Path to metadata file') 20 | parser.add_argument('--labels_file', type=str, default="./datasets/imagenet_lsvrc_2015_synsets.txt", 21 | help='Path to labels file') 22 | args = parser.parse_args() 23 | 24 | 25 | def lookup_human_readable(res,synset,lookup_table): 26 | return lookup_table[synset[res]] 27 | 28 | def main(argv): 29 | """ 30 | Main function to start training 31 | :param argv: 32 | not used 33 | :return: 34 | None 35 | """ 36 | del(argv) 37 | 38 | # setup config dictionary 39 | config = configs[args.configuration] 40 | config["model_dir"] = args.model_dir 41 | config["output_train_images"] = False 42 | config["total_steps"] = 1 43 | config["fine_tune_ckpt"] = None 44 | 45 | # init model class 46 | model = Model(config,1) 47 | 48 | # create classifier 49 | classifier = tf.estimator.Estimator( 50 | model_dir=args.model_dir, 51 | model_fn=model.model_fn, 52 | params=config) 53 | 54 | # read image 55 | image = misc.imread(args.image_path) 56 | 57 | # resize to caffe standard size 58 | resized = scipy.misc.imresize(image, (256, 256, 3)) # 59 | 60 | #center crop 61 | crop_min = abs(config["image_size"] / 2 - (config["image_size"] / 2)) 62 | crop_max = crop_min + config["image_size"] 63 | image = resized[crop_min:crop_max, crop_min:crop_max, :] 64 | 65 | #subtract imagenet mean 66 | mean_sub = image.astype(np.float32) - np.array([123, 117, 104]).astype(np.float32) 67 | image = np.expand_dims(np.array(mean_sub), 0) 68 | my_input_fn = tf.estimator.inputs.numpy_input_fn( 69 | x={"image": image}, 70 | shuffle=False, 71 | batch_size=1) 72 | 73 | # setup synset lookup table for human readable labels 74 | lookup_table = _build_synset_lookup(args.imagenet_metadata_file) 75 | challenge_synsets = [l.strip() for l in 76 | tf.gfile.FastGFile(args.labels_file, 'r').readlines()] 77 | 78 | # perform prediction 79 | predictions = classifier.predict(input_fn=my_input_fn) 80 | 81 | # Print top 5 results 82 | for result in predictions: 83 | print("top 5: \n 1: {} \n 2: {} \n 3: {} \n 4: {} \n 5: {} \n".format( 84 | lookup_human_readable(result["top_5"][0], challenge_synsets, lookup_table), 85 | lookup_human_readable(result["top_5"][1], challenge_synsets, lookup_table), 86 | lookup_human_readable(result["top_5"][2], challenge_synsets, lookup_table), 87 | lookup_human_readable(result["top_5"][3], challenge_synsets, lookup_table), 88 | lookup_human_readable(result["top_5"][4], challenge_synsets, lookup_table) 89 | )) 90 | 91 | 92 | if __name__ == '__main__': 93 | tf.logging.set_verbosity(tf.logging.INFO) 94 | tf.app.run(main) 95 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import tensorflow as tf 3 | from configs import configs 4 | from squeezenext_model import Model 5 | import argparse 6 | import numpy as np 7 | import tools 8 | tf.logging.set_verbosity(tf.logging.INFO) 9 | 10 | parser = argparse.ArgumentParser(description='Training parser') 11 | parser.add_argument('--model_dir', type=str, required=True, 12 | help='Location of model_dir') 13 | parser.add_argument('--configuration', type=str, default="v_1_0_SqNxt_23", 14 | help='Name of model config file') 15 | parser.add_argument('--batch_size', type=int, default=64, 16 | help='Batch size during training') 17 | parser.add_argument('--num_examples_per_epoch', type=int, default=1281167, 18 | help='Number of examples in one epoch') 19 | parser.add_argument('--num_eval_examples', type=int, default=50000, 20 | help='Number of examples in one eval epoch') 21 | parser.add_argument('--num_epochs', type=int, default=120, 22 | help='Number of epochs for training') 23 | parser.add_argument('--training_file_pattern', type=str, required=True, 24 | help='Glob for training tf records') 25 | parser.add_argument('--validation_file_pattern', type=str, required=True, 26 | help='Glob for validation tf records') 27 | parser.add_argument('--eval_every_n_secs', type=int, default=1800, 28 | help='Run eval every N seconds, default is every half hour') 29 | parser.add_argument('--output_train_images', type=bool, default=True, 30 | help='Whether to save image summary during training (Warning: can lead to large event file sizes).') 31 | parser.add_argument('--fine_tune_ckpt', type=str, default=None, 32 | help='Ckpt used for initializing the variables') 33 | args = parser.parse_args() 34 | 35 | 36 | def main(argv): 37 | """ 38 | Main function to start training 39 | :param argv: 40 | not used 41 | :return: 42 | None 43 | """ 44 | del argv # not used 45 | 46 | # calculate steps per epoch 47 | steps_per_epoch = (args.num_examples_per_epoch / args.batch_size) 48 | 49 | # setup config dictionary 50 | config = configs[args.configuration] 51 | config["model_dir"] = args.model_dir 52 | config["output_train_images"] = args.output_train_images 53 | config["total_steps"] = args.num_epochs * steps_per_epoch 54 | config["fine_tune_ckpt"] = args.fine_tune_ckpt 55 | # init model class 56 | model = Model(config, args.batch_size) 57 | 58 | # create classifier 59 | classifier = tf.estimator.Estimator( 60 | model_dir=args.model_dir, 61 | model_fn=model.model_fn, 62 | params=config) 63 | tf.logging.info("Total steps = {}, num_epochs = {}, batch size = {}".format(config["total_steps"], args.num_epochs, 64 | args.batch_size)) 65 | 66 | # setup train spec 67 | train_spec = tf.estimator.TrainSpec(input_fn=lambda: model.input_fn(args.training_file_pattern, True), 68 | max_steps=config["total_steps"]) 69 | 70 | # setup eval spec evaluating ever n seconds 71 | eval_spec = tf.estimator.EvalSpec( 72 | input_fn=lambda: model.input_fn(args.validation_file_pattern, False), 73 | steps=args.num_eval_examples / args.batch_size, 74 | throttle_secs=args.eval_every_n_secs) 75 | 76 | # run train and evaluate 77 | tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec) 78 | 79 | classifier.evaluate(input_fn=lambda: model.input_fn(args.validation_file_pattern, False), 80 | steps=args.num_eval_examples / args.batch_size) 81 | 82 | 83 | 84 | 85 | if __name__ == '__main__': 86 | tf.logging.set_verbosity(tf.logging.INFO) 87 | tf.app.run(main) 88 | -------------------------------------------------------------------------------- /datasets/process_downloaded_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This file is a modified version of the script found at: 4 | # https://github.com/tensorflow/models/blob/master/research/slim/datasets/download_and_convert_imagenet.sh 5 | # The modifications include removal of imagenet data downloads and different paths used for extraction and 6 | # output. 7 | 8 | 9 | # Copyright 2016 Google Inc. All Rights Reserved. 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | # ============================================================================== 23 | 24 | # usage: 25 | # ./process_downloaded_imagenet.sh [OUTDIR] 26 | set -e 27 | 28 | 29 | WORK_DIR=$PYTHONPATH 30 | LABELS_FILE="${WORK_DIR}/datasets/imagenet_lsvrc_2015_synsets.txt" 31 | OUTDIR="${1:-./}" 32 | 33 | BBOX_DIR="${OUTDIR}bounding_boxes" 34 | mkdir -p "${BBOX_DIR}" 35 | cd "${OUTDIR}" 36 | 37 | # See here for details: http://www.image-net.org/download-bboxes 38 | BBOX_TAR_BALL="ILSVRC2012_bbox_train_v2.tar.gz" 39 | echo "Uncompressing bounding box annotations ..." 40 | tar xzf "${BBOX_TAR_BALL}" -C "${BBOX_DIR}" 41 | 42 | LABELS_ANNOTATED="${BBOX_DIR}/*" 43 | NUM_XML=$(ls -1 ${LABELS_ANNOTATED} | wc -l) 44 | echo "Identified ${NUM_XML} bounding box annotations." 45 | 46 | # Convert the XML files for bounding box annotations into a single CSV. 47 | echo "Extracting bounding box information from XML." 48 | BOUNDING_BOX_SCRIPT="${WORK_DIR}/datasets/process_bounding_boxes.py" 49 | BOUNDING_BOX_FILE="${OUTDIR}/imagenet_2012_bounding_boxes.csv" 50 | BOUNDING_BOX_DIR="${OUTDIR}bounding_boxes/" 51 | "${BOUNDING_BOX_SCRIPT}" "${BOUNDING_BOX_DIR}" "${LABELS_FILE}" \ 52 | | sort >"${BOUNDING_BOX_FILE}" 53 | 54 | # Uncompress all images from the ImageNet 2012 validation dataset. 55 | VALIDATION_TARBALL="ILSVRC2012_img_val.tar" 56 | OUTPUT_PATH="${OUTDIR}validation/" 57 | mkdir -p "${OUTPUT_PATH}" 58 | tar xf "${VALIDATION_TARBALL}" -C "${OUTPUT_PATH}" 59 | 60 | # Umcompress all images from the ImageNet 2012 train dataset. 61 | TRAIN_TARBALL="ILSVRC2012_img_train.tar" 62 | OUTPUT_PATH="${OUTDIR}train/" 63 | mkdir -p "${OUTPUT_PATH}" 64 | tar xf "${TRAIN_TARBALL}" -C "${OUTPUT_PATH}" 65 | 66 | # Un-compress the individual tar-files within the train tar-file. 67 | echo "Uncompressing individual train tar-balls in the training data." 68 | # 69 | while read SYNSET; do 70 | echo "Processing: ${SYNSET}" 71 | 72 | # Create a directory and delete anything there. 73 | mkdir -p "${OUTPUT_PATH}/${SYNSET}" 74 | rm -rf "${OUTPUT_PATH}/${SYNSET}/*" 75 | 76 | # Uncompress into the directory. 77 | tar xf "${TRAIN_TARBALL}" "${SYNSET}.tar" 78 | tar xf "${SYNSET}.tar" -C "${OUTPUT_PATH}/${SYNSET}/" 79 | rm -f "${SYNSET}.tar" 80 | 81 | echo "Finished processing: ${SYNSET}" 82 | done < "${LABELS_FILE}" 83 | 84 | # Note the locations of the train and validation data. 85 | TRAIN_DIRECTORY="${OUTDIR}train/" 86 | VALIDATION_DIRECTORY="${OUTDIR}validation/" 87 | 88 | 89 | # Preprocess the validation data by moving the images into the appropriate 90 | # sub-directory based on the label (synset) of the image. 91 | echo "Organizing the validation data into sub-directories." 92 | PREPROCESS_VAL_SCRIPT="${WORK_DIR}/datasets/preprocess_imagenet_validation_data.py" 93 | VAL_LABELS_FILE="${WORK_DIR}/datasets/imagenet_2012_validation_synset_labels.txt" 94 | 95 | "${PREPROCESS_VAL_SCRIPT}" "${VALIDATION_DIRECTORY}" "${VAL_LABELS_FILE}" 96 | 97 | echo "Finished downloading and preprocessing the ImageNet data." 98 | 99 | # Build the TFRecords version of the ImageNet data. 100 | BUILD_SCRIPT="${WORK_DIR}/datasets/build_imagenet_data.py" 101 | OUTPUT_DIRECTORY="${OUTDIR}/tf-records/" 102 | IMAGENET_METADATA_FILE="${WORK_DIR}/datasets/imagenet_metadata.txt" 103 | 104 | "${BUILD_SCRIPT}" \ 105 | --train_directory="${TRAIN_DIRECTORY}" \ 106 | --validation_directory="${VALIDATION_DIRECTORY}" \ 107 | --output_directory="${OUTPUT_DIRECTORY}" \ 108 | --imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \ 109 | --labels_file="${LABELS_FILE}" \ 110 | --bounding_box_file="${BOUNDING_BOX_FILE}" -------------------------------------------------------------------------------- /tools/stats.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.framework import ops 5 | from collections import defaultdict 6 | import pandas as pd 7 | import os 8 | supported_stat_ops = "Conv2D, MatMul, VariableV2, MaxPool,AvgPool,Add" 9 | exclude_in_name = ["gradients", "Initializer", "Regularizer", "AssignMovingAvg", "Momentum", "BatchNorm"] 10 | 11 | 12 | class ModelStats(tf.train.SessionRunHook): 13 | """Logs model stats to a csv.""" 14 | 15 | def __init__(self, scope_name, path,batch_size): 16 | """ 17 | Set class variables 18 | :param scope_name: 19 | Used to filter for tensors which name contain that specific variable scope 20 | :param path: 21 | path to model dir 22 | :param batch_size: 23 | batch size during training 24 | """ 25 | self.scope_name = scope_name 26 | self.batch_size = batch_size 27 | self.path = path 28 | self.inc_bef =0 29 | self.inc_after = 0 30 | 31 | def begin(self): 32 | """ 33 | Method to output statistics of the model to an easy to read csv, listing the multiply accumulates(maccs) and 34 | number of parameters, in the model dir. 35 | :param session: 36 | Tensorflow session 37 | :param coord: 38 | unused 39 | """ 40 | # get graph and operations 41 | graph = tf.get_default_graph() 42 | operations = graph.get_operations() 43 | # setup dictionaries 44 | biases = defaultdict(lambda: None) 45 | stat_dict = defaultdict(lambda: {"params":0,"maccs":0,"adds":0, "comps":0}) 46 | 47 | # iterate over tensors 48 | for tensor in operations: 49 | name = tensor.name 50 | # check is scope_name is in name, or any of the excluded strings 51 | if not self.scope_name in name or any(exclude_name in name for exclude_name in exclude_in_name): 52 | continue 53 | # Check if type is considered for the param and macc calcualtion 54 | if not tensor.type in supported_stat_ops: 55 | continue 56 | 57 | base_name = "/".join(name.split("/")[:-1]) 58 | 59 | if name.endswith("weights"): 60 | shape = tensor.node_def.attr["shape"].shape.dim 61 | sizes = [int(size.size) for size in shape] 62 | if any(base_name + "/BatchNorm" in operation.name for operation in operations) or any( 63 | base_name + "/biases" in operation.name for operation in operations): 64 | biases[base_name] = int(sizes[-1]) 65 | params = 1 66 | for dim in sizes: 67 | params = params * dim 68 | 69 | if biases[base_name] is not None: 70 | params = params + biases[base_name] 71 | stat_dict[base_name]["params"] = params 72 | elif tensor.type == "Add": 73 | flops = ops.get_stats_for_node_def(graph, tensor.node_def, 'flops').value 74 | if flops is not None: 75 | stat_dict[name]["adds"] = flops / self.batch_size 76 | elif tensor.type == "MaxPool": 77 | flops = ops.get_stats_for_node_def(graph, tensor.node_def, 'comps').value 78 | if flops is not None: 79 | stat_dict[name]["comps"] = flops / self.batch_size 80 | elif tensor.type == "AvgPool": 81 | flops = ops.get_stats_for_node_def(graph, tensor.node_def, 'flops').value 82 | if flops is not None: 83 | stat_dict[name]["adds"] = flops / self.batch_size 84 | elif tensor.type == "MatMul" or tensor.type == "Conv2D": 85 | flops = ops.get_stats_for_node_def(graph, tensor.node_def, 'flops').value 86 | if flops is not None: 87 | stat_dict[base_name]["maccs"] += int(flops / 2 / self.batch_size) 88 | elif name.endswith("biases"): 89 | pass 90 | else: 91 | print(name,tensor.type) 92 | exit() 93 | total_params = 0 94 | total_maccs = 0 95 | total_comps = 0 96 | total_adds = 0 97 | for key,stat in stat_dict.iteritems(): 98 | total_maccs += stat["maccs"] 99 | total_params += stat["params"] 100 | total_adds += stat["adds"] 101 | total_comps += stat["comps"] 102 | stat_dict["total"] = {"maccs":total_maccs,"params":total_params, "adds":total_adds, "comps":total_comps} 103 | df = pd.DataFrame.from_dict(stat_dict, orient='index') 104 | df.to_csv(os.path.join(self.path,'model_stats.csv')) 105 | -------------------------------------------------------------------------------- /squeezenext_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | 5 | slim = tf.contrib.slim 6 | metrics = tf.contrib.metrics 7 | import squeezenext_architecture as squeezenext 8 | from optimizer import PolyOptimizer 9 | from dataloader import ReadTFRecords 10 | import tools 11 | import os 12 | metrics = tf.contrib.metrics 13 | 14 | class Model(object): 15 | def __init__(self, config, batch_size): 16 | self.image_size = config["image_size"] 17 | self.num_classes = config["num_classes"] 18 | self.batch_size = batch_size 19 | self.read_tf_records = ReadTFRecords(self.image_size, self.batch_size, self.num_classes) 20 | 21 | def define_batch_size(self, features, labels): 22 | """ 23 | Define batch size of dictionary 24 | :param features: 25 | Feature dict 26 | :param labels: 27 | Labels dict 28 | :return: 29 | (features,label) 30 | """ 31 | features = tools.define_first_dim(features, self.batch_size) 32 | labels = tools.define_first_dim(labels, self.batch_size) 33 | return (features, labels) 34 | 35 | def input_fn(self, file_pattern,training): 36 | """ 37 | Input fn of model 38 | :param file_pattern: 39 | Glob file pattern 40 | :param training: 41 | Whether or not the model is training 42 | :return: 43 | Input generator 44 | """ 45 | return self.define_batch_size(*self.read_tf_records(file_pattern,training=training)) 46 | 47 | def model_fn(self, features, labels, mode, params): 48 | """ 49 | Function to create squeezenext model and setup training environment 50 | :param features: 51 | Feature dict from estimators input fn 52 | :param labels: 53 | Label dict from estimators input fn 54 | :param mode: 55 | What mode the model is in tf.estimator.ModeKeys 56 | :param params: 57 | Dictionary of parameters used to configurate the network 58 | :return: 59 | Train op, predictions, or eval op depening on mode 60 | """ 61 | 62 | training = mode == tf.estimator.ModeKeys.TRAIN 63 | # init model class 64 | model = squeezenext.SqueezeNext(self.num_classes, params["block_defs"], params["input_def"], params["groups"],params["seperate_relus"]) 65 | # create model inside the argscope of the model 66 | with slim.arg_scope(squeezenext.squeeze_next_arg_scope(training)): 67 | predictions,endpoints = model(features["image"], training) 68 | 69 | # output predictions 70 | if mode == tf.estimator.ModeKeys.PREDICT: 71 | _,top_5 = tf.nn.top_k(predictions,k=5) 72 | predictions = { 73 | 'top_1': tf.argmax(predictions, -1), 74 | 'top_5': top_5, 75 | 'probabilities': tf.nn.softmax(predictions), 76 | 'logits': predictions, 77 | } 78 | return tf.estimator.EstimatorSpec(mode, predictions=predictions) 79 | 80 | # create loss (should be equal to caffe softmaxwithloss) 81 | loss = tf.losses.softmax_cross_entropy(tf.squeeze(labels["class_vec"],axis=1), predictions) 82 | 83 | # create histogram of class spread 84 | tf.summary.histogram("classes",labels["class_idx"]) 85 | 86 | if training: 87 | # init poly optimizer 88 | optimizer = PolyOptimizer(params) 89 | # define train op 90 | train_op = optimizer.optimize(loss, training, params["total_steps"]) 91 | 92 | # if params["output_train_images"] is true output images during training 93 | if params["output_train_images"]: 94 | tf.summary.image("training", features["image"]) 95 | stats_hook = tools.stats.ModelStats("squeezenext", params["model_dir"],self.batch_size) 96 | # setup fine tune scaffold 97 | scaffold = tf.train.Scaffold(init_op=None, 98 | init_fn=tools.fine_tune.init_weights("squeezenext", params["fine_tune_ckpt"])) 99 | 100 | # create estimator training spec, which also outputs the model_stats of the model to params["model_dir"] 101 | return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, training_hooks=[stats_hook],scaffold=scaffold) 102 | 103 | 104 | 105 | if mode == tf.estimator.ModeKeys.EVAL: 106 | # Define the metrics: 107 | metrics_dict = { 108 | 'Recall@1': tf.metrics.accuracy(tf.argmax(predictions, axis=-1), labels["class_idx"][:, 0]), 109 | 'Recall@5': metrics.streaming_sparse_recall_at_k(predictions, tf.cast(labels["class_idx"], tf.int64), 110 | 5) 111 | } 112 | # output eval images 113 | eval_summary_hook = tf.train.SummarySaverHook( 114 | save_steps=100, 115 | output_dir=os.path.join(params["model_dir"],"eval"), 116 | summary_op=tf.summary.image("validation", features["image"])) 117 | 118 | #return eval spec 119 | return tf.estimator.EstimatorSpec( 120 | mode, loss=loss, eval_metric_ops=metrics_dict, 121 | evaluation_hooks=[eval_summary_hook]) 122 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | import multiprocessing 5 | 6 | def caffe_center_crop(image_encoded,image_size,training,resize_size=256): 7 | """ 8 | Emulates the center crop function used in caffe 9 | :param image_encoded: 10 | Jpeg string 11 | :param image_size: 12 | Output width and height 13 | :param training: 14 | Whether or not the model is training 15 | :param resize_size: 16 | Size to which to resize the decoded image before center croping. Default size is 256 17 | to match the size used in this script: 18 | https://github.com/BVLC/caffe/blob/master/examples/imagenet/create_imagenet.sh 19 | :return: 20 | Image of size [image_size,image_size,3] 21 | """ 22 | # decode resize and shape jpeg image 23 | image = tf.image.decode_jpeg(image_encoded,channels=3) 24 | image = tf.image.resize_images(image, [resize_size, resize_size]) 25 | image = tf.reshape(image, [resize_size, resize_size,3]) 26 | # when training do random crop and random flip during eval do center crop 27 | if training: 28 | image = tf.random_crop(image,[image_size,image_size,3]) 29 | image = tf.image.random_flip_left_right(image) 30 | else: 31 | crop_min = tf.abs(resize_size / 2 - (image_size / 2)) 32 | crop_max = crop_min+image_size 33 | image = image[crop_min:crop_max,crop_min:crop_max,:] 34 | return image 35 | 36 | def _parse_function(example_proto, image_size, num_classes,training,mean_value=(123,117,104),method="crop"): 37 | """ 38 | Parses tf-records created with build_imagenet_data.py 39 | :param example_proto: 40 | Single example from tf record 41 | :param image_size: 42 | Output image size 43 | :param num_classes: 44 | Number of classes in dataset 45 | :param training: 46 | Whether or not the model is training 47 | :param mean_value: 48 | Imagenet mean to subtract from the output iamge 49 | :param method: 50 | How to generate the input image 51 | :return: 52 | Features dict containing image, and labels dict containing class index and one hot vector 53 | """ 54 | 55 | # Schema of fields to parse 56 | schema = { 57 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, 58 | default_value=''), 59 | 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, 60 | default_value=-1), 61 | } 62 | 63 | 64 | image_size = tf.cast(image_size,tf.int32) 65 | mean_value = tf.cast(tf.stack(mean_value),tf.float32) 66 | 67 | # Parse example using schema 68 | parsed_features = tf.parse_single_example(example_proto, schema) 69 | jpeg_image = parsed_features["image/encoded"] 70 | # generate correctly sized image using one of 2 methods 71 | if method == "crop": 72 | image = caffe_center_crop(jpeg_image,image_size,training) 73 | elif method == "resize": 74 | image = tf.image.decode_jpeg(jpeg_image) 75 | image = tf.image.resize_images(image, [image_size, image_size]) 76 | else: 77 | raise("unknown image process method") 78 | # subtract mean 79 | image = image - mean_value 80 | 81 | # subtract 1 from class index as background class 0 is not used 82 | label_idx = tf.cast(parsed_features['image/class/label'], dtype=tf.int32)-1 83 | 84 | # create one hot vector 85 | label_vec = tf.one_hot(label_idx, num_classes) 86 | 87 | return {"image": tf.reshape(image,[image_size,image_size,3])}, {"class_idx": label_idx, "class_vec": label_vec} 88 | 89 | 90 | class ReadTFRecords(object): 91 | def __init__(self, image_size, batch_size, num_classes): 92 | self.image_size = image_size 93 | self.batch_size = batch_size 94 | self.num_classes = num_classes 95 | 96 | def __call__(self, glob_pattern,training=True): 97 | """ 98 | Read tf records matching a glob pattern 99 | :param glob_pattern: 100 | glob pattern eg. "/usr/local/share/Datasets/Imagenet/train-*.tfrecords" 101 | :param training: 102 | Whether or not to shuffle the data for training and evaluation 103 | :return: 104 | Iterator generating one example of batch size for each training step 105 | """ 106 | threads = multiprocessing.cpu_count() 107 | with tf.name_scope("tf_record_reader"): 108 | # generate file list 109 | files = tf.data.Dataset.list_files(glob_pattern, shuffle=training) 110 | 111 | # parallel fetch tfrecords dataset using the file list in parallel 112 | dataset = files.apply(tf.contrib.data.parallel_interleave( 113 | lambda filename: tf.data.TFRecordDataset(filename), cycle_length=threads)) 114 | 115 | # shuffle and repeat examples for better randomness and allow training beyond one epoch 116 | dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(32*self.batch_size)) 117 | 118 | # map the parse function to each example individually in threads*2 parallel calls 119 | dataset = dataset.map(map_func=lambda example: _parse_function(example, self.image_size, self.num_classes,training=training), 120 | num_parallel_calls=threads) 121 | 122 | # batch the examples 123 | dataset = dataset.batch(batch_size=self.batch_size) 124 | 125 | #prefetch batch 126 | dataset = dataset.prefetch(buffer_size=32) 127 | 128 | return dataset.make_one_shot_iterator().get_next() 129 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## _SqueezeNext Tensorflow:_ A tensorflow Implementation of SqueezeNext 2 | This repository contains an unofficial tensorflow implementation of SqueezeNext, a hardware-aware neural network design. 3 | 4 | @article{DBLP:journals/corr/abs-1803-10615, 5 | author = {Amir Gholami and 6 | Kiseok Kwon and 7 | Bichen Wu and 8 | Zizheng Tai and 9 | Xiangyu Yue and 10 | Peter H. Jin and 11 | Sicheng Zhao and 12 | Kurt Keutzer}, 13 | title = {SqueezeNext: Hardware-Aware Neural Network Design}, 14 | journal = {CoRR}, 15 | volume = {abs/1803.10615}, 16 | year = {2018}, 17 | url = {http://arxiv.org/abs/1803.10615}, 18 | archivePrefix = {arXiv}, 19 | eprint = {1803.10615}, 20 | timestamp = {Wed, 11 Apr 2018 17:54:17 +0200}, 21 | biburl = {https://dblp.org/rec/bib/journals/corr/abs-1803-10615}, 22 | bibsource = {dblp computer science bibliography, https://dblp.org} 23 | } 24 | ## Pretrained model: 25 | Using the data from the paper, original caffe version on github and other sources I tried to recreate the 1.0-SqueezeNext-23 model as closely as possible. The model 26 | achieved a 56% top 1 accuracy on validation set and a 80% top 5 accuracy on the validation set. This is about 3% under the reported results. Causes for this 27 | could be that the network was trained with a batch size of 256 instead of 1024, and because of the the number of steps required for 120 epochs increased 4 fold. 28 | The learning rate schedule was modified to account for the lower batch size and the increased number of steps. 29 | 30 | This configuration (stored in the v_1_0_SqNxt_23 config) can be downloaded from here [v_1_0_SqNxt_23_mod](https://drive.google.com/file/d/1FsNIrUSo-m8Td20Xk6N13RICcACqsU9L/view?usp=sharing). 31 | 32 | 33 | 34 | 35 | ## Installation: 36 | This implementation was made using version 1.8 of the tensorflow api. Earlier versions are untested, and may not work due to the 37 | use of some recently added functions for data loading and processing. The code was made for python 2.7. 38 | 39 | - Make sure tensorflow 1.8 or higher is by running: 40 | ```Shell 41 | python -c 'import tensorflow as tf; print(tf.__version__)' # for Python 2 42 | ``` 43 | And verifying the output is 1.8.0 or above. 44 | 45 | - Clone this repository: 46 | 47 | ```Shell 48 | git clone https://github.com/Timen/squeezenext-tensorflow.git 49 | ``` 50 | - Install requirements: 51 | ```Shell 52 | pip install -r requirements.txt 53 | ``` 54 | 55 | ## Preparing the Dataset: 56 | SqueezeNext like most other classifiers is trained with the ImageNet dataset (http://www.image-net.org/). One can download the 57 | data from the afromentioned website, however this can be rather slow so I recommend downloading the dataset using torrents 58 | available on (http://academictorrents.com/) namely: 59 | 60 | [Training Images](http://academictorrents.com/details/a306397ccf9c2ead27155983c254227c0fd938e2/tech), 61 | [Validation Images](http://academictorrents.com/details/5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5/tech&dllist=1) 62 | and [Bounding Box Annotations](http://academictorrents.com/details/28202f4f8dde5c9b26d406f5522f8763713e605b/tech&dllist=1) 63 | 64 | Please note one should still abide by the original License agreement of the Imagenet dataset. After downloading these files please perform the following steps to prepare the dataset. 65 | 66 | - Create a directory used for processing and storing the dataset. 67 | Please note you should have at least around 500 GB of free space available on the drive you are processing the dataset. 68 | Once this directory is created copy the 3 files downloaded earlier to the root of this directory from that directory execute 69 | the following command: 70 | ```Shell 71 | export DATA_DIR=$(pwd) 72 | 73 | ``` 74 | 75 | 76 | - Execute the following command from this projects root folder: 77 | ```Shell 78 | bash datasets/process_downloaded_imagenet.sh $DATA_DIR 79 | 80 | ``` 81 | Where $DATA_DIR is the root of the directory created to hold the 3 downloaded files. 82 | 83 | - Wait for processing to finish. 84 | The script process_downloaded_imagenet.sh will automatically extract the tarballs and process al the data into tf-records. 85 | The whole process can take between 2 and 5 hours depending on how fast the hard drive and cpu are. 86 | 87 | ## Training: 88 | After installation and dataset preparation one only needs to execute the run_train.sh script to start training. By executing 89 | the following command from the projects root folder: 90 | 91 | ```Shell 92 | bash run_train.sh 93 | ``` 94 | This will start training the 1.0 v1 version of squeezenext for 120 epochs with batch size 256. With a GTX1080Ti this training 95 | will take up to 4 days. If your gpu has a smaller memory capacity then a gtx1080ti you probably need to lower the batch size 96 | to be able to run the training. 97 | 98 | 99 | ## Prediction: 100 | Prediction is done using the predict.py script, to run it you give it a path to a jpeg image and pass the directory containing 101 | a trained model in the model_dir argument. 102 | 103 | ```Shell 104 | python predict.py ./tabby_cat.jpg --model_dir ?TRAIN_DIR from the run_train.sh or pretrained model directory? 105 | ``` 106 | 107 | This script will load the image and run the classifier on it, the output is the top 5 human readable class labels. 108 | 109 | ## Modifying the hyper parameters: 110 | The batch size number of epochs and some other settings regarding epoch size, file location etc. can be passed as command 111 | line arguments to the train.py script. 112 | 113 | Switching between specific configurations such as the grouped convolution and the non grouped 114 | convolution versions of squeezenext should be done by selecting which config file from the configs folder to use. This can be done 115 | by passing the file name without the .py as the command line argument --configuration. It is easy to add your own configuration just 116 | copy one of the other configs and rename the file to something new (keep in mind it will be imported in python so stick to numbers letters 117 | and under scores). You can then change the parameters in the file to customize your own config and pass the new file name as --configuration parameter.(the python scripts in configs are automatically imported) 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /squeezenext_architecture.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import tensorflow as tf 4 | 5 | slim = tf.contrib.slim 6 | 7 | import tensorflow_extentions as tfe 8 | 9 | 10 | def squeezenext_unit(inputs, filters, stride,height_first_order, groups,seperate_relus): 11 | """ 12 | Squeezenext unit according to: 13 | https://arxiv.org/pdf/1803.10615.pdf 14 | 15 | :param inputs: 16 | Input tensor 17 | :param filters: 18 | Number of filters at output of this unit 19 | :param stride: 20 | Input stride 21 | :param height_first_order: 22 | Whether to first perform seperable convolution in the vertical direcation or horizontal direction 23 | :param groups: 24 | Number of groups for some of the convolutions (which ones are different from the paper but equal to: 25 | https://github.com/amirgholami/SqueezeNext/blob/master/1.0-G-SqNxt-23/train_val.prototxt) 26 | :return: 27 | Output tensor, not(height_first_order) 28 | """ 29 | input_channels = inputs.get_shape().as_list()[-1] 30 | shortcut = inputs 31 | out_activation = tf.nn.relu if bool(seperate_relus) else None 32 | # shorcut convolution only to be executed if input channels is different from output channels or 33 | # stride is greater than 1. 34 | if input_channels != filters or stride != 1: 35 | shortcut = slim.conv2d(shortcut, filters, [1, 1], stride=stride, activation_fn=out_activation) 36 | 37 | # input 1x1 reduction convolutions 38 | block = tfe.grouped_convolution(inputs, filters / 2, [1, 1], groups, stride=stride) 39 | block = slim.conv2d(block, block.get_shape().as_list()[-1] / 2, [1, 1]) 40 | 41 | # seperable convolutions 42 | if height_first_order: 43 | input_channels_seperated = block.get_shape().as_list()[-1] 44 | block = tfe.grouped_convolution(block, input_channels_seperated * 2, [3, 1], groups) 45 | block = tfe.grouped_convolution(block, block.get_shape().as_list()[-1], [1, 3], groups) 46 | 47 | else: 48 | input_channels_seperated = block.get_shape().as_list()[-1] 49 | block = tfe.grouped_convolution(block, input_channels_seperated * 2, [1, 3], groups) 50 | block = tfe.grouped_convolution(block, block.get_shape().as_list()[-1], [3, 1], groups) 51 | # switch order next unit 52 | height_first_order = not height_first_order 53 | 54 | # output convolutions 55 | block = slim.conv2d(block, block.get_shape().as_list()[-1] * 2, [1, 1],activation_fn=out_activation) 56 | assert block.get_shape().as_list()[-1] == filters, "Block output channels not equal to number of specified filters" 57 | 58 | 59 | return tf.nn.relu(block + shortcut),height_first_order 60 | 61 | 62 | class SqueezeNext(object): 63 | """Base class for building the SqueezeNext Model.""" 64 | 65 | def __init__(self, num_classes, block_defs, input_def,groups,seperate_relus): 66 | self.num_classes = num_classes 67 | self.block_defs = block_defs 68 | self.input_def = input_def 69 | self.groups = groups 70 | self.seperate_relus = seperate_relus 71 | 72 | 73 | def __call__(self, inputs, training,height_first_order = True): 74 | """Add operations to classify a batch of input images. 75 | 76 | Args: 77 | inputs: A Tensor representing a batch of input images. 78 | training: A boolean. Set to True to add operations required only when 79 | training the classifier. 80 | 81 | Returns: 82 | A logits Tensor with shape [, self.num_classes]. 83 | """ 84 | 85 | with tf.variable_scope("squeezenext"): 86 | input_filters, input_kernel,input_stride = self.input_def 87 | endpoints = {} 88 | 89 | # input convolution and pooling 90 | net = slim.conv2d(inputs, input_filters, input_kernel, stride=input_stride,scope="input_conv",padding="VALID") 91 | endpoints["input_conv"] = net 92 | net = slim.max_pool2d(net, [3, 3], stride=2) 93 | endpoints["max_pool"] = net 94 | 95 | # create block based network 96 | for block_idx,block_def in enumerate(self.block_defs): 97 | 98 | filters,units,stride = block_def 99 | with tf.variable_scope("block_{}".format(block_idx)): 100 | # create seperate units inside a block 101 | for unit_idx in range(0,units): 102 | with tf.variable_scope("unit_{}".format(unit_idx)): 103 | if unit_idx != 0: 104 | # perform striding only in first unit of a block 105 | net,height_first_order = squeezenext_unit(net,filters,1,height_first_order,self.groups,self.seperate_relus) 106 | else: 107 | net,height_first_order = squeezenext_unit(net, filters, stride,height_first_order,self.groups,self.seperate_relus) 108 | endpoints["block_{}".format(block_idx)+"/"+"unit_{}".format(unit_idx)]=net 109 | # output conv and pooling 110 | net = slim.conv2d(net, 128, [1,1],scope="output_conv") 111 | endpoints["output_conv"] = net 112 | net = tf.squeeze(slim.avg_pool2d(net,net.get_shape().as_list()[1:3],scope="avg_pool_out", padding="VALID"),axis=[1,2]) 113 | endpoints["avg_pool_out"] = net 114 | 115 | # Fully connected output without biases 116 | output = slim.fully_connected(net,self.num_classes,activation_fn=None,normalizer_fn=None, biases_initializer=None) 117 | endpoints["output"] = output 118 | 119 | return output,endpoints 120 | 121 | 122 | 123 | def squeeze_next_arg_scope(is_training, 124 | weight_decay=0.0001): 125 | """ 126 | Setup slim arg scope according to paper and github project 127 | :param is_training: 128 | Whether or not the network is training 129 | :param weight_decay: 130 | Weight decay of the convolutional layers 131 | :return: 132 | Slim arg scope 133 | """ 134 | batch_norm_params = { 135 | 'is_training': is_training, 136 | 'center': True, 137 | 'scale': True, 138 | 'decay': 0.999, 139 | 'epsilon': 1e-5, 140 | 'fused': True, 141 | } 142 | 143 | # Use xavier an l2 decay 144 | weights_init = tf.contrib.layers.xavier_initializer() 145 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 146 | 147 | 148 | with slim.arg_scope([slim.conv2d,tfe.grouped_convolution], 149 | weights_initializer=weights_init, 150 | normalizer_fn=slim.batch_norm, 151 | normalizer_params=batch_norm_params, 152 | # No biases in the convolutions (are already included in batch_norm) 153 | biases_initializer=None, 154 | weights_regularizer=regularizer): 155 | with slim.arg_scope([slim.batch_norm], **batch_norm_params) as sc: 156 | return sc 157 | -------------------------------------------------------------------------------- /tensorflow_extentions/grouped_convolution.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.layers.python.layers import utils 3 | import collections 4 | 5 | slim = tf.contrib.slim 6 | 7 | 8 | def grouped_convolution2D(inputs, filters, padding, num_groups, 9 | strides=None, 10 | dilation_rate=None): 11 | """ 12 | Performs a grouped convolution by applying a normal convolution to each of the seperate groups 13 | :param inputs: 14 | Input of the shape [,H,W,inC] 15 | :param filters: 16 | [H,W,inC/num_groups,outC] 17 | :param padding: 18 | What padding to use 19 | :param num_groups: 20 | Number of seperate groups 21 | :param strides: 22 | Stride 23 | :param dilation_rate: 24 | Dilation rate 25 | :return: 26 | Output of shape [,H/stride,W/stride,outC] 27 | """ 28 | # Split input and outputs along their last dimension 29 | input_list = tf.split(inputs, num_groups, axis=-1) 30 | filter_list = tf.split(filters, num_groups, axis=-1) 31 | output_list = [] 32 | 33 | # Perform a normal convolution on each split of the input and filters 34 | for conv_idx, (input_tensor, filter_tensor) in enumerate(zip(input_list, filter_list)): 35 | output_list.append(tf.nn.convolution( 36 | input_tensor, 37 | filter_tensor, 38 | padding, 39 | strides=strides, 40 | dilation_rate=dilation_rate, 41 | name="grouped_convolution" + "_{}".format(conv_idx) 42 | )) 43 | # Concatenate ouptputs along their last dimentsion 44 | outputs = tf.concat(output_list, axis=-1) 45 | 46 | return outputs 47 | 48 | 49 | @slim.add_arg_scope 50 | def grouped_convolution(inputs, 51 | num_outputs, 52 | kernel_size, 53 | groups, 54 | stride=1, 55 | padding='SAME', 56 | rate=1, 57 | activation_fn=tf.nn.relu, 58 | normalizer_fn=None, 59 | normalizer_params=None, 60 | weights_initializer=tf.contrib.layers.xavier_initializer(), 61 | weights_regularizer=None, 62 | biases_initializer=tf.initializers.zeros(), 63 | biases_regularizer=None, 64 | reuse=None, 65 | trainable=True, 66 | scope=None, 67 | outputs_collections=None): 68 | """Adds an 2-D grouped convolution followed by an optional batch_norm layer. 69 | `convolution` creates a variable called `weights`, representing the 70 | convolutional kernel, that is convolved (actually cross-correlated) with the 71 | `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is 72 | provided (such as `batch_norm`), it is then applied. Otherwise, if 73 | `normalizer_fn` is None and a `biases_initializer` is provided then a `biases` 74 | variable would be created and added the activations. Finally, if 75 | `activation_fn` is not `None`, it is applied to the activations as well. 76 | Performs atrous convolution with input stride/dilation rate equal to `rate` 77 | if a value > 1 for any dimension of `rate` is specified. In this case 78 | `stride` values != 1 are not supported. 79 | Args: 80 | inputs: A Tensor of rank N+2 of shape 81 | `[batch_size] + input_spatial_shape + [in_channels]` if data_format does 82 | not start with "NC" (default), or 83 | `[batch_size, in_channels] + input_spatial_shape` if data_format starts 84 | with "NC". 85 | num_outputs: Integer, the number of output filters. 86 | kernel_size: A sequence of N positive integers specifying the spatial 87 | dimensions of the filters. Can be a single integer to specify the same 88 | value for all spatial dimensions. 89 | groups: Number of groups to split the input up in before applying convolutions to the 90 | seperate groups. If groups==1 return normal slim.conv2d. 91 | stride: A sequence of N positive integers specifying the stride at which to 92 | compute output. Can be a single integer to specify the same value for all 93 | spatial dimensions. Specifying any `stride` value != 1 is incompatible 94 | with specifying any `rate` value != 1. 95 | padding: One of `"VALID"` or `"SAME"`. 96 | rate: A sequence of N positive integers specifying the dilation rate to use 97 | for atrous convolution. Can be a single integer to specify the same 98 | value for all spatial dimensions. Specifying any `rate` value != 1 is 99 | incompatible with specifying any `stride` value != 1. 100 | activation_fn: Activation function. The default value is a ReLU function. 101 | Explicitly set it to None to skip it and maintain a linear activation. 102 | normalizer_fn: Normalization function to use instead of `biases`. If 103 | `normalizer_fn` is provided then `biases_initializer` and 104 | `biases_regularizer` are ignored and `biases` are not created nor added. 105 | default set to None for no normalizer function 106 | normalizer_params: Normalization function parameters. 107 | weights_initializer: An initializer for the weights. 108 | weights_regularizer: Optional regularizer for the weights. 109 | biases_initializer: An initializer for the biases. If None skip biases. 110 | biases_regularizer: Optional regularizer for the biases. 111 | reuse: Whether or not the layer and its variables should be reused. To be 112 | able to reuse the layer scope must be given. 113 | outputs_collections: Collection to add the outputs. 114 | trainable: If `True` also add variables to the graph collection 115 | `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 116 | scope: Optional scope for `variable_scope`. 117 | Returns: 118 | A tensor representing the output of the operation. 119 | Raises: 120 | ValueError: If `data_format` is invalid. 121 | ValueError: Both 'rate' and `stride` are not uniformly 1. 122 | ValueError: If 'groups'<1. 123 | """ 124 | # if no group size specified or less than/equal to zero return a normal convolution 125 | if groups == 1: 126 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, padding=padding, 127 | activation_fn=activation_fn, 128 | normalizer_fn=normalizer_fn, 129 | normalizer_params=normalizer_params, 130 | weights_initializer=weights_initializer, 131 | weights_regularizer=weights_regularizer, 132 | biases_initializer=biases_initializer, 133 | biases_regularizer=biases_regularizer, 134 | reuse=reuse, 135 | trainable=trainable, 136 | scope=scope) 137 | if groups < 1: 138 | raise ValueError("Specify a number of groups greater than zero, groups given is {}".format(groups)) 139 | 140 | input_channels = inputs.get_shape().as_list()[-1] 141 | 142 | # check if the number of groups and corresponding group_size is an integer division of the input and output channels 143 | lowest_channels = min(input_channels, num_outputs) 144 | assert lowest_channels % groups == 0, "the remainder of min(input_channels,output_channels)/groups should be zero" 145 | assert max(input_channels, 146 | num_outputs) % groups == 0, "the remainder of max(input_channels,output_channels)/groups=({}) " \ 147 | "should be zero".format( 148 | groups) 149 | 150 | with tf.variable_scope(scope, 'Group_Conv', [inputs], reuse=reuse) as sc: 151 | # define weight shape 152 | if isinstance(kernel_size, collections.Iterable): 153 | weights_shape = list(kernel_size) + [input_channels/groups] + [num_outputs] 154 | else: 155 | weights_shape = [kernel_size, kernel_size, input_channels/groups, num_outputs] 156 | 157 | # create weights variable 158 | weights = slim.variable('weights', 159 | shape=weights_shape, 160 | initializer=weights_initializer, 161 | regularizer=weights_regularizer, 162 | trainable=trainable) 163 | strides = [stride, stride] 164 | dilation_rate = [rate, rate] 165 | # perform grouped convolution 166 | outputs = grouped_convolution2D(inputs, weights, padding, groups, 167 | strides=strides, 168 | dilation_rate=dilation_rate) 169 | if biases_initializer is not None: 170 | biases = slim.variable('biases', 171 | shape=[num_outputs], 172 | initializer=biases_initializer, 173 | regularizer=biases_regularizer, 174 | trainable=trainable) 175 | outputs = tf.nn.bias_add(outputs, biases) 176 | if normalizer_fn is not None: 177 | normalizer_params = normalizer_params or {} 178 | outputs = normalizer_fn(outputs, **normalizer_params) 179 | 180 | if activation_fn is not None: 181 | outputs = activation_fn(outputs) 182 | return utils.collect_named_outputs(outputs_collections, sc.name, outputs) 183 | -------------------------------------------------------------------------------- /datasets/process_bounding_boxes.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Process the ImageNet Challenge bounding boxes for TensorFlow model training. 17 | 18 | This script is called as 19 | 20 | process_bounding_boxes.py [synsets-file] 21 | 22 | Where is a directory containing the downloaded and unpacked bounding box 23 | data. If [synsets-file] is supplied, then only the bounding boxes whose 24 | synstes are contained within this file are returned. Note that the 25 | [synsets-file] file contains synset ids, one per line. 26 | 27 | The script dumps out a CSV text file in which each line contains an entry. 28 | n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940 29 | 30 | The entry can be read as: 31 | , , , , 32 | 33 | The bounding box for contains two points (xmin, ymin) and 34 | (xmax, ymax) specifying the lower-left corner and upper-right corner of a 35 | bounding box in *relative* coordinates. 36 | 37 | The user supplies a directory where the XML files reside. The directory 38 | structure in the directory is assumed to look like this: 39 | 40 | /nXXXXXXXX/nXXXXXXXX_YYYY.xml 41 | 42 | Each XML file contains a bounding box annotation. The script: 43 | 44 | (1) Parses the XML file and extracts the filename, label and bounding box info. 45 | 46 | (2) The bounding box is specified in the XML files as integer (xmin, ymin) and 47 | (xmax, ymax) *relative* to image size displayed to the human annotator. The 48 | size of the image displayed to the human annotator is stored in the XML file 49 | as integer (height, width). 50 | 51 | Note that the displayed size will differ from the actual size of the image 52 | downloaded from image-net.org. To make the bounding box annotation useable, 53 | we convert bounding box to floating point numbers relative to displayed 54 | height and width of the image. 55 | 56 | Note that each XML file might contain N bounding box annotations. 57 | 58 | Note that the points are all clamped at a range of [0.0, 1.0] because some 59 | human annotations extend outside the range of the supplied image. 60 | 61 | See details here: http://image-net.org/download-bboxes 62 | 63 | (3) By default, the script outputs all valid bounding boxes. If a 64 | [synsets-file] is supplied, only the subset of bounding boxes associated 65 | with those synsets are outputted. Importantly, one can supply a list of 66 | synsets in the ImageNet Challenge and output the list of bounding boxes 67 | associated with the training images of the ILSVRC. 68 | 69 | We use these bounding boxes to inform the random distortion of images 70 | supplied to the network. 71 | 72 | If you run this script successfully, you will see the following output 73 | to stderr: 74 | > Finished processing 544546 XML files. 75 | > Skipped 0 XML files not in ImageNet Challenge. 76 | > Skipped 0 bounding boxes not in ImageNet Challenge. 77 | > Wrote 615299 bounding boxes from 544546 annotated images. 78 | """ 79 | 80 | from __future__ import absolute_import 81 | from __future__ import division 82 | from __future__ import print_function 83 | 84 | import glob 85 | import os.path 86 | import sys 87 | import xml.etree.ElementTree as ET 88 | 89 | 90 | class BoundingBox(object): 91 | pass 92 | 93 | 94 | def GetItem(name, root, index=0): 95 | count = 0 96 | for item in root.iter(name): 97 | if count == index: 98 | return item.text 99 | count += 1 100 | # Failed to find "index" occurrence of item. 101 | return -1 102 | 103 | 104 | def GetInt(name, root, index=0): 105 | return int(GetItem(name, root, index)) 106 | 107 | 108 | def FindNumberBoundingBoxes(root): 109 | index = 0 110 | while True: 111 | if GetInt('xmin', root, index) == -1: 112 | break 113 | index += 1 114 | return index 115 | 116 | 117 | def ProcessXMLAnnotation(xml_file): 118 | """Process a single XML file containing a bounding box.""" 119 | # pylint: disable=broad-except 120 | try: 121 | tree = ET.parse(xml_file) 122 | except Exception: 123 | print('Failed to parse: ' + xml_file, file=sys.stderr) 124 | return None 125 | # pylint: enable=broad-except 126 | root = tree.getroot() 127 | 128 | num_boxes = FindNumberBoundingBoxes(root) 129 | boxes = [] 130 | 131 | for index in xrange(num_boxes): 132 | box = BoundingBox() 133 | # Grab the 'index' annotation. 134 | box.xmin = GetInt('xmin', root, index) 135 | box.ymin = GetInt('ymin', root, index) 136 | box.xmax = GetInt('xmax', root, index) 137 | box.ymax = GetInt('ymax', root, index) 138 | 139 | box.width = GetInt('width', root) 140 | box.height = GetInt('height', root) 141 | box.filename = GetItem('filename', root) + '.JPEG' 142 | box.label = GetItem('name', root) 143 | 144 | xmin = float(box.xmin) / float(box.width) 145 | xmax = float(box.xmax) / float(box.width) 146 | ymin = float(box.ymin) / float(box.height) 147 | ymax = float(box.ymax) / float(box.height) 148 | 149 | # Some images contain bounding box annotations that 150 | # extend outside of the supplied image. See, e.g. 151 | # n03127925/n03127925_147.xml 152 | # Additionally, for some bounding boxes, the min > max 153 | # or the box is entirely outside of the image. 154 | min_x = min(xmin, xmax) 155 | max_x = max(xmin, xmax) 156 | box.xmin_scaled = min(max(min_x, 0.0), 1.0) 157 | box.xmax_scaled = min(max(max_x, 0.0), 1.0) 158 | 159 | min_y = min(ymin, ymax) 160 | max_y = max(ymin, ymax) 161 | box.ymin_scaled = min(max(min_y, 0.0), 1.0) 162 | box.ymax_scaled = min(max(max_y, 0.0), 1.0) 163 | 164 | boxes.append(box) 165 | 166 | return boxes 167 | 168 | if __name__ == '__main__': 169 | if len(sys.argv) < 2 or len(sys.argv) > 3: 170 | print('Invalid usage\n' 171 | 'usage: process_bounding_boxes.py [synsets-file]', 172 | file=sys.stderr) 173 | sys.exit(-1) 174 | 175 | xml_files = glob.glob(sys.argv[1] + '/*/*.xml') 176 | print('Identified %d XML files in %s' % (len(xml_files), sys.argv[1]), 177 | file=sys.stderr) 178 | 179 | if len(sys.argv) == 3: 180 | labels = set([l.strip() for l in open(sys.argv[2]).readlines()]) 181 | print('Identified %d synset IDs in %s' % (len(labels), sys.argv[2]), 182 | file=sys.stderr) 183 | else: 184 | labels = None 185 | 186 | skipped_boxes = 0 187 | skipped_files = 0 188 | saved_boxes = 0 189 | saved_files = 0 190 | for file_index, one_file in enumerate(xml_files): 191 | # Example: <...>/n06470073/n00141669_6790.xml 192 | label = os.path.basename(os.path.dirname(one_file)) 193 | 194 | # Determine if the annotation is from an ImageNet Challenge label. 195 | if labels is not None and label not in labels: 196 | skipped_files += 1 197 | continue 198 | 199 | bboxes = ProcessXMLAnnotation(one_file) 200 | assert bboxes is not None, 'No bounding boxes found in ' + one_file 201 | 202 | found_box = False 203 | for bbox in bboxes: 204 | if labels is not None: 205 | if bbox.label != label: 206 | # Note: There is a slight bug in the bounding box annotation data. 207 | # Many of the dog labels have the human label 'Scottish_deerhound' 208 | # instead of the synset ID 'n02092002' in the bbox.label field. As a 209 | # simple hack to overcome this issue, we only exclude bbox labels 210 | # *which are synset ID's* that do not match original synset label for 211 | # the XML file. 212 | if bbox.label in labels: 213 | skipped_boxes += 1 214 | continue 215 | 216 | # Guard against improperly specified boxes. 217 | if (bbox.xmin_scaled >= bbox.xmax_scaled or 218 | bbox.ymin_scaled >= bbox.ymax_scaled): 219 | skipped_boxes += 1 220 | continue 221 | 222 | # Note bbox.filename occasionally contains '%s' in the name. This is 223 | # data set noise that is fixed by just using the basename of the XML file. 224 | image_filename = os.path.splitext(os.path.basename(one_file))[0] 225 | print('%s.JPEG,%.4f,%.4f,%.4f,%.4f' % 226 | (image_filename, 227 | bbox.xmin_scaled, bbox.ymin_scaled, 228 | bbox.xmax_scaled, bbox.ymax_scaled)) 229 | 230 | saved_boxes += 1 231 | found_box = True 232 | if found_box: 233 | saved_files += 1 234 | else: 235 | skipped_files += 1 236 | 237 | if not file_index % 5000: 238 | print('--> processed %d of %d XML files.' % 239 | (file_index + 1, len(xml_files)), 240 | file=sys.stderr) 241 | print('--> skipped %d boxes and %d XML files.' % 242 | (skipped_boxes, skipped_files), file=sys.stderr) 243 | 244 | print('Finished processing %d XML files.' % len(xml_files), file=sys.stderr) 245 | print('Skipped %d XML files not in ImageNet Challenge.' % skipped_files, 246 | file=sys.stderr) 247 | print('Skipped %d bounding boxes not in ImageNet Challenge.' % skipped_boxes, 248 | file=sys.stderr) 249 | print('Wrote %d bounding boxes from %d annotated images.' % 250 | (saved_boxes, saved_files), 251 | file=sys.stderr) 252 | print('Finished.', file=sys.stderr) -------------------------------------------------------------------------------- /datasets/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /datasets/imagenet_lsvrc_2015_synsets.txt: -------------------------------------------------------------------------------- 1 | n01440764 2 | n01443537 3 | n01484850 4 | n01491361 5 | n01494475 6 | n01496331 7 | n01498041 8 | n01514668 9 | n01514859 10 | n01518878 11 | n01530575 12 | n01531178 13 | n01532829 14 | n01534433 15 | n01537544 16 | n01558993 17 | n01560419 18 | n01580077 19 | n01582220 20 | n01592084 21 | n01601694 22 | n01608432 23 | n01614925 24 | n01616318 25 | n01622779 26 | n01629819 27 | n01630670 28 | n01631663 29 | n01632458 30 | n01632777 31 | n01641577 32 | n01644373 33 | n01644900 34 | n01664065 35 | n01665541 36 | n01667114 37 | n01667778 38 | n01669191 39 | n01675722 40 | n01677366 41 | n01682714 42 | n01685808 43 | n01687978 44 | n01688243 45 | n01689811 46 | n01692333 47 | n01693334 48 | n01694178 49 | n01695060 50 | n01697457 51 | n01698640 52 | n01704323 53 | n01728572 54 | n01728920 55 | n01729322 56 | n01729977 57 | n01734418 58 | n01735189 59 | n01737021 60 | n01739381 61 | n01740131 62 | n01742172 63 | n01744401 64 | n01748264 65 | n01749939 66 | n01751748 67 | n01753488 68 | n01755581 69 | n01756291 70 | n01768244 71 | n01770081 72 | n01770393 73 | n01773157 74 | n01773549 75 | n01773797 76 | n01774384 77 | n01774750 78 | n01775062 79 | n01776313 80 | n01784675 81 | n01795545 82 | n01796340 83 | n01797886 84 | n01798484 85 | n01806143 86 | n01806567 87 | n01807496 88 | n01817953 89 | n01818515 90 | n01819313 91 | n01820546 92 | n01824575 93 | n01828970 94 | n01829413 95 | n01833805 96 | n01843065 97 | n01843383 98 | n01847000 99 | n01855032 100 | n01855672 101 | n01860187 102 | n01871265 103 | n01872401 104 | n01873310 105 | n01877812 106 | n01882714 107 | n01883070 108 | n01910747 109 | n01914609 110 | n01917289 111 | n01924916 112 | n01930112 113 | n01943899 114 | n01944390 115 | n01945685 116 | n01950731 117 | n01955084 118 | n01968897 119 | n01978287 120 | n01978455 121 | n01980166 122 | n01981276 123 | n01983481 124 | n01984695 125 | n01985128 126 | n01986214 127 | n01990800 128 | n02002556 129 | n02002724 130 | n02006656 131 | n02007558 132 | n02009229 133 | n02009912 134 | n02011460 135 | n02012849 136 | n02013706 137 | n02017213 138 | n02018207 139 | n02018795 140 | n02025239 141 | n02027492 142 | n02028035 143 | n02033041 144 | n02037110 145 | n02051845 146 | n02056570 147 | n02058221 148 | n02066245 149 | n02071294 150 | n02074367 151 | n02077923 152 | n02085620 153 | n02085782 154 | n02085936 155 | n02086079 156 | n02086240 157 | n02086646 158 | n02086910 159 | n02087046 160 | n02087394 161 | n02088094 162 | n02088238 163 | n02088364 164 | n02088466 165 | n02088632 166 | n02089078 167 | n02089867 168 | n02089973 169 | n02090379 170 | n02090622 171 | n02090721 172 | n02091032 173 | n02091134 174 | n02091244 175 | n02091467 176 | n02091635 177 | n02091831 178 | n02092002 179 | n02092339 180 | n02093256 181 | n02093428 182 | n02093647 183 | n02093754 184 | n02093859 185 | n02093991 186 | n02094114 187 | n02094258 188 | n02094433 189 | n02095314 190 | n02095570 191 | n02095889 192 | n02096051 193 | n02096177 194 | n02096294 195 | n02096437 196 | n02096585 197 | n02097047 198 | n02097130 199 | n02097209 200 | n02097298 201 | n02097474 202 | n02097658 203 | n02098105 204 | n02098286 205 | n02098413 206 | n02099267 207 | n02099429 208 | n02099601 209 | n02099712 210 | n02099849 211 | n02100236 212 | n02100583 213 | n02100735 214 | n02100877 215 | n02101006 216 | n02101388 217 | n02101556 218 | n02102040 219 | n02102177 220 | n02102318 221 | n02102480 222 | n02102973 223 | n02104029 224 | n02104365 225 | n02105056 226 | n02105162 227 | n02105251 228 | n02105412 229 | n02105505 230 | n02105641 231 | n02105855 232 | n02106030 233 | n02106166 234 | n02106382 235 | n02106550 236 | n02106662 237 | n02107142 238 | n02107312 239 | n02107574 240 | n02107683 241 | n02107908 242 | n02108000 243 | n02108089 244 | n02108422 245 | n02108551 246 | n02108915 247 | n02109047 248 | n02109525 249 | n02109961 250 | n02110063 251 | n02110185 252 | n02110341 253 | n02110627 254 | n02110806 255 | n02110958 256 | n02111129 257 | n02111277 258 | n02111500 259 | n02111889 260 | n02112018 261 | n02112137 262 | n02112350 263 | n02112706 264 | n02113023 265 | n02113186 266 | n02113624 267 | n02113712 268 | n02113799 269 | n02113978 270 | n02114367 271 | n02114548 272 | n02114712 273 | n02114855 274 | n02115641 275 | n02115913 276 | n02116738 277 | n02117135 278 | n02119022 279 | n02119789 280 | n02120079 281 | n02120505 282 | n02123045 283 | n02123159 284 | n02123394 285 | n02123597 286 | n02124075 287 | n02125311 288 | n02127052 289 | n02128385 290 | n02128757 291 | n02128925 292 | n02129165 293 | n02129604 294 | n02130308 295 | n02132136 296 | n02133161 297 | n02134084 298 | n02134418 299 | n02137549 300 | n02138441 301 | n02165105 302 | n02165456 303 | n02167151 304 | n02168699 305 | n02169497 306 | n02172182 307 | n02174001 308 | n02177972 309 | n02190166 310 | n02206856 311 | n02219486 312 | n02226429 313 | n02229544 314 | n02231487 315 | n02233338 316 | n02236044 317 | n02256656 318 | n02259212 319 | n02264363 320 | n02268443 321 | n02268853 322 | n02276258 323 | n02277742 324 | n02279972 325 | n02280649 326 | n02281406 327 | n02281787 328 | n02317335 329 | n02319095 330 | n02321529 331 | n02325366 332 | n02326432 333 | n02328150 334 | n02342885 335 | n02346627 336 | n02356798 337 | n02361337 338 | n02363005 339 | n02364673 340 | n02389026 341 | n02391049 342 | n02395406 343 | n02396427 344 | n02397096 345 | n02398521 346 | n02403003 347 | n02408429 348 | n02410509 349 | n02412080 350 | n02415577 351 | n02417914 352 | n02422106 353 | n02422699 354 | n02423022 355 | n02437312 356 | n02437616 357 | n02441942 358 | n02442845 359 | n02443114 360 | n02443484 361 | n02444819 362 | n02445715 363 | n02447366 364 | n02454379 365 | n02457408 366 | n02480495 367 | n02480855 368 | n02481823 369 | n02483362 370 | n02483708 371 | n02484975 372 | n02486261 373 | n02486410 374 | n02487347 375 | n02488291 376 | n02488702 377 | n02489166 378 | n02490219 379 | n02492035 380 | n02492660 381 | n02493509 382 | n02493793 383 | n02494079 384 | n02497673 385 | n02500267 386 | n02504013 387 | n02504458 388 | n02509815 389 | n02510455 390 | n02514041 391 | n02526121 392 | n02536864 393 | n02606052 394 | n02607072 395 | n02640242 396 | n02641379 397 | n02643566 398 | n02655020 399 | n02666196 400 | n02667093 401 | n02669723 402 | n02672831 403 | n02676566 404 | n02687172 405 | n02690373 406 | n02692877 407 | n02699494 408 | n02701002 409 | n02704792 410 | n02708093 411 | n02727426 412 | n02730930 413 | n02747177 414 | n02749479 415 | n02769748 416 | n02776631 417 | n02777292 418 | n02782093 419 | n02783161 420 | n02786058 421 | n02787622 422 | n02788148 423 | n02790996 424 | n02791124 425 | n02791270 426 | n02793495 427 | n02794156 428 | n02795169 429 | n02797295 430 | n02799071 431 | n02802426 432 | n02804414 433 | n02804610 434 | n02807133 435 | n02808304 436 | n02808440 437 | n02814533 438 | n02814860 439 | n02815834 440 | n02817516 441 | n02823428 442 | n02823750 443 | n02825657 444 | n02834397 445 | n02835271 446 | n02837789 447 | n02840245 448 | n02841315 449 | n02843684 450 | n02859443 451 | n02860847 452 | n02865351 453 | n02869837 454 | n02870880 455 | n02871525 456 | n02877765 457 | n02879718 458 | n02883205 459 | n02892201 460 | n02892767 461 | n02894605 462 | n02895154 463 | n02906734 464 | n02909870 465 | n02910353 466 | n02916936 467 | n02917067 468 | n02927161 469 | n02930766 470 | n02939185 471 | n02948072 472 | n02950826 473 | n02951358 474 | n02951585 475 | n02963159 476 | n02965783 477 | n02966193 478 | n02966687 479 | n02971356 480 | n02974003 481 | n02977058 482 | n02978881 483 | n02979186 484 | n02980441 485 | n02981792 486 | n02988304 487 | n02992211 488 | n02992529 489 | n02999410 490 | n03000134 491 | n03000247 492 | n03000684 493 | n03014705 494 | n03016953 495 | n03017168 496 | n03018349 497 | n03026506 498 | n03028079 499 | n03032252 500 | n03041632 501 | n03042490 502 | n03045698 503 | n03047690 504 | n03062245 505 | n03063599 506 | n03063689 507 | n03065424 508 | n03075370 509 | n03085013 510 | n03089624 511 | n03095699 512 | n03100240 513 | n03109150 514 | n03110669 515 | n03124043 516 | n03124170 517 | n03125729 518 | n03126707 519 | n03127747 520 | n03127925 521 | n03131574 522 | n03133878 523 | n03134739 524 | n03141823 525 | n03146219 526 | n03160309 527 | n03179701 528 | n03180011 529 | n03187595 530 | n03188531 531 | n03196217 532 | n03197337 533 | n03201208 534 | n03207743 535 | n03207941 536 | n03208938 537 | n03216828 538 | n03218198 539 | n03220513 540 | n03223299 541 | n03240683 542 | n03249569 543 | n03250847 544 | n03255030 545 | n03259280 546 | n03271574 547 | n03272010 548 | n03272562 549 | n03290653 550 | n03291819 551 | n03297495 552 | n03314780 553 | n03325584 554 | n03337140 555 | n03344393 556 | n03345487 557 | n03347037 558 | n03355925 559 | n03372029 560 | n03376595 561 | n03379051 562 | n03384352 563 | n03388043 564 | n03388183 565 | n03388549 566 | n03393912 567 | n03394916 568 | n03400231 569 | n03404251 570 | n03417042 571 | n03424325 572 | n03425413 573 | n03443371 574 | n03444034 575 | n03445777 576 | n03445924 577 | n03447447 578 | n03447721 579 | n03450230 580 | n03452741 581 | n03457902 582 | n03459775 583 | n03461385 584 | n03467068 585 | n03476684 586 | n03476991 587 | n03478589 588 | n03481172 589 | n03482405 590 | n03483316 591 | n03485407 592 | n03485794 593 | n03492542 594 | n03494278 595 | n03495258 596 | n03496892 597 | n03498962 598 | n03527444 599 | n03529860 600 | n03530642 601 | n03532672 602 | n03534580 603 | n03535780 604 | n03538406 605 | n03544143 606 | n03584254 607 | n03584829 608 | n03590841 609 | n03594734 610 | n03594945 611 | n03595614 612 | n03598930 613 | n03599486 614 | n03602883 615 | n03617480 616 | n03623198 617 | n03627232 618 | n03630383 619 | n03633091 620 | n03637318 621 | n03642806 622 | n03649909 623 | n03657121 624 | n03658185 625 | n03661043 626 | n03662601 627 | n03666591 628 | n03670208 629 | n03673027 630 | n03676483 631 | n03680355 632 | n03690938 633 | n03691459 634 | n03692522 635 | n03697007 636 | n03706229 637 | n03709823 638 | n03710193 639 | n03710637 640 | n03710721 641 | n03717622 642 | n03720891 643 | n03721384 644 | n03724870 645 | n03729826 646 | n03733131 647 | n03733281 648 | n03733805 649 | n03742115 650 | n03743016 651 | n03759954 652 | n03761084 653 | n03763968 654 | n03764736 655 | n03769881 656 | n03770439 657 | n03770679 658 | n03773504 659 | n03775071 660 | n03775546 661 | n03776460 662 | n03777568 663 | n03777754 664 | n03781244 665 | n03782006 666 | n03785016 667 | n03786901 668 | n03787032 669 | n03788195 670 | n03788365 671 | n03791053 672 | n03792782 673 | n03792972 674 | n03793489 675 | n03794056 676 | n03796401 677 | n03803284 678 | n03804744 679 | n03814639 680 | n03814906 681 | n03825788 682 | n03832673 683 | n03837869 684 | n03838899 685 | n03840681 686 | n03841143 687 | n03843555 688 | n03854065 689 | n03857828 690 | n03866082 691 | n03868242 692 | n03868863 693 | n03871628 694 | n03873416 695 | n03874293 696 | n03874599 697 | n03876231 698 | n03877472 699 | n03877845 700 | n03884397 701 | n03887697 702 | n03888257 703 | n03888605 704 | n03891251 705 | n03891332 706 | n03895866 707 | n03899768 708 | n03902125 709 | n03903868 710 | n03908618 711 | n03908714 712 | n03916031 713 | n03920288 714 | n03924679 715 | n03929660 716 | n03929855 717 | n03930313 718 | n03930630 719 | n03933933 720 | n03935335 721 | n03937543 722 | n03938244 723 | n03942813 724 | n03944341 725 | n03947888 726 | n03950228 727 | n03954731 728 | n03956157 729 | n03958227 730 | n03961711 731 | n03967562 732 | n03970156 733 | n03976467 734 | n03976657 735 | n03977966 736 | n03980874 737 | n03982430 738 | n03983396 739 | n03991062 740 | n03992509 741 | n03995372 742 | n03998194 743 | n04004767 744 | n04005630 745 | n04008634 746 | n04009552 747 | n04019541 748 | n04023962 749 | n04026417 750 | n04033901 751 | n04033995 752 | n04037443 753 | n04039381 754 | n04040759 755 | n04041544 756 | n04044716 757 | n04049303 758 | n04065272 759 | n04067472 760 | n04069434 761 | n04070727 762 | n04074963 763 | n04081281 764 | n04086273 765 | n04090263 766 | n04099969 767 | n04111531 768 | n04116512 769 | n04118538 770 | n04118776 771 | n04120489 772 | n04125021 773 | n04127249 774 | n04131690 775 | n04133789 776 | n04136333 777 | n04141076 778 | n04141327 779 | n04141975 780 | n04146614 781 | n04147183 782 | n04149813 783 | n04152593 784 | n04153751 785 | n04154565 786 | n04162706 787 | n04179913 788 | n04192698 789 | n04200800 790 | n04201297 791 | n04204238 792 | n04204347 793 | n04208210 794 | n04209133 795 | n04209239 796 | n04228054 797 | n04229816 798 | n04235860 799 | n04238763 800 | n04239074 801 | n04243546 802 | n04251144 803 | n04252077 804 | n04252225 805 | n04254120 806 | n04254680 807 | n04254777 808 | n04258138 809 | n04259630 810 | n04263257 811 | n04264628 812 | n04265275 813 | n04266014 814 | n04270147 815 | n04273569 816 | n04275548 817 | n04277352 818 | n04285008 819 | n04286575 820 | n04296562 821 | n04310018 822 | n04311004 823 | n04311174 824 | n04317175 825 | n04325704 826 | n04326547 827 | n04328186 828 | n04330267 829 | n04332243 830 | n04335435 831 | n04336792 832 | n04344873 833 | n04346328 834 | n04347754 835 | n04350905 836 | n04355338 837 | n04355933 838 | n04356056 839 | n04357314 840 | n04366367 841 | n04367480 842 | n04370456 843 | n04371430 844 | n04371774 845 | n04372370 846 | n04376876 847 | n04380533 848 | n04389033 849 | n04392985 850 | n04398044 851 | n04399382 852 | n04404412 853 | n04409515 854 | n04417672 855 | n04418357 856 | n04423845 857 | n04428191 858 | n04429376 859 | n04435653 860 | n04442312 861 | n04443257 862 | n04447861 863 | n04456115 864 | n04458633 865 | n04461696 866 | n04462240 867 | n04465501 868 | n04467665 869 | n04476259 870 | n04479046 871 | n04482393 872 | n04483307 873 | n04485082 874 | n04486054 875 | n04487081 876 | n04487394 877 | n04493381 878 | n04501370 879 | n04505470 880 | n04507155 881 | n04509417 882 | n04515003 883 | n04517823 884 | n04522168 885 | n04523525 886 | n04525038 887 | n04525305 888 | n04532106 889 | n04532670 890 | n04536866 891 | n04540053 892 | n04542943 893 | n04548280 894 | n04548362 895 | n04550184 896 | n04552348 897 | n04553703 898 | n04554684 899 | n04557648 900 | n04560804 901 | n04562935 902 | n04579145 903 | n04579432 904 | n04584207 905 | n04589890 906 | n04590129 907 | n04591157 908 | n04591713 909 | n04592741 910 | n04596742 911 | n04597913 912 | n04599235 913 | n04604644 914 | n04606251 915 | n04612504 916 | n04613696 917 | n06359193 918 | n06596364 919 | n06785654 920 | n06794110 921 | n06874185 922 | n07248320 923 | n07565083 924 | n07579787 925 | n07583066 926 | n07584110 927 | n07590611 928 | n07613480 929 | n07614500 930 | n07615774 931 | n07684084 932 | n07693725 933 | n07695742 934 | n07697313 935 | n07697537 936 | n07711569 937 | n07714571 938 | n07714990 939 | n07715103 940 | n07716358 941 | n07716906 942 | n07717410 943 | n07717556 944 | n07718472 945 | n07718747 946 | n07720875 947 | n07730033 948 | n07734744 949 | n07742313 950 | n07745940 951 | n07747607 952 | n07749582 953 | n07753113 954 | n07753275 955 | n07753592 956 | n07754684 957 | n07760859 958 | n07768694 959 | n07802026 960 | n07831146 961 | n07836838 962 | n07860988 963 | n07871810 964 | n07873807 965 | n07875152 966 | n07880968 967 | n07892512 968 | n07920052 969 | n07930864 970 | n07932039 971 | n09193705 972 | n09229709 973 | n09246464 974 | n09256479 975 | n09288635 976 | n09332890 977 | n09399592 978 | n09421951 979 | n09428293 980 | n09468604 981 | n09472597 982 | n09835506 983 | n10148035 984 | n10565667 985 | n11879895 986 | n11939491 987 | n12057211 988 | n12144580 989 | n12267677 990 | n12620546 991 | n12768682 992 | n12985857 993 | n12998815 994 | n13037406 995 | n13040303 996 | n13044778 997 | n13052670 998 | n13054560 999 | n13133613 1000 | n15075141 1001 | -------------------------------------------------------------------------------- /datasets/build_imagenet_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Converts ImageNet data to TFRecords file format with Example protos. 17 | 18 | The raw ImageNet data set is expected to reside in JPEG files located in the 19 | following directory structure. 20 | 21 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 22 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 23 | ... 24 | 25 | where 'n01440764' is the unique synset label associated with 26 | these images. 27 | 28 | The training data set consists of 1000 sub-directories (i.e. labels) 29 | each containing 1200 JPEG images for a total of 1.2M JPEG images. 30 | 31 | The evaluation data set consists of 1000 sub-directories (i.e. labels) 32 | each containing 50 JPEG images for a total of 50K JPEG images. 33 | 34 | This TensorFlow script converts the training and evaluation data into 35 | a sharded data set consisting of 1024 and 128 TFRecord files, respectively. 36 | 37 | train_directory/train-00000-of-01024 38 | train_directory/train-00001-of-01024 39 | ... 40 | train_directory/train-01023-of-01024 41 | 42 | and 43 | 44 | validation_directory/validation-00000-of-00128 45 | validation_directory/validation-00001-of-00128 46 | ... 47 | validation_directory/validation-00127-of-00128 48 | 49 | Each validation TFRecord file contains ~390 records. Each training TFREcord 50 | file contains ~1250 records. Each record within the TFRecord file is a 51 | serialized Example proto. The Example proto contains the following fields: 52 | 53 | image/encoded: string containing JPEG encoded image in RGB colorspace 54 | image/height: integer, image height in pixels 55 | image/width: integer, image width in pixels 56 | image/colorspace: string, specifying the colorspace, always 'RGB' 57 | image/channels: integer, specifying the number of channels, always 3 58 | image/format: string, specifying the format, always 'JPEG' 59 | 60 | image/filename: string containing the basename of the image file 61 | e.g. 'n01440764_10026.JPEG' or 'ILSVRC2012_val_00000293.JPEG' 62 | image/class/label: integer specifying the index in a classification layer. 63 | The label ranges from [1, 1000] where 0 is not used. 64 | image/class/synset: string specifying the unique ID of the label, 65 | e.g. 'n01440764' 66 | image/class/text: string specifying the human-readable version of the label 67 | e.g. 'red fox, Vulpes vulpes' 68 | 69 | image/object/bbox/xmin: list of integers specifying the 0+ human annotated 70 | bounding boxes 71 | image/object/bbox/xmax: list of integers specifying the 0+ human annotated 72 | bounding boxes 73 | image/object/bbox/ymin: list of integers specifying the 0+ human annotated 74 | bounding boxes 75 | image/object/bbox/ymax: list of integers specifying the 0+ human annotated 76 | bounding boxes 77 | image/object/bbox/label: integer specifying the index in a classification 78 | layer. The label ranges from [1, 1000] where 0 is not used. Note this is 79 | always identical to the image label. 80 | 81 | Note that the length of xmin is identical to the length of xmax, ymin and ymax 82 | for each example. 83 | 84 | Running this script using 16 threads may take around ~2.5 hours on an HP Z420. 85 | """ 86 | from __future__ import absolute_import 87 | from __future__ import division 88 | from __future__ import print_function 89 | 90 | from datetime import datetime 91 | import os 92 | import random 93 | import sys 94 | import threading 95 | 96 | import numpy as np 97 | import six 98 | import tensorflow as tf 99 | 100 | tf.app.flags.DEFINE_string('train_directory', '/tmp/', 101 | 'Training data directory') 102 | tf.app.flags.DEFINE_string('validation_directory', '/tmp/', 103 | 'Validation data directory') 104 | tf.app.flags.DEFINE_string('output_directory', '/tmp/', 105 | 'Output data directory') 106 | 107 | tf.app.flags.DEFINE_integer('train_shards', 1024, 108 | 'Number of shards in training TFRecord files.') 109 | tf.app.flags.DEFINE_integer('validation_shards', 128, 110 | 'Number of shards in validation TFRecord files.') 111 | 112 | tf.app.flags.DEFINE_integer('num_threads', 32, 113 | 'Number of threads to preprocess the images.') 114 | 115 | # The labels file contains a list of valid labels are held in this file. 116 | # Assumes that the file contains entries as such: 117 | # n01440764 118 | # n01443537 119 | # n01484850 120 | # where each line corresponds to a label expressed as a synset. We map 121 | # each synset contained in the file to an integer (based on the alphabetical 122 | # ordering). See below for details. 123 | tf.app.flags.DEFINE_string('labels_file', 124 | 'imagenet_lsvrc_2015_synsets.txt', 125 | 'Labels file') 126 | 127 | # This file containing mapping from synset to human-readable label. 128 | # Assumes each line of the file looks like: 129 | # 130 | # n02119247 black fox 131 | # n02119359 silver fox 132 | # n02119477 red fox, Vulpes fulva 133 | # 134 | # where each line corresponds to a unique mapping. Note that each line is 135 | # formatted as \t. 136 | tf.app.flags.DEFINE_string('imagenet_metadata_file', 137 | 'imagenet_metadata.txt', 138 | 'ImageNet metadata file') 139 | 140 | # This file is the output of process_bounding_box.py 141 | # Assumes each line of the file looks like: 142 | # 143 | # n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940 144 | # 145 | # where each line corresponds to one bounding box annotation associated 146 | # with an image. Each line can be parsed as: 147 | # 148 | # , , , , 149 | # 150 | # Note that there might exist mulitple bounding box annotations associated 151 | # with an image file. 152 | tf.app.flags.DEFINE_string('bounding_box_file', 153 | './imagenet_2012_bounding_boxes.csv', 154 | 'Bounding box file') 155 | 156 | FLAGS = tf.app.flags.FLAGS 157 | 158 | 159 | def _int64_feature(value): 160 | """Wrapper for inserting int64 features into Example proto.""" 161 | if not isinstance(value, list): 162 | value = [value] 163 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 164 | 165 | 166 | def _float_feature(value): 167 | """Wrapper for inserting float features into Example proto.""" 168 | if not isinstance(value, list): 169 | value = [value] 170 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 171 | 172 | 173 | def _bytes_feature(value): 174 | """Wrapper for inserting bytes features into Example proto.""" 175 | if six.PY3 and isinstance(value, six.text_type): 176 | value = six.binary_type(value, encoding='utf-8') 177 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 178 | 179 | 180 | def _convert_to_example(filename, image_buffer, label, synset, human, bbox, 181 | height, width): 182 | """Build an Example proto for an example. 183 | 184 | Args: 185 | filename: string, path to an image file, e.g., '/path/to/example.JPG' 186 | image_buffer: string, JPEG encoding of RGB image 187 | label: integer, identifier for the ground truth for the network 188 | synset: string, unique WordNet ID specifying the label, e.g., 'n02323233' 189 | human: string, human-readable label, e.g., 'red fox, Vulpes vulpes' 190 | bbox: list of bounding boxes; each box is a list of integers 191 | specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong to 192 | the same label as the image label. 193 | height: integer, image height in pixels 194 | width: integer, image width in pixels 195 | Returns: 196 | Example proto 197 | """ 198 | xmin = [] 199 | ymin = [] 200 | xmax = [] 201 | ymax = [] 202 | for b in bbox: 203 | assert len(b) == 4 204 | # pylint: disable=expression-not-assigned 205 | [l.append(point) for l, point in zip([xmin, ymin, xmax, ymax], b)] 206 | # pylint: enable=expression-not-assigned 207 | 208 | colorspace = 'RGB' 209 | channels = 3 210 | image_format = 'JPEG' 211 | 212 | example = tf.train.Example(features=tf.train.Features(feature={ 213 | 'image/height': _int64_feature(height), 214 | 'image/width': _int64_feature(width), 215 | 'image/colorspace': _bytes_feature(colorspace), 216 | 'image/channels': _int64_feature(channels), 217 | 'image/class/label': _int64_feature(label), 218 | 'image/class/synset': _bytes_feature(synset), 219 | 'image/class/text': _bytes_feature(human), 220 | 'image/object/bbox/xmin': _float_feature(xmin), 221 | 'image/object/bbox/xmax': _float_feature(xmax), 222 | 'image/object/bbox/ymin': _float_feature(ymin), 223 | 'image/object/bbox/ymax': _float_feature(ymax), 224 | 'image/object/bbox/label': _int64_feature([label] * len(xmin)), 225 | 'image/format': _bytes_feature(image_format), 226 | 'image/filename': _bytes_feature(os.path.basename(filename)), 227 | 'image/encoded': _bytes_feature(image_buffer)})) 228 | return example 229 | 230 | 231 | class ImageCoder(object): 232 | """Helper class that provides TensorFlow image coding utilities.""" 233 | 234 | def __init__(self): 235 | # Create a single Session to run all image coding calls. 236 | self._sess = tf.Session() 237 | 238 | # Initializes function that converts PNG to JPEG data. 239 | self._png_data = tf.placeholder(dtype=tf.string) 240 | image = tf.image.decode_png(self._png_data, channels=3) 241 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) 242 | 243 | # Initializes function that converts CMYK JPEG data to RGB JPEG data. 244 | self._cmyk_data = tf.placeholder(dtype=tf.string) 245 | image = tf.image.decode_jpeg(self._cmyk_data, channels=0) 246 | self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) 247 | 248 | # Initializes function that decodes RGB JPEG data. 249 | self._decode_jpeg_data = tf.placeholder(dtype=tf.string) 250 | self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) 251 | 252 | def png_to_jpeg(self, image_data): 253 | return self._sess.run(self._png_to_jpeg, 254 | feed_dict={self._png_data: image_data}) 255 | 256 | def cmyk_to_rgb(self, image_data): 257 | return self._sess.run(self._cmyk_to_rgb, 258 | feed_dict={self._cmyk_data: image_data}) 259 | 260 | def decode_jpeg(self, image_data): 261 | image = self._sess.run(self._decode_jpeg, 262 | feed_dict={self._decode_jpeg_data: image_data}) 263 | assert len(image.shape) == 3 264 | assert image.shape[2] == 3 265 | return image 266 | 267 | 268 | def _is_png(filename): 269 | """Determine if a file contains a PNG format image. 270 | 271 | Args: 272 | filename: string, path of the image file. 273 | 274 | Returns: 275 | boolean indicating if the image is a PNG. 276 | """ 277 | # File list from: 278 | # https://groups.google.com/forum/embed/?place=forum/torch7#!topic/torch7/fOSTXHIESSU 279 | return 'n02105855_2933.JPEG' in filename 280 | 281 | 282 | def _is_cmyk(filename): 283 | """Determine if file contains a CMYK JPEG format image. 284 | 285 | Args: 286 | filename: string, path of the image file. 287 | 288 | Returns: 289 | boolean indicating if the image is a JPEG encoded with CMYK color space. 290 | """ 291 | # File list from: 292 | # https://github.com/cytsai/ilsvrc-cmyk-image-list 293 | blacklist = ['n01739381_1309.JPEG', 'n02077923_14822.JPEG', 294 | 'n02447366_23489.JPEG', 'n02492035_15739.JPEG', 295 | 'n02747177_10752.JPEG', 'n03018349_4028.JPEG', 296 | 'n03062245_4620.JPEG', 'n03347037_9675.JPEG', 297 | 'n03467068_12171.JPEG', 'n03529860_11437.JPEG', 298 | 'n03544143_17228.JPEG', 'n03633091_5218.JPEG', 299 | 'n03710637_5125.JPEG', 'n03961711_5286.JPEG', 300 | 'n04033995_2932.JPEG', 'n04258138_17003.JPEG', 301 | 'n04264628_27969.JPEG', 'n04336792_7448.JPEG', 302 | 'n04371774_5854.JPEG', 'n04596742_4225.JPEG', 303 | 'n07583066_647.JPEG', 'n13037406_4650.JPEG'] 304 | return filename.split('/')[-1] in blacklist 305 | 306 | 307 | def _process_image(filename, coder): 308 | """Process a single image file. 309 | 310 | Args: 311 | filename: string, path to an image file e.g., '/path/to/example.JPG'. 312 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 313 | Returns: 314 | image_buffer: string, JPEG encoding of RGB image. 315 | height: integer, image height in pixels. 316 | width: integer, image width in pixels. 317 | """ 318 | # Read the image file. 319 | with tf.gfile.FastGFile(filename, 'rb') as f: 320 | image_data = f.read() 321 | 322 | # Clean the dirty data. 323 | if _is_png(filename): 324 | # 1 image is a PNG. 325 | print('Converting PNG to JPEG for %s' % filename) 326 | image_data = coder.png_to_jpeg(image_data) 327 | elif _is_cmyk(filename): 328 | # 22 JPEG images are in CMYK colorspace. 329 | print('Converting CMYK to RGB for %s' % filename) 330 | image_data = coder.cmyk_to_rgb(image_data) 331 | 332 | # Decode the RGB JPEG. 333 | image = coder.decode_jpeg(image_data) 334 | 335 | # Check that image converted to RGB 336 | assert len(image.shape) == 3 337 | height = image.shape[0] 338 | width = image.shape[1] 339 | assert image.shape[2] == 3 340 | 341 | return image_data, height, width 342 | 343 | 344 | def _process_image_files_batch(coder, thread_index, ranges, name, filenames, 345 | synsets, labels, humans, bboxes, num_shards): 346 | """Processes and saves list of images as TFRecord in 1 thread. 347 | 348 | Args: 349 | coder: instance of ImageCoder to provide TensorFlow image coding utils. 350 | thread_index: integer, unique batch to run index is within [0, len(ranges)). 351 | ranges: list of pairs of integers specifying ranges of each batches to 352 | analyze in parallel. 353 | name: string, unique identifier specifying the data set 354 | filenames: list of strings; each string is a path to an image file 355 | synsets: list of strings; each string is a unique WordNet ID 356 | labels: list of integer; each integer identifies the ground truth 357 | humans: list of strings; each string is a human-readable label 358 | bboxes: list of bounding boxes for each image. Note that each entry in this 359 | list might contain from 0+ entries corresponding to the number of bounding 360 | box annotations for the image. 361 | num_shards: integer number of shards for this data set. 362 | """ 363 | # Each thread produces N shards where N = int(num_shards / num_threads). 364 | # For instance, if num_shards = 128, and the num_threads = 2, then the first 365 | # thread would produce shards [0, 64). 366 | num_threads = len(ranges) 367 | assert not num_shards % num_threads 368 | num_shards_per_batch = int(num_shards / num_threads) 369 | 370 | shard_ranges = np.linspace(ranges[thread_index][0], 371 | ranges[thread_index][1], 372 | num_shards_per_batch + 1).astype(int) 373 | num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0] 374 | 375 | counter = 0 376 | for s in range(num_shards_per_batch): 377 | # Generate a sharded version of the file name, e.g. 'train-00002-of-00010' 378 | shard = thread_index * num_shards_per_batch + s 379 | output_filename = '%s-%.5d-of-%.5d' % (name, shard, num_shards) 380 | output_file = os.path.join(FLAGS.output_directory, output_filename) 381 | writer = tf.python_io.TFRecordWriter(output_file) 382 | 383 | shard_counter = 0 384 | files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int) 385 | for i in files_in_shard: 386 | filename = filenames[i] 387 | label = labels[i] 388 | synset = synsets[i] 389 | human = humans[i] 390 | bbox = bboxes[i] 391 | 392 | image_buffer, height, width = _process_image(filename, coder) 393 | 394 | example = _convert_to_example(filename, image_buffer, label, 395 | synset, human, bbox, 396 | height, width) 397 | writer.write(example.SerializeToString()) 398 | shard_counter += 1 399 | counter += 1 400 | 401 | if not counter % 1000: 402 | print('%s [thread %d]: Processed %d of %d images in thread batch.' % 403 | (datetime.now(), thread_index, counter, num_files_in_thread)) 404 | sys.stdout.flush() 405 | 406 | writer.close() 407 | print('%s [thread %d]: Wrote %d images to %s' % 408 | (datetime.now(), thread_index, shard_counter, output_file)) 409 | sys.stdout.flush() 410 | shard_counter = 0 411 | print('%s [thread %d]: Wrote %d images to %d shards.' % 412 | (datetime.now(), thread_index, counter, num_files_in_thread)) 413 | sys.stdout.flush() 414 | 415 | 416 | def _process_image_files(name, filenames, synsets, labels, humans, 417 | bboxes, num_shards): 418 | """Process and save list of images as TFRecord of Example protos. 419 | 420 | Args: 421 | name: string, unique identifier specifying the data set 422 | filenames: list of strings; each string is a path to an image file 423 | synsets: list of strings; each string is a unique WordNet ID 424 | labels: list of integer; each integer identifies the ground truth 425 | humans: list of strings; each string is a human-readable label 426 | bboxes: list of bounding boxes for each image. Note that each entry in this 427 | list might contain from 0+ entries corresponding to the number of bounding 428 | box annotations for the image. 429 | num_shards: integer number of shards for this data set. 430 | """ 431 | assert len(filenames) == len(synsets) 432 | assert len(filenames) == len(labels) 433 | assert len(filenames) == len(humans) 434 | assert len(filenames) == len(bboxes) 435 | 436 | # Break all images into batches with a [ranges[i][0], ranges[i][1]]. 437 | spacing = np.linspace(0, len(filenames), FLAGS.num_threads + 1).astype(np.int) 438 | ranges = [] 439 | threads = [] 440 | for i in range(len(spacing) - 1): 441 | ranges.append([spacing[i], spacing[i + 1]]) 442 | 443 | # Launch a thread for each batch. 444 | print('Launching %d threads for spacings: %s' % (FLAGS.num_threads, ranges)) 445 | sys.stdout.flush() 446 | 447 | # Create a mechanism for monitoring when all threads are finished. 448 | coord = tf.train.Coordinator() 449 | 450 | # Create a generic TensorFlow-based utility for converting all image codings. 451 | coder = ImageCoder() 452 | 453 | threads = [] 454 | for thread_index in range(len(ranges)): 455 | args = (coder, thread_index, ranges, name, filenames, 456 | synsets, labels, humans, bboxes, num_shards) 457 | t = threading.Thread(target=_process_image_files_batch, args=args) 458 | t.start() 459 | threads.append(t) 460 | 461 | # Wait for all the threads to terminate. 462 | coord.join(threads) 463 | print('%s: Finished writing all %d images in data set.' % 464 | (datetime.now(), len(filenames))) 465 | sys.stdout.flush() 466 | 467 | 468 | def _find_image_files(data_dir, labels_file): 469 | """Build a list of all images files and labels in the data set. 470 | 471 | Args: 472 | data_dir: string, path to the root directory of images. 473 | 474 | Assumes that the ImageNet data set resides in JPEG files located in 475 | the following directory structure. 476 | 477 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 478 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 479 | 480 | where 'n01440764' is the unique synset label associated with these images. 481 | 482 | labels_file: string, path to the labels file. 483 | 484 | The list of valid labels are held in this file. Assumes that the file 485 | contains entries as such: 486 | n01440764 487 | n01443537 488 | n01484850 489 | where each line corresponds to a label expressed as a synset. We map 490 | each synset contained in the file to an integer (based on the alphabetical 491 | ordering) starting with the integer 1 corresponding to the synset 492 | contained in the first line. 493 | 494 | The reason we start the integer labels at 1 is to reserve label 0 as an 495 | unused background class. 496 | 497 | Returns: 498 | filenames: list of strings; each string is a path to an image file. 499 | synsets: list of strings; each string is a unique WordNet ID. 500 | labels: list of integer; each integer identifies the ground truth. 501 | """ 502 | print('Determining list of input files and labels from %s.' % data_dir) 503 | challenge_synsets = [l.strip() for l in 504 | tf.gfile.FastGFile(labels_file, 'r').readlines()] 505 | 506 | labels = [] 507 | filenames = [] 508 | synsets = [] 509 | 510 | # Leave label index 0 empty as a background class. 511 | label_index = 1 512 | 513 | # Construct the list of JPEG files and labels. 514 | for synset in challenge_synsets: 515 | jpeg_file_path = '%s/%s/*.JPEG' % (data_dir, synset) 516 | matching_files = tf.gfile.Glob(jpeg_file_path) 517 | 518 | labels.extend([label_index] * len(matching_files)) 519 | synsets.extend([synset] * len(matching_files)) 520 | filenames.extend(matching_files) 521 | 522 | if not label_index % 100: 523 | print('Finished finding files in %d of %d classes.' % ( 524 | label_index, len(challenge_synsets))) 525 | label_index += 1 526 | 527 | # Shuffle the ordering of all image files in order to guarantee 528 | # random ordering of the images with respect to label in the 529 | # saved TFRecord files. Make the randomization repeatable. 530 | shuffled_index = list(range(len(filenames))) 531 | random.seed(12345) 532 | random.shuffle(shuffled_index) 533 | 534 | filenames = [filenames[i] for i in shuffled_index] 535 | synsets = [synsets[i] for i in shuffled_index] 536 | labels = [labels[i] for i in shuffled_index] 537 | 538 | print('Found %d JPEG files across %d labels inside %s.' % 539 | (len(filenames), len(challenge_synsets), data_dir)) 540 | return filenames, synsets, labels 541 | 542 | 543 | def _find_human_readable_labels(synsets, synset_to_human): 544 | """Build a list of human-readable labels. 545 | 546 | Args: 547 | synsets: list of strings; each string is a unique WordNet ID. 548 | synset_to_human: dict of synset to human labels, e.g., 549 | 'n02119022' --> 'red fox, Vulpes vulpes' 550 | 551 | Returns: 552 | List of human-readable strings corresponding to each synset. 553 | """ 554 | humans = [] 555 | for s in synsets: 556 | assert s in synset_to_human, ('Failed to find: %s' % s) 557 | humans.append(synset_to_human[s]) 558 | return humans 559 | 560 | 561 | def _find_image_bounding_boxes(filenames, image_to_bboxes): 562 | """Find the bounding boxes for a given image file. 563 | 564 | Args: 565 | filenames: list of strings; each string is a path to an image file. 566 | image_to_bboxes: dictionary mapping image file names to a list of 567 | bounding boxes. This list contains 0+ bounding boxes. 568 | Returns: 569 | List of bounding boxes for each image. Note that each entry in this 570 | list might contain from 0+ entries corresponding to the number of bounding 571 | box annotations for the image. 572 | """ 573 | num_image_bbox = 0 574 | bboxes = [] 575 | for f in filenames: 576 | basename = os.path.basename(f) 577 | if basename in image_to_bboxes: 578 | bboxes.append(image_to_bboxes[basename]) 579 | num_image_bbox += 1 580 | else: 581 | bboxes.append([]) 582 | print('Found %d images with bboxes out of %d images' % ( 583 | num_image_bbox, len(filenames))) 584 | return bboxes 585 | 586 | 587 | def _process_dataset(name, directory, num_shards, synset_to_human, 588 | image_to_bboxes): 589 | """Process a complete data set and save it as a TFRecord. 590 | 591 | Args: 592 | name: string, unique identifier specifying the data set. 593 | directory: string, root path to the data set. 594 | num_shards: integer number of shards for this data set. 595 | synset_to_human: dict of synset to human labels, e.g., 596 | 'n02119022' --> 'red fox, Vulpes vulpes' 597 | image_to_bboxes: dictionary mapping image file names to a list of 598 | bounding boxes. This list contains 0+ bounding boxes. 599 | """ 600 | filenames, synsets, labels = _find_image_files(directory, FLAGS.labels_file) 601 | humans = _find_human_readable_labels(synsets, synset_to_human) 602 | bboxes = _find_image_bounding_boxes(filenames, image_to_bboxes) 603 | _process_image_files(name, filenames, synsets, labels, 604 | humans, bboxes, num_shards) 605 | 606 | 607 | def _build_synset_lookup(imagenet_metadata_file): 608 | """Build lookup for synset to human-readable label. 609 | 610 | Args: 611 | imagenet_metadata_file: string, path to file containing mapping from 612 | synset to human-readable label. 613 | 614 | Assumes each line of the file looks like: 615 | 616 | n02119247 black fox 617 | n02119359 silver fox 618 | n02119477 red fox, Vulpes fulva 619 | 620 | where each line corresponds to a unique mapping. Note that each line is 621 | formatted as \t. 622 | 623 | Returns: 624 | Dictionary of synset to human labels, such as: 625 | 'n02119022' --> 'red fox, Vulpes vulpes' 626 | """ 627 | lines = tf.gfile.FastGFile(imagenet_metadata_file, 'r').readlines() 628 | synset_to_human = {} 629 | for l in lines: 630 | if l: 631 | parts = l.strip().split('\t') 632 | assert len(parts) == 2 633 | synset = parts[0] 634 | human = parts[1] 635 | synset_to_human[synset] = human 636 | return synset_to_human 637 | 638 | 639 | def _build_bounding_box_lookup(bounding_box_file): 640 | """Build a lookup from image file to bounding boxes. 641 | 642 | Args: 643 | bounding_box_file: string, path to file with bounding boxes annotations. 644 | 645 | Assumes each line of the file looks like: 646 | 647 | n00007846_64193.JPEG,0.0060,0.2620,0.7545,0.9940 648 | 649 | where each line corresponds to one bounding box annotation associated 650 | with an image. Each line can be parsed as: 651 | 652 | , , , , 653 | 654 | Note that there might exist mulitple bounding box annotations associated 655 | with an image file. This file is the output of process_bounding_boxes.py. 656 | 657 | Returns: 658 | Dictionary mapping image file names to a list of bounding boxes. This list 659 | contains 0+ bounding boxes. 660 | """ 661 | lines = tf.gfile.FastGFile(bounding_box_file, 'r').readlines() 662 | images_to_bboxes = {} 663 | num_bbox = 0 664 | num_image = 0 665 | for l in lines: 666 | if l: 667 | parts = l.split(',') 668 | assert len(parts) == 5, ('Failed to parse: %s' % l) 669 | filename = parts[0] 670 | xmin = float(parts[1]) 671 | ymin = float(parts[2]) 672 | xmax = float(parts[3]) 673 | ymax = float(parts[4]) 674 | box = [xmin, ymin, xmax, ymax] 675 | 676 | if filename not in images_to_bboxes: 677 | images_to_bboxes[filename] = [] 678 | num_image += 1 679 | images_to_bboxes[filename].append(box) 680 | num_bbox += 1 681 | 682 | print('Successfully read %d bounding boxes ' 683 | 'across %d images.' % (num_bbox, num_image)) 684 | return images_to_bboxes 685 | 686 | 687 | def main(unused_argv): 688 | assert not FLAGS.train_shards % FLAGS.num_threads, ( 689 | 'Please make the FLAGS.num_threads commensurate with FLAGS.train_shards') 690 | assert not FLAGS.validation_shards % FLAGS.num_threads, ( 691 | 'Please make the FLAGS.num_threads commensurate with ' 692 | 'FLAGS.validation_shards') 693 | print('Saving results to %s' % FLAGS.output_directory) 694 | 695 | # Build a map from synset to human-readable label. 696 | synset_to_human = _build_synset_lookup(FLAGS.imagenet_metadata_file) 697 | image_to_bboxes = _build_bounding_box_lookup(FLAGS.bounding_box_file) 698 | 699 | # Run it! 700 | _process_dataset('validation', FLAGS.validation_directory, 701 | FLAGS.validation_shards, synset_to_human, image_to_bboxes) 702 | _process_dataset('train', FLAGS.train_directory, FLAGS.train_shards, 703 | synset_to_human, image_to_bboxes) 704 | 705 | 706 | if __name__ == '__main__': 707 | tf.app.run() --------------------------------------------------------------------------------