├── __init__.py ├── nets ├── __init__.py ├── .DS_Store ├── inception.py ├── nets_factory.py ├── inception_utils.py ├── inception_v1.py └── mobilenet_v1.py ├── datasets ├── __init__.py ├── .DS_Store ├── dataset_factory.py ├── dataset_utils.py ├── utils.py ├── format_market_train.py ├── make_filename_list.py ├── convert_to_tfrecords.py └── reid.py ├── deployment ├── __init__.py ├── .DS_Store └── model_deploy.py ├── preprocessing ├── __init__.py ├── .DS_Store ├── preprocessing_factory.py ├── reid_preprocessing.py └── inception_preprocessing.py ├── DML.png ├── .DS_Store ├── scripts ├── .DS_Store ├── format_and_convert_market.sh ├── evaluate_ind_mobilenet_on_market.sh ├── evaluate_dml_mobilenet_on_market.sh ├── train_ind_mobilenet_on_market.sh └── train_dml_mobilenet_on_market.sh ├── .idea ├── markdown-navigator │ └── profiles_settings.xml ├── vcs.xml ├── misc.xml ├── modules.xml ├── Deep-Mutual-Learning.iml ├── inspectionProfiles │ └── Project_Default.xml ├── markdown-navigator.xml └── workspace.xml ├── LICENSE ├── format_and_convert_data.py ├── README.md ├── eval_image_classifier.py ├── train_image_classifier.py ├── eval_models.py └── train_models.py /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DML.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/DML.png -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/.DS_Store -------------------------------------------------------------------------------- /nets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/nets/.DS_Store -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/scripts/.DS_Store -------------------------------------------------------------------------------- /datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/datasets/.DS_Store -------------------------------------------------------------------------------- /deployment/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/deployment/.DS_Store -------------------------------------------------------------------------------- /preprocessing/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/preprocessing/.DS_Store -------------------------------------------------------------------------------- /.idea/markdown-navigator/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/Deep-Mutual-Learning.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 13 | -------------------------------------------------------------------------------- /scripts/format_and_convert_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # 1. Format the Market-1501 training images with consecutive labels. 5 | # 2. Convert the Market-1501 images into TFRecords. 6 | # 7 | # Usage: 8 | # cd Deep-Mutual-Learning 9 | # ./scripts/format_and_convert_market.sh 10 | 11 | 12 | # Where the Market-1501 images are saved to. 13 | IMAGE_DIR=/path/to/Market-1501/images 14 | 15 | # Where the TFRecord data will be saved to. 16 | TF_DIR=/path/to/market-1501/tfrecords 17 | 18 | 19 | echo "Building the TFRecords of market1501..." 20 | 21 | for split in bounding_box_train bounding_box_test gt_bbox query; do 22 | echo "Processing ${split} ..." 23 | python format_and_convert_data.py \ 24 | --image_dir="$IMAGE_DIR/$split" \ 25 | --output_dir=${TF_DIR} \ 26 | --dataset_name="market1501" \ 27 | --split_name="$split" 28 | 29 | done 30 | 31 | echo "Finished converting all the splits!" 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 YingZhangDUT 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_v1 import inception_v1 23 | from nets.inception_v1 import inception_v1_arg_scope 24 | from nets.inception_v1 import inception_v1_base 25 | # pylint: enable=unused-import 26 | -------------------------------------------------------------------------------- /scripts/evaluate_ind_mobilenet_on_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # Evaluate the MobileNet trained independently on Market-1501 5 | # 6 | # Usage: 7 | # cd Deep-Mutual-Learning 8 | # ./scripts/evaluate_ind_mobilenet_on_market.sh 9 | 10 | 11 | # Where the TFRecords are saved to. 12 | DATASET_DIR=/path/to/market-1501/tfrecords 13 | 14 | # Where the checkpoints are saved to. 15 | DATASET_NAME=market1501 16 | SAVE_NAME=market1501_ind_mobilenet 17 | CKPT_DIR=${SAVE_NAME}/checkpoint 18 | 19 | # Where the results will be saved to. 20 | RESULT_DIR=${SAVE_NAME}/results 21 | 22 | # Model setting 23 | MODEL_NAME=mobilenet_v1 24 | 25 | # Run evaluation. 26 | for split in query bounding_box_test gt_bbox; do 27 | python eval_image_classifier.py \ 28 | --dataset_name=${DATASET_NAME}\ 29 | --split_name="$split" \ 30 | --dataset_dir=${DATASET_DIR} \ 31 | --checkpoint_dir=${CKPT_DIR} \ 32 | --eval_dir=${RESULT_DIR} \ 33 | --model_name=${MODEL_NAME} \ 34 | --preprocessing_name=reid \ 35 | --num_classes=751 \ 36 | --batch_size=1 \ 37 | --num_networks=1 38 | done 39 | -------------------------------------------------------------------------------- /scripts/evaluate_dml_mobilenet_on_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # Evaluate the MobileNets trained with DML on Market-1501 5 | # 6 | # Usage: 7 | # cd Deep-Mutual-Learning 8 | # ./scripts/evaluate_dml_mobilenet_on_market.sh 9 | 10 | 11 | # Where the TFRecords are saved to. 12 | DATASET_DIR=/path/to/market-1501/tfrecords 13 | 14 | # Where the checkpoints are saved to. 15 | DATASET_NAME=market1501 16 | SAVE_NAME=market1501_dml_mobilenet 17 | CKPT_DIR=${SAVE_NAME}/checkpoint 18 | 19 | # Where the results will be saved to. 20 | RESULT_DIR=${SAVE_NAME}/results 21 | 22 | # Model setting 23 | MODEL_NAME=mobilenet_v1,mobilenet_v1 24 | 25 | # Run evaluation. 26 | for split in query bounding_box_test gt_bbox; do 27 | python eval_image_classifier.py \ 28 | --dataset_name=${DATASET_NAME}\ 29 | --split_name="$split" \ 30 | --dataset_dir=${DATASET_DIR} \ 31 | --checkpoint_dir=${CKPT_DIR} \ 32 | --eval_dir=${RESULT_DIR} \ 33 | --model_name=${MODEL_NAME} \ 34 | --preprocessing_name=reid \ 35 | --num_classes=751 \ 36 | --batch_size=1 \ 37 | --num_networks=2 38 | done 39 | -------------------------------------------------------------------------------- /scripts/train_ind_mobilenet_on_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # Training 1 MobileNet independently on Market-1501 5 | # 6 | # Usage: 7 | # cd Deep-Mutual-Learning 8 | # ./scripts/train_ind_mobilenet_on_market.sh 9 | 10 | 11 | # Where the TFRecords are saved to. 12 | DATASET_DIR=/path/to/market-1501/tfrecords 13 | 14 | # Where the checkpoint and logs will be saved to. 15 | DATASET_NAME=market1501 16 | SAVE_NAME=market1501_ind_mobilenet 17 | CKPT_DIR=${SAVE_NAME}/checkpoint 18 | LOG_DIR=${SAVE_NAME}/logs 19 | 20 | # Model setting 21 | MODEL_NAME=mobilenet_v1 22 | SPLIT_NAME=bounding_box_train 23 | 24 | # Run training. 25 | python train_image_classifier.py \ 26 | --dataset_name=${DATASET_NAME}\ 27 | --split_name=${SPLIT_NAME} \ 28 | --dataset_dir=${DATASET_DIR} \ 29 | --checkpoint_dir=${CKPT_DIR} \ 30 | --log_dir=${LOG_DIR} \ 31 | --model_name=${MODEL_NAME} \ 32 | --preprocessing_name=reid \ 33 | --max_number_of_steps=200000 \ 34 | --ckpt_steps=5000 \ 35 | --batch_size=16 \ 36 | --num_classes=751 \ 37 | --optimizer=adam \ 38 | --learning_rate=0.0002 \ 39 | --adam_beta1=0.5 \ 40 | --opt_epsilon=1e-8 \ 41 | --label_smoothing=0.1 \ 42 | --num_networks=1 43 | 44 | -------------------------------------------------------------------------------- /scripts/train_dml_mobilenet_on_market.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # This script performs the following operations: 4 | # Training 2 MobileNets with DML on Market-1501 5 | # 6 | # Usage: 7 | # cd Deep-Mutual-Learning 8 | # ./scripts/train_dml_mobilenet_on_market.sh 9 | 10 | 11 | # Where the TFRecords are saved to. 12 | DATASET_DIR=/path/to/market-1501/tfrecords 13 | 14 | # Where the checkpoint and logs will be saved to. 15 | DATASET_NAME=market1501 16 | SAVE_NAME=market1501_dml_mobilenet 17 | CKPT_DIR=${SAVE_NAME}/checkpoint 18 | LOG_DIR=${SAVE_NAME}/logs 19 | 20 | # Model setting 21 | MODEL_NAME=mobilenet_v1,mobilenet_v1 22 | SPLIT_NAME=bounding_box_train 23 | 24 | # Run training. 25 | python train_image_classifier.py \ 26 | --dataset_name=${DATASET_NAME}\ 27 | --split_name=${SPLIT_NAME} \ 28 | --dataset_dir=${DATASET_DIR} \ 29 | --checkpoint_dir=${CKPT_DIR} \ 30 | --log_dir=${LOG_DIR} \ 31 | --model_name=${MODEL_NAME} \ 32 | --preprocessing_name=reid \ 33 | --max_number_of_steps=200000 \ 34 | --ckpt_steps=5000 \ 35 | --batch_size=16 \ 36 | --num_classes=751 \ 37 | --optimizer=adam \ 38 | --learning_rate=0.0002 \ 39 | --adam_beta1=0.5 \ 40 | --opt_epsilon=1e-8 \ 41 | --label_smoothing=0.1 \ 42 | --num_networks=2 43 | 44 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provide dataset given split name. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | from datasets import reid 10 | 11 | # provider functions might vary on different datasets 12 | datasets_map = { 13 | 'market1501': reid, 14 | } 15 | 16 | 17 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 18 | """Given a dataset name and a split_name returns a Dataset. 19 | 20 | Args: 21 | name: String, the name of the dataset. 22 | split_name: A train/test split name. 23 | dataset_dir: The directory where the dataset files are stored. 24 | file_pattern: The file pattern to use for matching the dataset source files. 25 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 26 | reader defined by each dataset is used. 27 | 28 | Returns: 29 | A `Dataset` class. 30 | 31 | Raises: 32 | ValueError: If the dataset `name` is unknown. 33 | """ 34 | if name not in datasets_map: 35 | raise ValueError('Name of dataset unknown %s' % name) 36 | return datasets_map[name].get_split( 37 | split_name, 38 | dataset_dir, 39 | file_pattern, 40 | reader) 41 | -------------------------------------------------------------------------------- /format_and_convert_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Format Market-1501 training images and convert all the splits into TFRecords 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | from datasets import convert_to_tfrecords 11 | from datasets import format_market_train 12 | from datasets import make_filename_list 13 | from datasets.utils import * 14 | 15 | FLAGS = tf.app.flags.FLAGS 16 | tf.app.flags.DEFINE_string('image_dir', None, 'path to the raw images') 17 | tf.app.flags.DEFINE_string('output_dir', None, 'path to the list and tfrecords ') 18 | tf.app.flags.DEFINE_string('split_name', None, 'split name') 19 | 20 | 21 | def main(_): 22 | 23 | mkdir_if_missing(FLAGS.output_dir) 24 | 25 | if FLAGS.split_name == 'bounding_box_train': 26 | format_market_train.run(image_dir=FLAGS.image_dir) 27 | 28 | make_filename_list.run(image_dir=FLAGS.image_dir, 29 | output_dir=FLAGS.output_dir, 30 | split_name=FLAGS.split_name) 31 | 32 | convert_to_tfrecords.run(image_dir=FLAGS.image_dir, 33 | output_dir=FLAGS.output_dir, 34 | split_name=FLAGS.split_name) 35 | 36 | 37 | if __name__ == '__main__': 38 | tf.app.run() 39 | 40 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains utilities for converting datasets. 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | 10 | 11 | def int64_feature(values): 12 | """Returns a TF-Feature of int64s. 13 | 14 | Args: 15 | values: A scalar or list of values. 16 | 17 | Returns: 18 | a TF-Feature. 19 | """ 20 | if not isinstance(values, (tuple, list)): 21 | values = [values] 22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 23 | 24 | 25 | def bytes_feature(values): 26 | """Returns a TF-Feature of bytes. 27 | 28 | Args: 29 | values: A string. 30 | 31 | Returns: 32 | a TF-Feature. 33 | """ 34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 35 | 36 | 37 | def image_to_tfexample(image_data, class_id, filename, height, width, image_format): 38 | return tf.train.Example(features=tf.train.Features(feature={ 39 | 'image/encoded': bytes_feature(image_data), 40 | 'image/label': int64_feature(class_id), 41 | 'image/filename': bytes_feature(filename), 42 | 'image/height': int64_feature(height), 43 | 'image/width': int64_feature(width), 44 | 'image/format': bytes_feature(image_format), 45 | })) 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Mutual-Learning 2 | 3 | TensorFlow implementation of **[Deep Mutual Learning](https://drive.google.com/file/d/1Deg9xXqPKAlxRgmWbggavftTvJPqJeyp/view)** accepted by CVPR 2018. 4 | 5 | 6 | ## Introduction 7 | Deep mutual learning provides a simple but effective way to improve the generalisation ability of a network by training collaboratively with a cohort of other networks. 8 | 9 | ![DML](DML.png "Deep Mutual Learning") 10 | 11 | ## Requirements 12 | 13 | 1. TensorFlow 1.3.1 14 | 2. CUDA 8.0 and cuDNN 6.0 15 | 3. Matlab 16 | 17 | ## Usage 18 | 19 | ### Data Preparation 20 | 1. Please download the [Market-1501 Dataset](http://www.liangzheng.com.cn/Project/project_reid.html) 21 | 22 | 2. Convert the image data into TFRecords 23 | ``` 24 | sh scripts/format_and_convert_market.sh 25 | ``` 26 | 27 | ### Training 28 | 1. Train MobileNets with DML 29 | ``` 30 | sh scripts/train_dml_mobilenet_on_market.sh 31 | ``` 32 | 33 | 2. Train MobileNet independently 34 | ``` 35 | sh scripts/train_ind_mobilenet_on_market.sh 36 | ``` 37 | 38 | ### Testing 39 | 1. Extract features of the test image 40 | ``` 41 | sh scripts/evaludate_dml_mobilenet_on_market.sh 42 | ``` 43 | 44 | 2. Evaluate the performance with matlab [code](https://github.com/zhunzhong07/person-re-ranking/tree/master/evaluation) 45 | 46 | 47 | ## Citation 48 | If you find DML useful in your research, please kindly cite our paper: 49 | 50 | ``` 51 | @inproceedings{ying2018DML, 52 | author = {Ying Zhang and Tao Xiang and Timothy M. Hospedales and Huchuan Lu}, 53 | title = {Deep Mutual Learning}, 54 | booktitle = {CVPR}, 55 | year = {2018}} 56 | ``` -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains a factory for building various models. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | 11 | from preprocessing import inception_preprocessing 12 | from preprocessing import reid_preprocessing 13 | 14 | slim = tf.contrib.slim 15 | 16 | 17 | def get_preprocessing(name, is_training=False): 18 | """Returns preprocessing_fn(image, height, width, **kwargs). 19 | 20 | Args: 21 | name: The name of the preprocessing function. 22 | is_training: `True` if the model is being used for training and `False` 23 | otherwise. 24 | 25 | Returns: 26 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 27 | It has the following signature: 28 | image = preprocessing_fn(image, output_height, output_width, ...). 29 | 30 | Raises: 31 | ValueError: If Preprocessing `name` is not recognized. 32 | """ 33 | preprocessing_fn_map = { 34 | 'inception_v1': inception_preprocessing, 35 | 'mobilenet_v1': inception_preprocessing, 36 | 'reid': reid_preprocessing, 37 | } 38 | 39 | if name not in preprocessing_fn_map: 40 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 41 | 42 | def preprocessing_fn(image, output_height, output_width, **kwargs): 43 | return preprocessing_fn_map[name].preprocess_image( 44 | image, output_height, output_width, is_training=is_training, **kwargs) 45 | 46 | return preprocessing_fn 47 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains utilities for general usage. 3 | """ 4 | 5 | import os 6 | import os.path as osp 7 | import json 8 | import codecs 9 | import cPickle 10 | 11 | 12 | def mkdir_if_missing(d): 13 | if not osp.isdir(d): 14 | os.makedirs(d) 15 | 16 | 17 | def pickle(data, file_path): 18 | with open(file_path, 'wb') as f: 19 | cPickle.dump(data, f, cPickle.HIGHEST_PROTOCOL) 20 | 21 | 22 | def unpickle(file_path): 23 | with open(file_path, 'rb') as f: 24 | data = cPickle.load(f) 25 | return data 26 | 27 | 28 | def read_list(file_path, coding=None): 29 | if coding is None: 30 | with open(file_path, 'r') as f: 31 | arr = [line.strip() for line in f.readlines()] 32 | else: 33 | with codecs.open(file_path, 'r', coding) as f: 34 | arr = [line.strip() for line in f.readlines()] 35 | return arr 36 | 37 | 38 | def write_list(arr, file_path, coding=None): 39 | if coding is None: 40 | arr = ['{}'.format(item) for item in arr] 41 | with open(file_path, 'w') as f: 42 | f.write('\n'.join(arr)) 43 | else: 44 | with codecs.open(file_path, 'w', coding) as f: 45 | f.write(u'\n'.join(arr)) 46 | 47 | 48 | def read_kv(file_path, coding=None): 49 | arr = read_list(file_path, coding) 50 | if len(arr) == 0: 51 | return [], [] 52 | return zip(*map(str.split, arr)) 53 | 54 | 55 | def write_kv(k, v, file_path, coding=None): 56 | arr = zip(k, v) 57 | arr = [' '.join(item) for item in arr] 58 | write_list(arr, file_path, coding) 59 | 60 | 61 | def read_json(file_path): 62 | with open(file_path, 'r') as f: 63 | obj = json.load(f) 64 | return obj 65 | 66 | 67 | def write_json(obj, file_path): 68 | with open(file_path, 'w') as f: 69 | json.dump(obj, f, indent=4, separators=(',', ': ')) 70 | -------------------------------------------------------------------------------- /datasets/format_market_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Format Market-1501 training images with consecutive labels. 3 | 4 | This code modifies the data preparation method of 5 | "Learning Deep Feature Representations with Domain Guided Dropout for Person Re-identification". 6 | 7 | """ 8 | 9 | import shutil 10 | from glob import glob 11 | from datasets.utils import * 12 | 13 | 14 | def _format_train_data(in_dir, output_dir): 15 | # cam_0 to cam_5 16 | for i in xrange(6): 17 | mkdir_if_missing(osp.join(output_dir, 'cam_' + str(i))) 18 | # pdb.set_trace() 19 | images = glob(osp.join(in_dir, '*.jpg')) 20 | images.sort() 21 | identities = [] 22 | prev_pid = -1 23 | for name in images: 24 | name = osp.basename(name) 25 | p_id = int(name[0:4]) 26 | c_id = int(name[6]) - 1 27 | if prev_pid != p_id: 28 | identities.append([]) 29 | prev_cid = -1 30 | p_images = identities[-1] 31 | if prev_cid != c_id: 32 | p_images.append([]) 33 | v_images = p_images[-1] 34 | file_name = 'cam_{}/cam_{:02d}_{:05d}_{:05d}.jpg'.format(c_id, c_id, len(identities)-1, len(v_images)) 35 | shutil.copy(osp.join(in_dir, name), 36 | osp.join(output_dir, file_name)) 37 | v_images.append(file_name) 38 | prev_pid = p_id 39 | prev_cid = c_id 40 | # Save meta information into a json file 41 | meta = {'name': 'market1501', 'shot': 'multiple', 'num_cameras': 6} 42 | meta['identities'] = identities 43 | write_json(meta, osp.join(output_dir, 'meta.json')) 44 | num_images = len(images) 45 | num_classes = len(identities) 46 | print("Training data has %d images of %d classes" % (num_images, num_classes)) 47 | 48 | 49 | def run(image_dir): 50 | """Format the datasets with consecutive labels. 51 | 52 | Args: 53 | image_dir: The dataset directory where the raw images are stored. 54 | 55 | """ 56 | in_dir = image_dir + "_raw" 57 | os.rename(image_dir, in_dir) 58 | mkdir_if_missing(image_dir) 59 | _format_train_data(in_dir, image_dir) 60 | -------------------------------------------------------------------------------- /datasets/make_filename_list.py: -------------------------------------------------------------------------------- 1 | """ 2 | Make a list for image files. This provides guidance for checking the data preparation. 3 | """ 4 | import numpy as np 5 | from glob import glob 6 | from utils import * 7 | 8 | 9 | def _save(file_label_list, file_path): 10 | content = ['{} {}'.format(x, y) for x, y in file_label_list] 11 | write_list(content, file_path) 12 | 13 | 14 | def _get_train_list(files): 15 | ret = [] 16 | for views in files: 17 | for v in views: 18 | for f in v: 19 | # camID = int(osp.basename(f)[4:6]) 20 | label = int(osp.basename(f)[7:12]) 21 | ret.append((f, label)) 22 | return np.asarray(ret) 23 | 24 | 25 | def _make_train_list(image_dir, output_dir, split_name): 26 | meta = read_json(osp.join(image_dir, 'meta.json')) 27 | identities = np.asarray(meta['identities']) 28 | images = _get_train_list(identities) 29 | _save(images, os.path.join(output_dir, '%s.txt' % split_name)) 30 | 31 | 32 | def _get_test_list(files): 33 | ret = [] 34 | for f in files: 35 | if osp.basename(f)[:2] == '-1': 36 | # camID = int(osp.basename(f)[4]) - 1 37 | label = int(osp.basename(f)[:2]) 38 | else: 39 | # camID = int(osp.basename(f)[6]) - 1 40 | label = int(osp.basename(f)[:4]) 41 | ret.append((osp.basename(f), label)) 42 | return np.asarray(ret) 43 | 44 | 45 | def _make_test_list(image_dir, output_dir, split_name): 46 | files = sorted(glob(osp.join(image_dir, '*.jpg'))) 47 | images = _get_test_list(files) 48 | _save(images, os.path.join(output_dir, '%s.txt' % split_name)) 49 | 50 | 51 | def run(image_dir, output_dir, split_name): 52 | """Make list file for images. 53 | 54 | Args: 55 | image_dir: The image directory where the raw images are stored. 56 | output_dir: The directory where the lists and tfrecords are stored. 57 | split_name: The split name of dataset. 58 | """ 59 | if split_name == 'bounding_box_train': 60 | _make_train_list(image_dir, output_dir, split_name) 61 | else: 62 | _make_test_list(image_dir, output_dir, split_name) 63 | -------------------------------------------------------------------------------- /eval_image_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic evaluation script that evaluates a model using a given dataset. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | import eval_models 11 | from datasets.utils import * 12 | 13 | slim = tf.contrib.slim 14 | 15 | tf.app.flags.DEFINE_string('dataset_name', 'market1501', 16 | 'The name of the dataset to load.') 17 | 18 | tf.app.flags.DEFINE_string('split_name', 'test', 19 | 'The name of the train/test split.') 20 | 21 | tf.app.flags.DEFINE_string('dataset_dir', None, 22 | 'The directory where the dataset files are stored.') 23 | 24 | tf.app.flags.DEFINE_string('checkpoint_dir', None, 25 | 'The directory where the model was written to or an absolute path to a ' 26 | 'checkpoint file.') 27 | 28 | tf.app.flags.DEFINE_string('eval_dir', 'results', 29 | 'Directory where the results are saved to.') 30 | 31 | tf.app.flags.DEFINE_string('model_name', 'mobilenet_v1', 32 | 'The name of the architecture to evaluate.') 33 | 34 | tf.app.flags.DEFINE_integer('num_networks', 2, 35 | 'Number of Networks') 36 | 37 | tf.app.flags.DEFINE_integer('num_classes', 751, 38 | 'The number of classes.') 39 | 40 | tf.app.flags.DEFINE_integer('batch_size', 1, 41 | 'The number of samples in each batch.') 42 | 43 | tf.app.flags.DEFINE_string('preprocessing_name', None, 44 | 'The name of the preprocessing to use. If left ' 45 | 'as `None`, then the model_name flag is used.') 46 | 47 | tf.app.flags.DEFINE_integer('num_preprocessing_threads', 1, 48 | 'The number of threads used to create the batches.') 49 | 50 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 51 | 'The decay to use for the moving average.' 52 | 'If left as None, then moving averages are not used.') 53 | 54 | ######################### 55 | 56 | FLAGS = tf.app.flags.FLAGS 57 | 58 | 59 | def main(_): 60 | # create folders 61 | mkdir_if_missing(FLAGS.eval_dir) 62 | # test 63 | eval_models.evaluate() 64 | 65 | 66 | if __name__ == '__main__': 67 | tf.app.run() 68 | 69 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | from nets import inception 24 | from nets import mobilenet_v1 25 | 26 | slim = tf.contrib.slim 27 | 28 | networks_map = {'mobilenet_v1': mobilenet_v1.mobilenet_v1, 29 | 'inception_v1': inception.inception_v1, 30 | } 31 | 32 | arg_scopes_map = {'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 33 | 'inception_v1': inception.inception_v1_arg_scope, 34 | } 35 | 36 | 37 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 38 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 39 | 40 | Args: 41 | name: The name of the network. 42 | num_classes: The number of classes to use for classification. 43 | weight_decay: The l2 coefficient for the model weights. 44 | is_training: `True` if the model is being used for training and `False` 45 | otherwise. 46 | 47 | Returns: 48 | network_fn: A function that applies the model to a batch of images. It has 49 | the following signature: 50 | logits, end_points = network_fn(images) 51 | Raises: 52 | ValueError: If network `name` is not recognized. 53 | """ 54 | if name not in networks_map: 55 | raise ValueError('Name of network unknown %s' % name) 56 | func = networks_map[name] 57 | 58 | @functools.wraps(func) 59 | def network_fn(images, scope=None): 60 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 61 | with slim.arg_scope(arg_scope): 62 | return func(images, num_classes, is_training=is_training, scope=scope) 63 | 64 | if hasattr(func, 'default_image_size'): 65 | network_fn.default_image_size = func.default_image_size 66 | 67 | return network_fn 68 | -------------------------------------------------------------------------------- /datasets/convert_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert Market-1501 to TFRecords of TF-Example protos. 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import tensorflow as tf 9 | import os 10 | import sys 11 | from scipy import misc 12 | from datasets.dataset_utils import * 13 | 14 | # resize all the re-id images into the same size 15 | _IMAGE_HEIGHT = 160 16 | _IMAGE_WIDTH = 64 17 | _IMAGE_CHANNELS = 3 18 | 19 | 20 | def _add_to_tfrecord(image_dir, list_filename, tfrecord_writer, split_name): 21 | """Loads images and writes files to a TFRecord. 22 | 23 | Args: 24 | image_dir: The image directory where the raw images are stored. 25 | list_filename: The list file of images. 26 | tfrecord_writer: The TFRecord writer to use for writing. 27 | """ 28 | num_images = len(tf.gfile.FastGFile(list_filename, 'r').readlines()) 29 | 30 | shape = (_IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNELS) 31 | with tf.Graph().as_default(): 32 | image = tf.placeholder(dtype=tf.uint8, shape=shape) 33 | encoded_png = tf.image.encode_png(image) 34 | j = 0 35 | with tf.Session('') as sess: 36 | for line in tf.gfile.FastGFile(list_filename, 'r').readlines(): 37 | sys.stdout.write('\r>> Converting %s image %d/%d' % (split_name, j + 1, num_images)) 38 | sys.stdout.flush() 39 | j += 1 40 | imagename, label = line.split(' ') 41 | label = int(label) 42 | file_path = os.path.join(image_dir, imagename) 43 | image_data = misc.imread(file_path) 44 | image_data = misc.imresize(image_data, [_IMAGE_HEIGHT, _IMAGE_WIDTH]) 45 | png_string = sess.run(encoded_png, feed_dict={image: image_data}) 46 | example = image_to_tfexample(png_string, label, imagename, _IMAGE_HEIGHT, _IMAGE_WIDTH, 'png') 47 | tfrecord_writer.write(example.SerializeToString()) 48 | 49 | 50 | def run(image_dir, output_dir, split_name): 51 | """Convert images to tfrecords. 52 | Args: 53 | image_dir: The image directory where the raw images are stored. 54 | output_dir: The directory where the lists and tfrecords are stored. 55 | split_name: The split name of dataset. 56 | """ 57 | list_filename = os.path.join(output_dir, '%s.txt' % split_name) 58 | tf_filename = os.path.join(output_dir, '%s.tfrecord' % split_name) 59 | 60 | if tf.gfile.Exists(tf_filename): 61 | print('Dataset files already exist. Exiting without re-creating them.') 62 | return 63 | 64 | with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer: 65 | _add_to_tfrecord(image_dir, list_filename, tfrecord_writer, split_name) 66 | 67 | print(" Done! \n") 68 | 69 | -------------------------------------------------------------------------------- /preprocessing/reid_preprocessing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides utilities to preprocess images for re-id. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | slim = tf.contrib.slim 11 | 12 | 13 | IMAGE_HEIGHT = 160 14 | IMAGE_WIDTH = 64 15 | 16 | 17 | def preprocess_for_train(image, 18 | output_height, 19 | output_width): 20 | """Preprocesses the given image for training. 21 | 22 | Args: 23 | image: A `Tensor` representing an image of arbitrary size. 24 | output_height: The height of the image after preprocessing. 25 | output_width: The width of the image after preprocessing. 26 | 27 | Returns: 28 | A preprocessed image. 29 | """ 30 | if image.dtype != tf.float32: 31 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 32 | 33 | image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, 3]) 34 | image = tf.image.resize_images(image, [output_height, output_width]) 35 | image = tf.image.random_flip_left_right(image) 36 | tf.summary.image('cropped_resized_image', 37 | tf.expand_dims(image, 0)) 38 | image = tf.subtract(image, 0.5) 39 | image = tf.multiply(image, 2.0) 40 | return image 41 | 42 | 43 | def preprocess_for_eval(image, output_height, output_width): 44 | """Preprocesses the given image for evaluation. 45 | 46 | Args: 47 | image: A `Tensor` representing an image of arbitrary size. 48 | output_height: The height of the image after preprocessing. 49 | output_width: The width of the image after preprocessing. 50 | 51 | Returns: 52 | A preprocessed image. 53 | """ 54 | if image.dtype != tf.float32: 55 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 56 | image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, 3]) 57 | image = tf.image.resize_images(image, [output_height, output_width]) 58 | image.set_shape([output_height, output_width, 3]) 59 | image = tf.subtract(image, 0.5) 60 | image = tf.multiply(image, 2.0) 61 | return image 62 | 63 | 64 | def preprocess_image(image, output_height, output_width, is_training=False): 65 | """Preprocesses the given image. 66 | 67 | Args: 68 | image: A `Tensor` representing an image of arbitrary size. 69 | output_height: The height of the image after preprocessing. 70 | output_width: The width of the image after preprocessing. 71 | is_training: `True` if we're preprocessing the image for training and 72 | `False` otherwise. 73 | 74 | Returns: 75 | A preprocessed image. 76 | """ 77 | if is_training: 78 | return preprocess_for_train(image, output_height, output_width) 79 | else: 80 | return preprocess_for_eval(image, output_height, output_width) 81 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001): 36 | """Defines the default arg scope for inception models. 37 | 38 | Args: 39 | weight_decay: The weight decay to use for regularizing the model. 40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 41 | batch_norm_decay: Decay for batch norm moving average. 42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 43 | in batch norm. 44 | 45 | Returns: 46 | An `arg_scope` to use for the inception models. 47 | """ 48 | batch_norm_params = { 49 | # Decay for the moving averages. 50 | 'decay': batch_norm_decay, 51 | # epsilon to prevent 0s in variance. 52 | 'epsilon': batch_norm_epsilon, 53 | # collection containing update_ops. 54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 55 | } 56 | if use_batch_norm: 57 | normalizer_fn = slim.batch_norm 58 | normalizer_params = batch_norm_params 59 | else: 60 | normalizer_fn = None 61 | normalizer_params = {} 62 | # Set weight_decay for weights in Conv and FC layers. 63 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 64 | weights_regularizer=slim.l2_regularizer(weight_decay)): 65 | with slim.arg_scope( 66 | [slim.conv2d], 67 | weights_initializer=slim.variance_scaling_initializer(), 68 | activation_fn=tf.nn.relu, 69 | normalizer_fn=normalizer_fn, 70 | normalizer_params=normalizer_params) as sc: 71 | return sc 72 | -------------------------------------------------------------------------------- /datasets/reid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provides data given split name 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import tensorflow as tf 10 | 11 | FLAGS = tf.app.flags.FLAGS 12 | 13 | slim = tf.contrib.slim 14 | 15 | _FILE_PATTERN = '%s.tfrecord' 16 | 17 | _SPLITS_NAMES = ['bounding_box_train', 'bounding_box_test', 'gt_bbox', 'query'] 18 | 19 | _ITEMS_TO_DESCRIPTIONS = { 20 | 'image': 'A color image of varying height and width.', 21 | 'label': 'The label id of the image, integer between 0 and num_classes', 22 | 'filename': 'The name of an image', 23 | } 24 | 25 | 26 | def get_num_examples(split_name): 27 | list_file = os.path.join(FLAGS.dataset_dir, '%s.txt' % split_name) 28 | num_examples = len(tf.gfile.FastGFile(list_file, 'r').readlines()) 29 | 30 | return num_examples 31 | 32 | 33 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 34 | """Get a dataset tuple. 35 | 36 | Args: 37 | split_name: A train/test split name. 38 | dataset_dir: The base directory of the dataset sources. 39 | file_pattern: The file pattern to use when matching the dataset sources. 40 | It is assumed that the pattern contains a '%s' string so that the split 41 | name can be inserted. 42 | reader: The TensorFlow reader type. 43 | 44 | Returns: 45 | A `Dataset` namedtuple. 46 | 47 | Raises: 48 | ValueError: if `split_name` is not a valid train/test split. 49 | """ 50 | if split_name not in _SPLITS_NAMES: 51 | raise ValueError('split name %s was not recognized.' % split_name) 52 | 53 | if not file_pattern: 54 | file_pattern = _FILE_PATTERN 55 | 56 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 57 | 58 | # Allowing None in the signature so that dataset_factory can use the default. 59 | if reader is None: 60 | reader = tf.TFRecordReader 61 | 62 | keys_to_features = { 63 | 'image/encoded': tf.FixedLenFeature( 64 | (), tf.string, default_value=''), 65 | 'image/format': tf.FixedLenFeature( 66 | (), tf.string, default_value='png'), 67 | 'image/label': tf.FixedLenFeature( 68 | [], dtype=tf.int64, default_value=-1), 69 | 'image/filename': tf.FixedLenFeature( 70 | [], dtype=tf.string, default_value=''), 71 | } 72 | 73 | items_to_handlers = { 74 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 75 | 'label': slim.tfexample_decoder.Tensor('image/label'), 76 | 'filename': slim.tfexample_decoder.Tensor('image/filename'), 77 | } 78 | 79 | decoder = slim.tfexample_decoder.TFExampleDecoder( 80 | keys_to_features, items_to_handlers) 81 | 82 | num_examples = get_num_examples(split_name) 83 | num_classes = FLAGS.num_classes 84 | 85 | return slim.dataset.Dataset( 86 | data_sources=file_pattern, 87 | reader=reader, 88 | decoder=decoder, 89 | num_samples=num_examples, 90 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 91 | num_classes=num_classes) 92 | -------------------------------------------------------------------------------- /.idea/markdown-navigator.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 36 | 37 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /train_image_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic training script that trains a model using a given dataset. 3 | 4 | This code modifies the "TensorFlow-Slim image classification model library", 5 | Please visit https://github.com/tensorflow/models/tree/master/research/slim 6 | for more detailed usage. 7 | 8 | """ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import tensorflow as tf 15 | import train_models 16 | from datasets.utils import * 17 | 18 | slim = tf.contrib.slim 19 | 20 | ######################### 21 | # Training Directories # 22 | ######################### 23 | 24 | tf.app.flags.DEFINE_string('dataset_name', 'market1501', 25 | 'The name of the dataset to load.') 26 | 27 | tf.app.flags.DEFINE_string('split_name', 'bounding_box_train', 28 | 'The name of the data split.') 29 | 30 | tf.app.flags.DEFINE_string('dataset_dir', None, 31 | 'The directory where the dataset files are stored.') 32 | 33 | tf.app.flags.DEFINE_string('checkpoint_dir', 'checkpoint', 34 | 'Directory name to save the checkpoints [checkpoint]') 35 | 36 | tf.app.flags.DEFINE_string('log_dir', 'logs', 37 | 'Directory name to save the logs') 38 | 39 | 40 | ######################### 41 | # Model Settings # 42 | ######################### 43 | 44 | tf.app.flags.DEFINE_string('model_name', 'mobilenet_v1', 45 | 'The name of the architecture to train.') 46 | 47 | tf.app.flags.DEFINE_string('preprocessing_name', None, 48 | 'The name of the preprocessing to use. If left as `None`, ' 49 | 'then the model_name flag is used.') 50 | 51 | tf.app.flags.DEFINE_float('weight_decay', 0.00004, 52 | 'The weight decay on the model weights.') 53 | 54 | tf.app.flags.DEFINE_float('label_smoothing', 0.0, 55 | 'The amount of label smoothing.') 56 | 57 | tf.app.flags.DEFINE_integer('batch_size', 16, 58 | 'The number of samples in each batch.') 59 | 60 | tf.app.flags.DEFINE_integer('max_number_of_steps', 200000, 61 | 'The maximum number of training steps.') 62 | 63 | tf.app.flags.DEFINE_integer('ckpt_steps', 5000, 64 | 'How many steps to save checkpoints.') 65 | 66 | tf.app.flags.DEFINE_integer('num_classes', 751, 67 | 'The number of classes.') 68 | 69 | tf.app.flags.DEFINE_integer('num_networks', 2, 70 | 'The number of networks in DML.') 71 | 72 | tf.app.flags.DEFINE_integer('num_gpus', 1, 73 | 'The number of GPUs.') 74 | 75 | ######################### 76 | # Optimization Settings # 77 | ######################### 78 | 79 | tf.app.flags.DEFINE_string('optimizer', 'adam', 80 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",' 81 | '"ftrl", "momentum", "sgd" or "rmsprop".') 82 | 83 | tf.app.flags.DEFINE_float('learning_rate', 0.0002, 84 | 'Initial learning rate.') 85 | 86 | tf.app.flags.DEFINE_float('adam_beta1', 0.5, 87 | 'The exponential decay rate for the 1st moment estimates.') 88 | 89 | tf.app.flags.DEFINE_float('adam_beta2', 0.999, 90 | 'The exponential decay rate for the 2nd moment estimates.') 91 | 92 | tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 93 | 'Epsilon term for the optimizer.') 94 | 95 | 96 | ######################### 97 | # Default Settings # 98 | ######################### 99 | tf.app.flags.DEFINE_integer('num_clones', 1, 100 | 'Number of model clones to deploy.') 101 | 102 | tf.app.flags.DEFINE_boolean('clone_on_cpu', False, 103 | 'Use CPUs to deploy clones.') 104 | 105 | tf.app.flags.DEFINE_integer('worker_replicas', 1, 106 | 'Number of worker replicas.') 107 | 108 | tf.app.flags.DEFINE_integer('num_ps_tasks', 0, 109 | 'The number of parameter servers. If the value is 0, then the parameters ' 110 | 'are handled locally by the worker.') 111 | 112 | tf.app.flags.DEFINE_integer('task', 0, 113 | 'Task id of the replica running the training.') 114 | 115 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999, 116 | 'The decay to use for the moving average.' 117 | 'If left as None, then moving averages are not used.') 118 | 119 | tf.app.flags.DEFINE_integer('input_queue_memory_factor', 16, 120 | """Size of the queue of preprocessed images. """) 121 | 122 | tf.app.flags.DEFINE_integer('num_readers', 4, 123 | 'The number of parallel readers that read data from the dataset.') 124 | 125 | tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4, 126 | 'The number of threads used to create the batches.') 127 | 128 | tf.app.flags.DEFINE_boolean('log_device_placement', False, 129 | """Whether to log device placement.""") 130 | 131 | 132 | FLAGS = tf.app.flags.FLAGS 133 | 134 | 135 | def main(_): 136 | # create folders 137 | mkdir_if_missing(FLAGS.checkpoint_dir) 138 | mkdir_if_missing(FLAGS.log_dir) 139 | # training 140 | train_models.train() 141 | 142 | 143 | if __name__ == '__main__': 144 | tf.app.run() 145 | -------------------------------------------------------------------------------- /eval_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic evaluation script that evaluates a model using a given dataset. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | from datasets import dataset_factory 11 | from nets import nets_factory 12 | from preprocessing import preprocessing_factory 13 | import math 14 | from datetime import datetime 15 | import numpy as np 16 | import os.path 17 | import sys 18 | import scipy.io as sio 19 | 20 | slim = tf.contrib.slim 21 | 22 | FLAGS = tf.app.flags.FLAGS 23 | 24 | 25 | def _extract_once(features, labels, filenames, num_examples, saver): 26 | """Extract Features. 27 | """ 28 | config = tf.ConfigProto() 29 | config.gpu_options.per_process_gpu_memory_fraction = 0.2 30 | with tf.device('/cpu:0'): 31 | with tf.Session(config=config) as sess: 32 | ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) 33 | if ckpt and ckpt.model_checkpoint_path: 34 | if os.path.isabs(ckpt.model_checkpoint_path): 35 | saver.restore(sess, ckpt.model_checkpoint_path) 36 | else: 37 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path) 38 | saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, ckpt_name)) 39 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 40 | print('Succesfully loaded model from %s at step=%s.' % 41 | (ckpt.model_checkpoint_path, global_step)) 42 | else: 43 | print('No checkpoint file found') 44 | return 45 | 46 | # Start the queue runners. 47 | coord = tf.train.Coordinator() 48 | try: 49 | threads = [] 50 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): 51 | threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True)) 52 | # num_examples = get_num_examples() 53 | num_iter = int(math.ceil(num_examples / FLAGS.batch_size)) 54 | # Counts the number of correct predictions. 55 | step = 0 56 | all_features = [] 57 | all_labels = [] 58 | print("Current Path: %s" % os.getcwd()) 59 | print('%s: starting extracting features on (%s).' % (datetime.now(), FLAGS.split_name)) 60 | while step < num_iter and not coord.should_stop(): 61 | step += 1 62 | sys.stdout.write('\r>> Extracting %s image %d/%d' % (FLAGS.split_name, step, num_examples)) 63 | sys.stdout.flush() 64 | eval_features, eval_labels, eval_filenames = sess.run([features, labels, filenames]) 65 | # print('Filename:%s, Camid:%d, Label:%d' % (eval_filenames, eval_camids, eval_labels)) 66 | concat_features = np.concatenate(eval_features, axis=3) 67 | eval_features = np.reshape(concat_features, [concat_features.shape[0], -1]) 68 | all_features.append(eval_features) 69 | all_labels.append(eval_labels) 70 | 71 | # save features and labels 72 | np_features = np.asarray(all_features) 73 | np_features = np.reshape(np_features, [len(all_features), -1]) 74 | np_labels = np.asarray(all_labels) 75 | np_labels = np.reshape(np_labels, len(all_labels)) 76 | feature_filename = "%s/%s_features.mat" % (FLAGS.eval_dir, FLAGS.split_name) 77 | sio.savemat(feature_filename, {'feature': np_features}) 78 | label_filename = "%s/%s_labels.mat" % (FLAGS.eval_dir, FLAGS.split_name) 79 | sio.savemat(label_filename, {'label': np_labels}) 80 | print("Done!\n") 81 | 82 | except Exception as e: 83 | coord.request_stop(e) 84 | 85 | coord.request_stop() 86 | coord.join(threads, stop_grace_period_secs=10) 87 | 88 | 89 | def evaluate(): 90 | if not FLAGS.dataset_dir: 91 | raise ValueError('You must supply the dataset directory with --dataset_dir') 92 | 93 | tf.logging.set_verbosity(tf.logging.INFO) 94 | with tf.Graph().as_default(): 95 | tf_global_step = slim.get_or_create_global_step() 96 | 97 | ###################### 98 | # Select the dataset # 99 | ###################### 100 | dataset = dataset_factory.get_dataset( 101 | FLAGS.dataset_name, FLAGS.split_name, FLAGS.dataset_dir) 102 | 103 | #################### 104 | # Select the model # 105 | #################### 106 | network_fn = {} 107 | model_names = [net.strip() for net in FLAGS.model_name.split(',')] 108 | for i in range(FLAGS.num_networks): 109 | network_fn["{0}".format(i)] = nets_factory.get_network_fn( 110 | model_names[i], 111 | num_classes=dataset.num_classes, 112 | is_training=False) 113 | 114 | ############################################################## 115 | # Create a dataset provider that loads data from the dataset # 116 | ############################################################## 117 | provider = slim.dataset_data_provider.DatasetDataProvider( 118 | dataset, 119 | shuffle=False, 120 | common_queue_capacity=2 * FLAGS.batch_size, 121 | common_queue_min=FLAGS.batch_size) 122 | [image, label, filename] = provider.get(['image', 'label', 'filename']) 123 | 124 | ##################################### 125 | # Select the preprocessing function # 126 | ##################################### 127 | preprocessing_name = FLAGS.preprocessing_name 128 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 129 | preprocessing_name, 130 | is_training=False) 131 | 132 | eval_image_size = network_fn['0'].default_image_size 133 | 134 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size) 135 | 136 | images, labels, filenames = tf.train.batch( 137 | [image, label, filename], 138 | batch_size=FLAGS.batch_size, 139 | num_threads=FLAGS.num_preprocessing_threads, 140 | capacity=5 * FLAGS.batch_size) 141 | 142 | #################### 143 | # Define the model # 144 | #################### 145 | net_endpoints, net_features = {}, {} 146 | all_features = [] 147 | for i in range(FLAGS.num_networks): 148 | _, net_endpoints["{0}".format(i)] = network_fn["{0}".format(i)](images, scope=('dmlnet_%d' % i)) 149 | net_features["{0}".format(i)] = net_endpoints["{0}".format(i)]['PreLogits'] 150 | all_features.append(net_features["{0}".format(i)]) 151 | 152 | if FLAGS.moving_average_decay: 153 | variable_averages = tf.train.ExponentialMovingAverage( 154 | FLAGS.moving_average_decay, tf_global_step) 155 | variables_to_restore = variable_averages.variables_to_restore( 156 | slim.get_model_variables()) 157 | variables_to_restore[tf_global_step.op.name] = tf_global_step 158 | else: 159 | variables_to_restore = slim.get_variables_to_restore() 160 | 161 | saver = tf.train.Saver(variables_to_restore) 162 | _extract_once(all_features, labels, filenames, dataset.num_samples, saver) 163 | -------------------------------------------------------------------------------- /preprocessing/inception_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images for the Inception networks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | 25 | 26 | def apply_with_random_selector(x, func, num_cases): 27 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 28 | 29 | Args: 30 | x: input Tensor. 31 | func: Python function to apply. 32 | num_cases: Python int32, number of cases to sample sel from. 33 | 34 | Returns: 35 | The result of func(x, sel), where func receives the value of the 36 | selector as a python integer, but sel is sampled dynamically. 37 | """ 38 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 39 | # Pass the real x only to one of the func calls. 40 | return control_flow_ops.merge([ 41 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 42 | for case in range(num_cases)])[0] 43 | 44 | 45 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 46 | """Distort the color of a Tensor image. 47 | 48 | Each color distortion is non-commutative and thus ordering of the color ops 49 | matters. Ideally we would randomly permute the ordering of the color ops. 50 | Rather then adding that level of complication, we select a distinct ordering 51 | of color ops for each preprocessing thread. 52 | 53 | Args: 54 | image: 3-D Tensor containing single image in [0, 1]. 55 | color_ordering: Python int, a type of distortion (valid values: 0-3). 56 | fast_mode: Avoids slower ops (random_hue and random_contrast) 57 | scope: Optional scope for name_scope. 58 | Returns: 59 | 3-D Tensor color-distorted image on range [0, 1] 60 | Raises: 61 | ValueError: if color_ordering not in [0, 3] 62 | """ 63 | with tf.name_scope(scope, 'distort_color', [image]): 64 | if fast_mode: 65 | if color_ordering == 0: 66 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 67 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 68 | else: 69 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 70 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 71 | else: 72 | if color_ordering == 0: 73 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 74 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 75 | image = tf.image.random_hue(image, max_delta=0.2) 76 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 77 | elif color_ordering == 1: 78 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 79 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 80 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 81 | image = tf.image.random_hue(image, max_delta=0.2) 82 | elif color_ordering == 2: 83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 84 | image = tf.image.random_hue(image, max_delta=0.2) 85 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 86 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 87 | elif color_ordering == 3: 88 | image = tf.image.random_hue(image, max_delta=0.2) 89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 90 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 91 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 92 | else: 93 | raise ValueError('color_ordering must be in [0, 3]') 94 | 95 | # The random_* ops do not necessarily clamp. 96 | return tf.clip_by_value(image, 0.0, 1.0) 97 | 98 | 99 | def distorted_bounding_box_crop(image, 100 | bbox, 101 | min_object_covered=0.1, 102 | aspect_ratio_range=(0.75, 1.33), 103 | area_range=(0.05, 1.0), 104 | max_attempts=100, 105 | scope=None): 106 | """Generates cropped_image using a one of the bboxes randomly distorted. 107 | 108 | See `tf.image.sample_distorted_bounding_box` for more documentation. 109 | 110 | Args: 111 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 112 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 113 | where each coordinate is [0, 1) and the coordinates are arranged 114 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 115 | image. 116 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 117 | area of the image must contain at least this fraction of any bounding box 118 | supplied. 119 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 120 | image must have an aspect ratio = width / height within this range. 121 | area_range: An optional list of `floats`. The cropped area of the image 122 | must contain a fraction of the supplied image within in this range. 123 | max_attempts: An optional `int`. Number of attempts at generating a cropped 124 | region of the image of the specified constraints. After `max_attempts` 125 | failures, return the entire image. 126 | scope: Optional scope for name_scope. 127 | Returns: 128 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 129 | """ 130 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 131 | # Each bounding box has shape [1, num_boxes, box coords] and 132 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 133 | 134 | # A large fraction of image datasets contain a human-annotated bounding 135 | # box delineating the region of the image containing the object of interest. 136 | # We choose to create a new bounding box for the object which is a randomly 137 | # distorted version of the human-annotated bounding box that obeys an 138 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 139 | # bounding box. If no box is supplied, then we assume the bounding box is 140 | # the entire image. 141 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 142 | tf.shape(image), 143 | bounding_boxes=bbox, 144 | min_object_covered=min_object_covered, 145 | aspect_ratio_range=aspect_ratio_range, 146 | area_range=area_range, 147 | max_attempts=max_attempts, 148 | use_image_if_no_bounding_boxes=True) 149 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 150 | 151 | # Crop the image to the specified bounding box. 152 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 153 | return cropped_image, distort_bbox 154 | 155 | 156 | def preprocess_for_train(image, height, width, bbox, 157 | fast_mode=True, 158 | scope=None): 159 | """Distort one image for training a network. 160 | 161 | Distorting images provides a useful technique for augmenting the data 162 | set during training in order to make the network invariant to aspects 163 | of the image that do not effect the label. 164 | 165 | Additionally it would create image_summaries to display the different 166 | transformations applied to the image. 167 | 168 | Args: 169 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 170 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 171 | is [0, MAX], where MAX is largest positive representable number for 172 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 173 | height: integer 174 | width: integer 175 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 176 | where each coordinate is [0, 1) and the coordinates are arranged 177 | as [ymin, xmin, ymax, xmax]. 178 | fast_mode: Optional boolean, if True avoids slower transformations (i.e. 179 | bi-cubic resizing, random_hue or random_contrast). 180 | scope: Optional scope for name_scope. 181 | Returns: 182 | 3-D float Tensor of distorted image used for training with range [-1, 1]. 183 | """ 184 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): 185 | if bbox is None: 186 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], 187 | dtype=tf.float32, 188 | shape=[1, 1, 4]) 189 | if image.dtype != tf.float32: 190 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 191 | # Each bounding box has shape [1, num_boxes, box coords] and 192 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 193 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 194 | bbox) 195 | tf.summary.image('image_with_bounding_boxes', image_with_box) 196 | 197 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) 198 | # Restore the shape since the dynamic slice based upon the bbox_size loses 199 | # the third dimension. 200 | distorted_image.set_shape([None, None, 3]) 201 | image_with_distorted_box = tf.image.draw_bounding_boxes( 202 | tf.expand_dims(image, 0), distorted_bbox) 203 | tf.summary.image('images_with_distorted_bounding_box', 204 | image_with_distorted_box) 205 | 206 | # This resizing operation may distort the images because the aspect 207 | # ratio is not respected. We select a resize method in a round robin 208 | # fashion based on the thread number. 209 | # Note that ResizeMethod contains 4 enumerated resizing methods. 210 | 211 | # We select only 1 case for fast_mode bilinear. 212 | num_resize_cases = 1 if fast_mode else 4 213 | distorted_image = apply_with_random_selector( 214 | distorted_image, 215 | lambda x, method: tf.image.resize_images(x, [height, width], method=method), 216 | num_cases=num_resize_cases) 217 | 218 | tf.summary.image('cropped_resized_image', 219 | tf.expand_dims(distorted_image, 0)) 220 | 221 | # Randomly flip the image horizontally. 222 | distorted_image = tf.image.random_flip_left_right(distorted_image) 223 | 224 | # Randomly distort the colors. There are 4 ways to do it. 225 | distorted_image = apply_with_random_selector( 226 | distorted_image, 227 | lambda x, ordering: distort_color(x, ordering, fast_mode), 228 | num_cases=4) 229 | 230 | tf.summary.image('final_distorted_image', 231 | tf.expand_dims(distorted_image, 0)) 232 | distorted_image = tf.subtract(distorted_image, 0.5) 233 | distorted_image = tf.multiply(distorted_image, 2.0) 234 | return distorted_image 235 | 236 | 237 | def preprocess_for_eval(image, height, width, 238 | central_fraction=0.875, scope=None): 239 | """Prepare one image for evaluation. 240 | 241 | If height and width are specified it would output an image with that size by 242 | applying resize_bilinear. 243 | 244 | If central_fraction is specified it would cropt the central fraction of the 245 | input image. 246 | 247 | Args: 248 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 249 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 250 | is [0, MAX], where MAX is largest positive representable number for 251 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details) 252 | height: integer 253 | width: integer 254 | central_fraction: Optional Float, fraction of the image to crop. 255 | scope: Optional scope for name_scope. 256 | Returns: 257 | 3-D float Tensor of prepared image. 258 | """ 259 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 260 | if image.dtype != tf.float32: 261 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 262 | # Crop the central region of the image with an area containing 87.5% of 263 | # the original image. 264 | if central_fraction: 265 | image = tf.image.central_crop(image, central_fraction=central_fraction) 266 | 267 | if height and width: 268 | # Resize the image to the specified height and width. 269 | image = tf.expand_dims(image, 0) 270 | image = tf.image.resize_bilinear(image, [height, width], 271 | align_corners=False) 272 | image = tf.squeeze(image, [0]) 273 | image = tf.subtract(image, 0.5) 274 | image = tf.multiply(image, 2.0) 275 | return image 276 | 277 | 278 | def preprocess_image(image, height, width, 279 | is_training=False, 280 | bbox=None, 281 | fast_mode=True): 282 | """Pre-process one image for training or evaluation. 283 | 284 | Args: 285 | image: 3-D Tensor [height, width, channels] with the image. 286 | height: integer, image expected height. 287 | width: integer, image expected width. 288 | is_training: Boolean. If true it would transform an image for train, 289 | otherwise it would transform it for evaluation. 290 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 291 | where each coordinate is [0, 1) and the coordinates are arranged as 292 | [ymin, xmin, ymax, xmax]. 293 | fast_mode: Optional boolean, if True avoids slower transformations. 294 | 295 | Returns: 296 | 3-D float Tensor containing an appropriately scaled image 297 | 298 | Raises: 299 | ValueError: if user does not provide bounding box 300 | """ 301 | if is_training: 302 | return preprocess_for_train(image, height, width, bbox, fast_mode) 303 | else: 304 | return preprocess_for_eval(image, height, width) 305 | -------------------------------------------------------------------------------- /train_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic training script that trains a model using a given dataset. 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import tensorflow as tf 10 | from datasets import dataset_factory 11 | from deployment import model_deploy 12 | from nets import nets_factory 13 | from preprocessing import preprocessing_factory 14 | from datasets.utils import * 15 | import numpy as np 16 | 17 | slim = tf.contrib.slim 18 | 19 | FLAGS = tf.app.flags.FLAGS 20 | 21 | 22 | def _average_gradients(tower_grads, catname=None): 23 | """Calculate the average gradient for each shared variable across all towers. 24 | 25 | Note that this function provides a synchronization point across all towers. 26 | 27 | Args: 28 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 29 | is over individual gradients. The inner list is over the gradient 30 | calculation for each tower. 31 | Returns: 32 | List of pairs of (gradient, variable) where the gradient has been averaged 33 | across all towers. 34 | """ 35 | average_grads = [] 36 | for grad_and_vars in zip(*tower_grads): 37 | # Note that each grad_and_vars looks like the following: 38 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 39 | grads = [] 40 | for g, _ in grad_and_vars: 41 | # Add 0 dimension to the gradients to represent the tower. 42 | expanded_g = tf.expand_dims(input=g, axis=0) 43 | # print(g) 44 | # Append on a 'tower' dimension which we will average over below. 45 | grads.append(expanded_g) 46 | 47 | # Average over the 'tower' dimension. 48 | grad = tf.concat(axis=0, values=grads, name=catname) 49 | grad = tf.reduce_mean(grad, 0) 50 | 51 | # Keep in mind that the Variables are redundant because they are shared 52 | # across towers. So .. we will just return the first tower's pointer to 53 | # the Variable. 54 | v = grad_and_vars[0][1] 55 | grad_and_var = (grad, v) 56 | average_grads.append(grad_and_var) 57 | return average_grads 58 | 59 | 60 | def kl_loss_compute(logits1, logits2): 61 | """ KL loss 62 | """ 63 | pred1 = tf.nn.softmax(logits1) 64 | pred2 = tf.nn.softmax(logits2) 65 | loss = tf.reduce_mean(tf.reduce_sum(pred2 * tf.log(1e-8 + pred2 / (pred1 + 1e-8)), 1)) 66 | 67 | return loss 68 | 69 | 70 | def _tower_loss(network_fn, images, labels): 71 | """Calculate the total loss on a single tower running the reid model.""" 72 | # Build inference Graph. 73 | net_logits, net_endpoints, net_raw_loss, net_pred, net_features = {}, {}, {}, {}, {} 74 | for i in range(FLAGS.num_networks): 75 | net_logits["{0}".format(i)], net_endpoints["{0}".format(i)] = \ 76 | network_fn["{0}".format(i)](images, scope=('dmlnet_%d' % i)) 77 | net_raw_loss["{0}".format(i)] = tf.losses.softmax_cross_entropy( 78 | logits=net_logits["{0}".format(i)], onehot_labels=labels, 79 | label_smoothing=FLAGS.label_smoothing, weights=1.0) 80 | net_pred["{0}".format(i)] = net_endpoints["{0}".format(i)]['Predictions'] 81 | 82 | if 'AuxLogits' in net_endpoints["{0}".format(i)]: 83 | net_raw_loss["{0}".format(i)] += tf.losses.softmax_cross_entropy( 84 | logits=net_endpoints["{0}".format(i)]['AuxLogits'], onehot_labels=labels, 85 | label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss') 86 | 87 | # Add KL loss if there are more than one network 88 | net_loss, kl_loss, net_reg_loss, net_total_loss, net_loss_averages, net_loss_averages_op = {}, {}, {}, {}, {}, {} 89 | 90 | for i in range(FLAGS.num_networks): 91 | net_loss["{0}".format(i)] = net_raw_loss["{0}".format(i)] 92 | for j in range(FLAGS.num_networks): 93 | if i != j: 94 | kl_loss["{0}{0}".format(i, j)] = kl_loss_compute(net_logits["{0}".format(i)], net_logits["{0}".format(j)]) 95 | net_loss["{0}".format(i)] += kl_loss["{0}{0}".format(i, j)] 96 | tf.summary.scalar('kl_loss_%d%d' % (i, j), kl_loss["{0}{0}".format(i, j)]) 97 | 98 | net_reg_loss["{0}".format(i)] = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=('dmlnet_%d' % i)) 99 | net_total_loss["{0}".format(i)] = tf.add_n([net_loss["{0}".format(i)]] + 100 | net_reg_loss["{0}".format(i)], 101 | name=('net%d_total_loss' % i)) 102 | 103 | net_loss_averages["{0}".format(i)] = tf.train.ExponentialMovingAverage(0.9, name='net%d_avg' % i) 104 | net_loss_averages_op["{0}".format(i)] = net_loss_averages["{0}".format(i)].apply( 105 | [net_loss["{0}".format(i)]] + [net_total_loss["{0}".format(i)]]) 106 | 107 | tf.summary.scalar('net%d_loss_raw' % i, net_raw_loss["{0}".format(i)]) 108 | tf.summary.scalar('net%d_loss_sum' % i, net_loss["{0}".format(i)]) 109 | tf.summary.scalar('net%d_loss_avg' % i, net_loss_averages["{0}".format(i)].average(net_loss["{0}".format(i)])) 110 | 111 | with tf.control_dependencies([net_loss_averages_op["{0}".format(i)]]): 112 | net_total_loss["{0}".format(i)] = tf.identity(net_total_loss["{0}".format(i)]) 113 | 114 | return net_total_loss, net_pred 115 | 116 | 117 | def train(): 118 | if not FLAGS.dataset_dir: 119 | raise ValueError('You must supply the dataset directory with --dataset_dir') 120 | 121 | tf.logging.set_verbosity(tf.logging.INFO) 122 | with tf.Graph().as_default(): 123 | ####################### 124 | # Config model_deploy # 125 | ####################### 126 | deploy_config = model_deploy.DeploymentConfig( 127 | num_clones=FLAGS.num_clones, 128 | clone_on_cpu=FLAGS.clone_on_cpu, 129 | replica_id=FLAGS.task, 130 | num_replicas=FLAGS.worker_replicas, 131 | num_ps_tasks=FLAGS.num_ps_tasks) 132 | 133 | # Create global_step 134 | with tf.device(deploy_config.variables_device()): 135 | global_step = slim.create_global_step() 136 | 137 | ###################### 138 | # Select the dataset # 139 | ###################### 140 | dataset = dataset_factory.get_dataset( 141 | FLAGS.dataset_name, FLAGS.split_name, FLAGS.dataset_dir) 142 | 143 | ###################### 144 | # Select the network and # 145 | ###################### 146 | network_fn = {} 147 | model_names = [net.strip() for net in FLAGS.model_name.split(',')] 148 | for i in range(FLAGS.num_networks): 149 | network_fn["{0}".format(i)] = nets_factory.get_network_fn( 150 | model_names[i], 151 | num_classes=dataset.num_classes, 152 | weight_decay=FLAGS.weight_decay, 153 | is_training=True) 154 | 155 | ######################################### 156 | # Configure the optimization procedure. # 157 | ######################################### 158 | with tf.device(deploy_config.optimizer_device()): 159 | net_opt = {} 160 | for i in range(FLAGS.num_networks): 161 | net_opt["{0}".format(i)] = tf.train.AdamOptimizer(FLAGS.learning_rate, 162 | beta1=FLAGS.adam_beta1, 163 | beta2=FLAGS.adam_beta2, 164 | epsilon=FLAGS.opt_epsilon) 165 | 166 | ##################################### 167 | # Select the preprocessing function # 168 | ##################################### 169 | preprocessing_name = FLAGS.preprocessing_name # or FLAGS.model_name 170 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 171 | preprocessing_name, 172 | is_training=True) 173 | 174 | ############################################################## 175 | # Create a dataset provider that loads data from the dataset # 176 | ############################################################## 177 | with tf.device(deploy_config.inputs_device()): 178 | examples_per_shard = 1024 179 | min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor 180 | provider = slim.dataset_data_provider.DatasetDataProvider( 181 | dataset, 182 | num_readers=FLAGS.num_readers, 183 | common_queue_capacity=min_queue_examples + 3 * FLAGS.batch_size, 184 | common_queue_min=min_queue_examples) 185 | [image, label] = provider.get(['image', 'label']) 186 | 187 | train_image_size = network_fn["{0}".format(0)].default_image_size 188 | 189 | image = image_preprocessing_fn(image, train_image_size, train_image_size) 190 | 191 | images, labels = tf.train.batch( 192 | [image, label], 193 | batch_size=FLAGS.batch_size, 194 | num_threads=FLAGS.num_preprocessing_threads, 195 | capacity=2 * FLAGS.num_preprocessing_threads * FLAGS.batch_size) 196 | labels = slim.one_hot_encoding(labels, dataset.num_classes) 197 | batch_queue = slim.prefetch_queue.prefetch_queue( 198 | [images, labels], capacity=16 * deploy_config.num_clones, 199 | num_threads=FLAGS.num_preprocessing_threads) 200 | 201 | images, labels = batch_queue.dequeue() 202 | 203 | images_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=images) 204 | labels_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=labels) 205 | 206 | precision, net_tower_grads, net_update_ops, net_var_list, net_grads = {}, {}, {}, {}, {} 207 | 208 | for i in range(FLAGS.num_networks): 209 | net_tower_grads["{0}".format(i)] = [] 210 | 211 | for k in xrange(FLAGS.num_gpus): 212 | with tf.device('/gpu:%d' % k): 213 | with tf.name_scope('tower_%d' % k) as scope: 214 | with tf.variable_scope(tf.get_variable_scope()): 215 | 216 | net_loss, net_pred = _tower_loss(network_fn, images_splits[k], labels_splits[k]) 217 | 218 | truth = tf.argmax(labels_splits[k], axis=1) 219 | 220 | # Reuse variables for the next tower. 221 | tf.get_variable_scope().reuse_variables() 222 | 223 | # Retain the summaries from the final tower. 224 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) 225 | var_list = tf.trainable_variables() 226 | 227 | for i in range(FLAGS.num_networks): 228 | predictions = tf.argmax(net_pred["{0}".format(i)], axis=1) 229 | precision["{0}".format(i)] = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth))) 230 | 231 | # Add a summary to track the training precision. 232 | summaries.append(tf.summary.scalar('precision_%d' % i, precision["{0}".format(i)])) 233 | 234 | net_update_ops["{0}".format(i)] = \ 235 | tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=('%sdmlnet_%d' % (scope, i))) 236 | 237 | net_var_list["{0}".format(i)] = \ 238 | [var for var in var_list if 'dmlnet_%d' % i in var.name] 239 | 240 | net_grads["{0}".format(i)] = net_opt["{0}".format(i)].compute_gradients( 241 | net_loss["{0}".format(i)], var_list=net_var_list["{0}".format(i)]) 242 | 243 | net_tower_grads["{0}".format(i)].append(net_grads["{0}".format(i)]) 244 | 245 | # We must calculate the mean of each gradient. Note that this is the 246 | # synchronization point across all towers. 247 | for i in range(FLAGS.num_networks): 248 | net_grads["{0}".format(i)] = _average_gradients(net_tower_grads["{0}".format(i)], 249 | catname=('dmlnet_%d_cat' % i)) 250 | 251 | # Add histograms for histogram and trainable variables. 252 | for i in range(FLAGS.num_networks): 253 | for grad, var in net_grads["{0}".format(i)]: 254 | if grad is not None: 255 | summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad)) 256 | 257 | for var in tf.trainable_variables(): 258 | summaries.append(tf.summary.histogram(var.op.name, var)) 259 | 260 | ################################# 261 | # Configure the moving averages # 262 | ################################# 263 | 264 | if FLAGS.moving_average_decay: 265 | moving_average_variables = {} 266 | all_moving_average_variables = slim.get_model_variables() 267 | variable_averages = tf.train.ExponentialMovingAverage( 268 | FLAGS.moving_average_decay, global_step) 269 | for i in range(FLAGS.num_networks): 270 | moving_average_variables["{0}".format(i)] = \ 271 | [var for var in all_moving_average_variables if 'dmlnet_%d' % i in var.name] 272 | net_update_ops["{0}".format(i)].append( 273 | variable_averages.apply(moving_average_variables["{0}".format(i)])) 274 | 275 | # Apply the gradients to adjust the shared variables. 276 | net_grad_updates, net_train_op = {}, {} 277 | for i in range(FLAGS.num_networks): 278 | net_grad_updates["{0}".format(i)] = net_opt["{0}".format(i)].apply_gradients( 279 | net_grads["{0}".format(i)], global_step=global_step) 280 | net_update_ops["{0}".format(i)].append(net_grad_updates["{0}".format(i)]) 281 | # Group all updates to into a single train op. 282 | net_train_op["{0}".format(i)] = tf.group(*net_update_ops["{0}".format(i)]) 283 | 284 | # Create a saver. 285 | saver = tf.train.Saver(tf.global_variables()) 286 | 287 | # Build the summary operation from the last tower summaries. 288 | summary_op = tf.summary.merge(summaries) 289 | 290 | # Build an initialization operation to run below. 291 | init = tf.global_variables_initializer() 292 | 293 | # Start running operations on the Graph. allow_soft_placement must be set to 294 | # True to build towers on GPU, as some of the ops do not have GPU 295 | # implementations. 296 | sess = tf.Session(config=tf.ConfigProto( 297 | allow_soft_placement=True, 298 | log_device_placement=FLAGS.log_device_placement)) 299 | sess.run(init) 300 | 301 | # Start the queue runners. 302 | tf.train.start_queue_runners(sess=sess) 303 | 304 | summary_writer = tf.summary.FileWriter( 305 | os.path.join(FLAGS.log_dir), 306 | graph=sess.graph) 307 | 308 | net_loss_value, precision_value = {}, {} 309 | 310 | for step in xrange(FLAGS.max_number_of_steps): 311 | 312 | for i in range(FLAGS.num_networks): 313 | _, net_loss_value["{0}".format(i)], precision_value["{0}".format(i)] = \ 314 | sess.run([net_train_op["{0}".format(i)], net_loss["{0}".format(i)], 315 | precision["{0}".format(i)]]) 316 | assert not np.isnan(net_loss_value["{0}".format(i)]), 'Model diverged with loss = NaN' 317 | 318 | if step % 10 == 0: 319 | format_str = '%s: step %d, net0_loss = %.2f, net0_acc = %.4f' 320 | print(format_str % (FLAGS.dataset_name, step, net_loss_value["{0}".format(0)], 321 | precision_value["{0}".format(0)])) 322 | 323 | if step % 100 == 0: 324 | summary_str = sess.run(summary_op) 325 | summary_writer.add_summary(summary_str, step) 326 | 327 | # Save the model checkpoint periodically. 328 | if step % FLAGS.ckpt_steps == 0 or (step + 1) == FLAGS.max_number_of_steps: 329 | checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt') 330 | saver.save(sess, checkpoint_path, global_step=step) 331 | 332 | -------------------------------------------------------------------------------- /nets/inception_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition for inception v1 classification network.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import inception_utils 24 | 25 | slim = tf.contrib.slim 26 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 27 | 28 | 29 | def inception_v1_base(inputs, 30 | final_endpoint='Mixed_5c', 31 | scope='InceptionV1'): 32 | """Defines the Inception V1 base architecture. 33 | 34 | This architecture is defined in: 35 | Going deeper with convolutions 36 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 37 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 38 | http://arxiv.org/pdf/1409.4842v1.pdf. 39 | 40 | Args: 41 | inputs: a tensor of size [batch_size, height, width, channels]. 42 | final_endpoint: specifies the endpoint to construct the network up to. It 43 | can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 44 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 45 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 46 | 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c'] 47 | scope: Optional variable_scope. 48 | 49 | Returns: 50 | A dictionary from components of the network to the corresponding activation. 51 | 52 | Raises: 53 | ValueError: if final_endpoint is not set to one of the predefined values. 54 | """ 55 | end_points = {} 56 | with tf.variable_scope(scope, 'InceptionV1', [inputs]): 57 | with slim.arg_scope( 58 | [slim.conv2d, slim.fully_connected], 59 | weights_initializer=trunc_normal(0.01)): 60 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 61 | stride=1, padding='SAME'): 62 | end_point = 'Conv2d_1a_7x7' 63 | net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point) 64 | end_points[end_point] = net 65 | if final_endpoint == end_point: return net, end_points 66 | end_point = 'MaxPool_2a_3x3' 67 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 68 | end_points[end_point] = net 69 | if final_endpoint == end_point: return net, end_points 70 | end_point = 'Conv2d_2b_1x1' 71 | net = slim.conv2d(net, 64, [1, 1], scope=end_point) 72 | end_points[end_point] = net 73 | if final_endpoint == end_point: return net, end_points 74 | end_point = 'Conv2d_2c_3x3' 75 | net = slim.conv2d(net, 192, [3, 3], scope=end_point) 76 | end_points[end_point] = net 77 | if final_endpoint == end_point: return net, end_points 78 | end_point = 'MaxPool_3a_3x3' 79 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 80 | end_points[end_point] = net 81 | if final_endpoint == end_point: return net, end_points 82 | 83 | end_point = 'Mixed_3b' 84 | with tf.variable_scope(end_point): 85 | with tf.variable_scope('Branch_0'): 86 | branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1') 87 | with tf.variable_scope('Branch_1'): 88 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1') 89 | branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3') 90 | with tf.variable_scope('Branch_2'): 91 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1') 92 | branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3') 93 | with tf.variable_scope('Branch_3'): 94 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 95 | branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1') 96 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 97 | end_points[end_point] = net 98 | if final_endpoint == end_point: return net, end_points 99 | 100 | end_point = 'Mixed_3c' 101 | with tf.variable_scope(end_point): 102 | with tf.variable_scope('Branch_0'): 103 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 104 | with tf.variable_scope('Branch_1'): 105 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 106 | branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3') 107 | with tf.variable_scope('Branch_2'): 108 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 109 | branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3') 110 | with tf.variable_scope('Branch_3'): 111 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 112 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 113 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 114 | end_points[end_point] = net 115 | if final_endpoint == end_point: return net, end_points 116 | 117 | end_point = 'MaxPool_4a_3x3' 118 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point) 119 | end_points[end_point] = net 120 | if final_endpoint == end_point: return net, end_points 121 | 122 | end_point = 'Mixed_4b' 123 | with tf.variable_scope(end_point): 124 | with tf.variable_scope('Branch_0'): 125 | branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1') 126 | with tf.variable_scope('Branch_1'): 127 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1') 128 | branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3') 129 | with tf.variable_scope('Branch_2'): 130 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1') 131 | branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3') 132 | with tf.variable_scope('Branch_3'): 133 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 134 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 135 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 136 | end_points[end_point] = net 137 | if final_endpoint == end_point: return net, end_points 138 | 139 | end_point = 'Mixed_4c' 140 | with tf.variable_scope(end_point): 141 | with tf.variable_scope('Branch_0'): 142 | branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 143 | with tf.variable_scope('Branch_1'): 144 | branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1') 145 | branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3') 146 | with tf.variable_scope('Branch_2'): 147 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1') 148 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 149 | with tf.variable_scope('Branch_3'): 150 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 151 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 152 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 153 | end_points[end_point] = net 154 | if final_endpoint == end_point: return net, end_points 155 | 156 | end_point = 'Mixed_4d' 157 | with tf.variable_scope(end_point): 158 | with tf.variable_scope('Branch_0'): 159 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 160 | with tf.variable_scope('Branch_1'): 161 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1') 162 | branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3') 163 | with tf.variable_scope('Branch_2'): 164 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1') 165 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 166 | with tf.variable_scope('Branch_3'): 167 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 168 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 169 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 170 | end_points[end_point] = net 171 | if final_endpoint == end_point: return net, end_points 172 | 173 | end_point = 'Mixed_4e' 174 | with tf.variable_scope(end_point): 175 | with tf.variable_scope('Branch_0'): 176 | branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1') 177 | with tf.variable_scope('Branch_1'): 178 | branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1') 179 | branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3') 180 | with tf.variable_scope('Branch_2'): 181 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 182 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3') 183 | with tf.variable_scope('Branch_3'): 184 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 185 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1') 186 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 187 | end_points[end_point] = net 188 | if final_endpoint == end_point: return net, end_points 189 | 190 | end_point = 'Mixed_4f' 191 | with tf.variable_scope(end_point): 192 | with tf.variable_scope('Branch_0'): 193 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1') 194 | with tf.variable_scope('Branch_1'): 195 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 196 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3') 197 | with tf.variable_scope('Branch_2'): 198 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 199 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3') 200 | with tf.variable_scope('Branch_3'): 201 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 202 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 203 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 204 | end_points[end_point] = net 205 | if final_endpoint == end_point: return net, end_points 206 | 207 | end_point = 'MaxPool_5a_2x2' 208 | net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point) 209 | end_points[end_point] = net 210 | if final_endpoint == end_point: return net, end_points 211 | 212 | end_point = 'Mixed_5b' 213 | with tf.variable_scope(end_point): 214 | with tf.variable_scope('Branch_0'): 215 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1') 216 | with tf.variable_scope('Branch_1'): 217 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1') 218 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3') 219 | with tf.variable_scope('Branch_2'): 220 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1') 221 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3') 222 | with tf.variable_scope('Branch_3'): 223 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 224 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 225 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 226 | end_points[end_point] = net 227 | if final_endpoint == end_point: return net, end_points 228 | 229 | end_point = 'Mixed_5c' 230 | with tf.variable_scope(end_point): 231 | with tf.variable_scope('Branch_0'): 232 | branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1') 233 | with tf.variable_scope('Branch_1'): 234 | branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1') 235 | branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3') 236 | with tf.variable_scope('Branch_2'): 237 | branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1') 238 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3') 239 | with tf.variable_scope('Branch_3'): 240 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3') 241 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 242 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 243 | end_points[end_point] = net 244 | if final_endpoint == end_point: return net, end_points 245 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 246 | 247 | 248 | def inception_v1(inputs, 249 | num_classes=1000, 250 | is_training=True, 251 | dropout_keep_prob=0.8, 252 | prediction_fn=slim.softmax, 253 | spatial_squeeze=True, 254 | reuse=None, 255 | scope='InceptionV1'): 256 | """Defines the Inception V1 architecture. 257 | 258 | This architecture is defined in: 259 | 260 | Going deeper with convolutions 261 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, 262 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. 263 | http://arxiv.org/pdf/1409.4842v1.pdf. 264 | 265 | The default image size used to train this network is 224x224. 266 | 267 | Args: 268 | inputs: a tensor of size [batch_size, height, width, channels]. 269 | num_classes: number of predicted classes. 270 | is_training: whether is training or not. 271 | dropout_keep_prob: the percentage of activation values that are retained. 272 | prediction_fn: a function to get predictions out of logits. 273 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 274 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 275 | reuse: whether or not the network and its variables should be reused. To be 276 | able to reuse 'scope' must be given. 277 | scope: Optional variable_scope. 278 | 279 | Returns: 280 | logits: the pre-softmax activations, a tensor of size 281 | [batch_size, num_classes] 282 | end_points: a dictionary from components of the network to the corresponding 283 | activation. 284 | """ 285 | # Final pooling and prediction 286 | with tf.variable_scope(scope, 'InceptionV1', [inputs, num_classes], 287 | reuse=reuse) as scope: 288 | with slim.arg_scope([slim.batch_norm, slim.dropout], 289 | is_training=is_training): 290 | net, end_points = inception_v1_base(inputs, scope=scope) 291 | with tf.variable_scope('Logits'): 292 | net = slim.avg_pool2d(net, [7, 7], stride=1, scope='AvgPool_0a_7x7') 293 | net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b') 294 | end_points['PreLogits'] = net 295 | if not num_classes: 296 | return net, end_points 297 | 298 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 299 | normalizer_fn=None, scope='Conv2d_0c_1x1') 300 | if spatial_squeeze: 301 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 302 | 303 | end_points['Logits'] = logits 304 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 305 | return logits, end_points 306 | 307 | 308 | inception_v1.default_image_size = 224 309 | 310 | inception_v1_arg_scope = inception_utils.inception_arg_scope 311 | -------------------------------------------------------------------------------- /nets/mobilenet_v1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """MobileNet v1. 16 | 17 | MobileNet is a general architecture and can be used for multiple use cases. 18 | Depending on the use case, it can use different input layer size and different 19 | head (for example: embeddings, localization and classification). 20 | 21 | As described in https://arxiv.org/abs/1704.04861. 22 | 23 | MobileNets: Efficient Convolutional Neural Networks for 24 | Mobile Vision Applications 25 | Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang, 26 | Tobias Weyand, Marco Andreetto, Hartwig Adam 27 | 28 | 100% Mobilenet V1 (base) with input size 224x224: 29 | 30 | See mobilenet_v1() 31 | 32 | Layer params macs 33 | -------------------------------------------------------------------------------- 34 | MobilenetV1/Conv2d_0/Conv2D: 864 10,838,016 35 | MobilenetV1/Conv2d_1_depthwise/depthwise: 288 3,612,672 36 | MobilenetV1/Conv2d_1_pointwise/Conv2D: 2,048 25,690,112 37 | MobilenetV1/Conv2d_2_depthwise/depthwise: 576 1,806,336 38 | MobilenetV1/Conv2d_2_pointwise/Conv2D: 8,192 25,690,112 39 | MobilenetV1/Conv2d_3_depthwise/depthwise: 1,152 3,612,672 40 | MobilenetV1/Conv2d_3_pointwise/Conv2D: 16,384 51,380,224 41 | MobilenetV1/Conv2d_4_depthwise/depthwise: 1,152 903,168 42 | MobilenetV1/Conv2d_4_pointwise/Conv2D: 32,768 25,690,112 43 | MobilenetV1/Conv2d_5_depthwise/depthwise: 2,304 1,806,336 44 | MobilenetV1/Conv2d_5_pointwise/Conv2D: 65,536 51,380,224 45 | MobilenetV1/Conv2d_6_depthwise/depthwise: 2,304 451,584 46 | MobilenetV1/Conv2d_6_pointwise/Conv2D: 131,072 25,690,112 47 | MobilenetV1/Conv2d_7_depthwise/depthwise: 4,608 903,168 48 | MobilenetV1/Conv2d_7_pointwise/Conv2D: 262,144 51,380,224 49 | MobilenetV1/Conv2d_8_depthwise/depthwise: 4,608 903,168 50 | MobilenetV1/Conv2d_8_pointwise/Conv2D: 262,144 51,380,224 51 | MobilenetV1/Conv2d_9_depthwise/depthwise: 4,608 903,168 52 | MobilenetV1/Conv2d_9_pointwise/Conv2D: 262,144 51,380,224 53 | MobilenetV1/Conv2d_10_depthwise/depthwise: 4,608 903,168 54 | MobilenetV1/Conv2d_10_pointwise/Conv2D: 262,144 51,380,224 55 | MobilenetV1/Conv2d_11_depthwise/depthwise: 4,608 903,168 56 | MobilenetV1/Conv2d_11_pointwise/Conv2D: 262,144 51,380,224 57 | MobilenetV1/Conv2d_12_depthwise/depthwise: 4,608 225,792 58 | MobilenetV1/Conv2d_12_pointwise/Conv2D: 524,288 25,690,112 59 | MobilenetV1/Conv2d_13_depthwise/depthwise: 9,216 451,584 60 | MobilenetV1/Conv2d_13_pointwise/Conv2D: 1,048,576 51,380,224 61 | -------------------------------------------------------------------------------- 62 | Total: 3,185,088 567,716,352 63 | 64 | 65 | 75% Mobilenet V1 (base) with input size 128x128: 66 | 67 | See mobilenet_v1_075() 68 | 69 | Layer params macs 70 | -------------------------------------------------------------------------------- 71 | MobilenetV1/Conv2d_0/Conv2D: 648 2,654,208 72 | MobilenetV1/Conv2d_1_depthwise/depthwise: 216 884,736 73 | MobilenetV1/Conv2d_1_pointwise/Conv2D: 1,152 4,718,592 74 | MobilenetV1/Conv2d_2_depthwise/depthwise: 432 442,368 75 | MobilenetV1/Conv2d_2_pointwise/Conv2D: 4,608 4,718,592 76 | MobilenetV1/Conv2d_3_depthwise/depthwise: 864 884,736 77 | MobilenetV1/Conv2d_3_pointwise/Conv2D: 9,216 9,437,184 78 | MobilenetV1/Conv2d_4_depthwise/depthwise: 864 221,184 79 | MobilenetV1/Conv2d_4_pointwise/Conv2D: 18,432 4,718,592 80 | MobilenetV1/Conv2d_5_depthwise/depthwise: 1,728 442,368 81 | MobilenetV1/Conv2d_5_pointwise/Conv2D: 36,864 9,437,184 82 | MobilenetV1/Conv2d_6_depthwise/depthwise: 1,728 110,592 83 | MobilenetV1/Conv2d_6_pointwise/Conv2D: 73,728 4,718,592 84 | MobilenetV1/Conv2d_7_depthwise/depthwise: 3,456 221,184 85 | MobilenetV1/Conv2d_7_pointwise/Conv2D: 147,456 9,437,184 86 | MobilenetV1/Conv2d_8_depthwise/depthwise: 3,456 221,184 87 | MobilenetV1/Conv2d_8_pointwise/Conv2D: 147,456 9,437,184 88 | MobilenetV1/Conv2d_9_depthwise/depthwise: 3,456 221,184 89 | MobilenetV1/Conv2d_9_pointwise/Conv2D: 147,456 9,437,184 90 | MobilenetV1/Conv2d_10_depthwise/depthwise: 3,456 221,184 91 | MobilenetV1/Conv2d_10_pointwise/Conv2D: 147,456 9,437,184 92 | MobilenetV1/Conv2d_11_depthwise/depthwise: 3,456 221,184 93 | MobilenetV1/Conv2d_11_pointwise/Conv2D: 147,456 9,437,184 94 | MobilenetV1/Conv2d_12_depthwise/depthwise: 3,456 55,296 95 | MobilenetV1/Conv2d_12_pointwise/Conv2D: 294,912 4,718,592 96 | MobilenetV1/Conv2d_13_depthwise/depthwise: 6,912 110,592 97 | MobilenetV1/Conv2d_13_pointwise/Conv2D: 589,824 9,437,184 98 | -------------------------------------------------------------------------------- 99 | Total: 1,800,144 106,002,432 100 | 101 | """ 102 | 103 | # Tensorflow mandates these. 104 | from __future__ import absolute_import 105 | from __future__ import division 106 | from __future__ import print_function 107 | 108 | from collections import namedtuple 109 | import functools 110 | 111 | import tensorflow as tf 112 | 113 | slim = tf.contrib.slim 114 | 115 | # Conv and DepthSepConv namedtuple define layers of the MobileNet architecture 116 | # Conv defines 3x3 convolution layers 117 | # DepthSepConv defines 3x3 depthwise convolution followed by 1x1 convolution. 118 | # stride is the stride of the convolution 119 | # depth is the number of channels or filters in a layer 120 | Conv = namedtuple('Conv', ['kernel', 'stride', 'depth']) 121 | DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth']) 122 | 123 | # _CONV_DEFS specifies the MobileNet body 124 | _CONV_DEFS = [ 125 | Conv(kernel=[3, 3], stride=2, depth=32), 126 | DepthSepConv(kernel=[3, 3], stride=1, depth=64), 127 | DepthSepConv(kernel=[3, 3], stride=2, depth=128), 128 | DepthSepConv(kernel=[3, 3], stride=1, depth=128), 129 | DepthSepConv(kernel=[3, 3], stride=2, depth=256), 130 | DepthSepConv(kernel=[3, 3], stride=1, depth=256), 131 | DepthSepConv(kernel=[3, 3], stride=2, depth=512), 132 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 133 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 134 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 135 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 136 | DepthSepConv(kernel=[3, 3], stride=1, depth=512), 137 | DepthSepConv(kernel=[3, 3], stride=2, depth=1024), 138 | DepthSepConv(kernel=[3, 3], stride=1, depth=1024) 139 | ] 140 | 141 | 142 | def mobilenet_v1_base(inputs, 143 | final_endpoint='Conv2d_13_pointwise', 144 | min_depth=8, 145 | depth_multiplier=1.0, 146 | conv_defs=None, 147 | output_stride=None, 148 | scope=None): 149 | """Mobilenet v1. 150 | 151 | Constructs a Mobilenet v1 network from inputs to the given final endpoint. 152 | 153 | Args: 154 | inputs: a tensor of shape [batch_size, height, width, channels]. 155 | final_endpoint: specifies the endpoint to construct the network up to. It 156 | can be one of ['Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise', 157 | 'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5'_pointwise, 158 | 'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise', 159 | 'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise', 160 | 'Conv2d_12_pointwise', 'Conv2d_13_pointwise']. 161 | min_depth: Minimum depth value (number of channels) for all convolution ops. 162 | Enforced when depth_multiplier < 1, and not an active constraint when 163 | depth_multiplier >= 1. 164 | depth_multiplier: Float multiplier for the depth (number of channels) 165 | for all convolution ops. The value must be greater than zero. Typical 166 | usage will be to set this value in (0, 1) to reduce the number of 167 | parameters or computation cost of the model. 168 | conv_defs: A list of ConvDef namedtuples specifying the net architecture. 169 | output_stride: An integer that specifies the requested ratio of input to 170 | output spatial resolution. If not None, then we invoke atrous convolution 171 | if necessary to prevent the network from reducing the spatial resolution 172 | of the activation maps. Allowed values are 8 (accurate fully convolutional 173 | mode), 16 (fast fully convolutional mode), 32 (classification mode). 174 | scope: Optional variable_scope. 175 | 176 | Returns: 177 | tensor_out: output tensor corresponding to the final_endpoint. 178 | end_points: a set of activations for external use, for example summaries or 179 | losses. 180 | 181 | Raises: 182 | ValueError: if final_endpoint is not set to one of the predefined values, 183 | or depth_multiplier <= 0, or the target output_stride is not 184 | allowed. 185 | """ 186 | depth = lambda d: max(int(d * depth_multiplier), min_depth) 187 | end_points = {} 188 | 189 | # Used to find thinned depths for each layer. 190 | if depth_multiplier <= 0: 191 | raise ValueError('depth_multiplier is not greater than zero.') 192 | 193 | if conv_defs is None: 194 | conv_defs = _CONV_DEFS 195 | 196 | if output_stride is not None and output_stride not in [8, 16, 32]: 197 | raise ValueError('Only allowed output_stride values are 8, 16, 32.') 198 | 199 | with tf.variable_scope(scope, 'MobilenetV1', [inputs]): 200 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME'): 201 | # The current_stride variable keeps track of the output stride of the 202 | # activations, i.e., the running product of convolution strides up to the 203 | # current network layer. This allows us to invoke atrous convolution 204 | # whenever applying the next convolution would result in the activations 205 | # having output stride larger than the target output_stride. 206 | current_stride = 1 207 | 208 | # The atrous convolution rate parameter. 209 | rate = 1 210 | 211 | net = inputs 212 | for i, conv_def in enumerate(conv_defs): 213 | end_point_base = 'Conv2d_%d' % i 214 | 215 | if output_stride is not None and current_stride == output_stride: 216 | # If we have reached the target output_stride, then we need to employ 217 | # atrous convolution with stride=1 and multiply the atrous rate by the 218 | # current unit's stride for use in subsequent layers. 219 | layer_stride = 1 220 | layer_rate = rate 221 | rate *= conv_def.stride 222 | else: 223 | layer_stride = conv_def.stride 224 | layer_rate = 1 225 | current_stride *= conv_def.stride 226 | 227 | if isinstance(conv_def, Conv): 228 | end_point = end_point_base 229 | net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel, 230 | stride=conv_def.stride, 231 | normalizer_fn=slim.batch_norm, 232 | scope=end_point) 233 | end_points[end_point] = net 234 | if end_point == final_endpoint: 235 | return net, end_points 236 | 237 | elif isinstance(conv_def, DepthSepConv): 238 | end_point = end_point_base + '_depthwise' 239 | 240 | # By passing filters=None 241 | # separable_conv2d produces only a depthwise convolution layer 242 | net = slim.separable_conv2d(net, None, conv_def.kernel, 243 | depth_multiplier=1, 244 | stride=layer_stride, 245 | rate=layer_rate, 246 | normalizer_fn=slim.batch_norm, 247 | scope=end_point) 248 | 249 | end_points[end_point] = net 250 | if end_point == final_endpoint: 251 | return net, end_points 252 | 253 | end_point = end_point_base + '_pointwise' 254 | 255 | net = slim.conv2d(net, depth(conv_def.depth), [1, 1], 256 | stride=1, 257 | normalizer_fn=slim.batch_norm, 258 | scope=end_point) 259 | 260 | end_points[end_point] = net 261 | if end_point == final_endpoint: 262 | return net, end_points 263 | else: 264 | raise ValueError('Unknown convolution type %s for layer %d' 265 | % (conv_def.ltype, i)) 266 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 267 | 268 | 269 | def mobilenet_v1(inputs, 270 | num_classes=1000, 271 | dropout_keep_prob=0.999, 272 | is_training=True, 273 | min_depth=8, 274 | depth_multiplier=1.0, 275 | conv_defs=None, 276 | prediction_fn=tf.contrib.layers.softmax, 277 | spatial_squeeze=True, 278 | reuse=None, 279 | scope='MobilenetV1', 280 | global_pool=False): 281 | """Mobilenet v1 model for classification. 282 | 283 | Args: 284 | inputs: a tensor of shape [batch_size, height, width, channels]. 285 | num_classes: number of predicted classes. If 0 or None, the logits layer 286 | is omitted and the input features to the logits layer (before dropout) 287 | are returned instead. 288 | dropout_keep_prob: the percentage of activation values that are retained. 289 | is_training: whether is training or not. 290 | min_depth: Minimum depth value (number of channels) for all convolution ops. 291 | Enforced when depth_multiplier < 1, and not an active constraint when 292 | depth_multiplier >= 1. 293 | depth_multiplier: Float multiplier for the depth (number of channels) 294 | for all convolution ops. The value must be greater than zero. Typical 295 | usage will be to set this value in (0, 1) to reduce the number of 296 | parameters or computation cost of the model. 297 | conv_defs: A list of ConvDef namedtuples specifying the net architecture. 298 | prediction_fn: a function to get predictions out of logits. 299 | spatial_squeeze: if True, logits is of shape is [B, C], if false logits is 300 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 301 | reuse: whether or not the network and its variables should be reused. To be 302 | able to reuse 'scope' must be given. 303 | scope: Optional variable_scope. 304 | global_pool: Optional boolean flag to control the avgpooling before the 305 | logits layer. If false or unset, pooling is done with a fixed window 306 | that reduces default-sized inputs to 1x1, while larger inputs lead to 307 | larger outputs. If true, any input size is pooled down to 1x1. 308 | 309 | Returns: 310 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 311 | is a non-zero integer, or the non-dropped-out input to the logits layer 312 | if num_classes is 0 or None. 313 | end_points: a dictionary from components of the network to the corresponding 314 | activation. 315 | 316 | Raises: 317 | ValueError: Input rank is invalid. 318 | """ 319 | input_shape = inputs.get_shape().as_list() 320 | if len(input_shape) != 4: 321 | raise ValueError('Invalid input tensor rank, expected 4, was: %d' % 322 | len(input_shape)) 323 | 324 | with tf.variable_scope(scope, 'MobilenetV1', [inputs], reuse=reuse) as scope: 325 | with slim.arg_scope([slim.batch_norm, slim.dropout], 326 | is_training=is_training): 327 | net, end_points = mobilenet_v1_base(inputs, scope=scope, 328 | min_depth=min_depth, 329 | depth_multiplier=depth_multiplier, 330 | conv_defs=conv_defs) 331 | end_points['FeatureMap'] = net 332 | with tf.variable_scope('Logits'): 333 | if global_pool: 334 | # Global average pooling. 335 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 336 | end_points['global_pool'] = net 337 | else: 338 | # Pooling with a fixed kernel size. 339 | kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7]) 340 | net = slim.avg_pool2d(net, kernel_size, padding='VALID', 341 | scope='AvgPool_1a') 342 | end_points['AvgPool_1a'] = net 343 | # 1 x 1 x 1024 344 | net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b') 345 | end_points['PreLogits'] = net 346 | if not num_classes: 347 | return net, end_points 348 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 349 | normalizer_fn=None, scope='Conv2d_1c_1x1') 350 | if spatial_squeeze: 351 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze') 352 | end_points['Logits'] = logits 353 | if prediction_fn: 354 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 355 | return logits, end_points 356 | 357 | 358 | mobilenet_v1.default_image_size = 224 359 | 360 | 361 | def wrapped_partial(func, *args, **kwargs): 362 | partial_func = functools.partial(func, *args, **kwargs) 363 | functools.update_wrapper(partial_func, func) 364 | return partial_func 365 | 366 | 367 | mobilenet_v1_075 = wrapped_partial(mobilenet_v1, depth_multiplier=0.75) 368 | mobilenet_v1_050 = wrapped_partial(mobilenet_v1, depth_multiplier=0.50) 369 | mobilenet_v1_025 = wrapped_partial(mobilenet_v1, depth_multiplier=0.25) 370 | 371 | 372 | def _reduced_kernel_size_for_small_input(input_tensor, kernel_size): 373 | """Define kernel size which is automatically reduced for small input. 374 | 375 | If the shape of the input images is unknown at graph construction time this 376 | function assumes that the input images are large enough. 377 | 378 | Args: 379 | input_tensor: input tensor of size [batch_size, height, width, channels]. 380 | kernel_size: desired kernel size of length 2: [kernel_height, kernel_width] 381 | 382 | Returns: 383 | a tensor with the kernel size. 384 | """ 385 | shape = input_tensor.get_shape().as_list() 386 | if shape[1] is None or shape[2] is None: 387 | kernel_size_out = kernel_size 388 | else: 389 | kernel_size_out = [min(shape[1], kernel_size[0]), 390 | min(shape[2], kernel_size[1])] 391 | return kernel_size_out 392 | 393 | 394 | def mobilenet_v1_arg_scope(is_training=True, 395 | weight_decay=0.00004, 396 | stddev=0.09, 397 | regularize_depthwise=False): 398 | """Defines the default MobilenetV1 arg scope. 399 | 400 | Args: 401 | is_training: Whether or not we're training the model. 402 | weight_decay: The weight decay to use for regularizing the model. 403 | stddev: The standard deviation of the trunctated normal weight initializer. 404 | regularize_depthwise: Whether or not apply regularization on depthwise. 405 | 406 | Returns: 407 | An `arg_scope` to use for the mobilenet v1 model. 408 | """ 409 | batch_norm_params = { 410 | 'is_training': is_training, 411 | 'center': True, 412 | 'scale': True, 413 | 'decay': 0.9997, 414 | 'epsilon': 0.001, 415 | } 416 | 417 | # Set weight_decay for weights in Conv and DepthSepConv layers. 418 | weights_init = tf.truncated_normal_initializer(stddev=stddev) 419 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay) 420 | if regularize_depthwise: 421 | depthwise_regularizer = regularizer 422 | else: 423 | depthwise_regularizer = None 424 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], 425 | weights_initializer=weights_init, 426 | activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm): 427 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 428 | with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer): 429 | with slim.arg_scope([slim.separable_conv2d], 430 | weights_regularizer=depthwise_regularizer) as sc: 431 | return sc 432 | -------------------------------------------------------------------------------- /deployment/model_deploy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Deploy Slim models across multiple clones and replicas. 16 | 17 | # TODO(sguada) docstring paragraph by (a) motivating the need for the file and 18 | # (b) defining clones. 19 | 20 | # TODO(sguada) describe the high-level components of model deployment. 21 | # E.g. "each model deployment is composed of several parts: a DeploymentConfig, 22 | # which captures A, B and C, an input_fn which loads data.. etc 23 | 24 | To easily train a model on multiple GPUs or across multiple machines this 25 | module provides a set of helper functions: `create_clones`, 26 | `optimize_clones` and `deploy`. 27 | 28 | Usage: 29 | 30 | g = tf.Graph() 31 | 32 | # Set up DeploymentConfig 33 | config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True) 34 | 35 | # Create the global step on the device storing the variables. 36 | with tf.device(config.variables_device()): 37 | global_step = slim.create_global_step() 38 | 39 | # Define the inputs 40 | with tf.device(config.inputs_device()): 41 | images, labels = LoadData(...) 42 | inputs_queue = slim.data.prefetch_queue((images, labels)) 43 | 44 | # Define the optimizer. 45 | with tf.device(config.optimizer_device()): 46 | optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) 47 | 48 | # Define the model including the loss. 49 | def model_fn(inputs_queue): 50 | images, labels = inputs_queue.dequeue() 51 | predictions = CreateNetwork(images) 52 | slim.losses.log_loss(predictions, labels) 53 | 54 | model_dp = model_deploy.deploy(config, model_fn, [inputs_queue], 55 | optimizer=optimizer) 56 | 57 | # Run training. 58 | slim.learning.train(model_dp.train_op, my_log_dir, 59 | summary_op=model_dp.summary_op) 60 | 61 | The Clone namedtuple holds together the values associated with each call to 62 | model_fn: 63 | * outputs: The return values of the calls to `model_fn()`. 64 | * scope: The scope used to create the clone. 65 | * device: The device used to create the clone. 66 | 67 | DeployedModel namedtuple, holds together the values needed to train multiple 68 | clones: 69 | * train_op: An operation that run the optimizer training op and include 70 | all the update ops created by `model_fn`. Present only if an optimizer 71 | was specified. 72 | * summary_op: An operation that run the summaries created by `model_fn` 73 | and process_gradients. 74 | * total_loss: A `Tensor` that contains the sum of all losses created by 75 | `model_fn` plus the regularization losses. 76 | * clones: List of `Clone` tuples returned by `create_clones()`. 77 | 78 | DeploymentConfig parameters: 79 | * num_clones: Number of model clones to deploy in each replica. 80 | * clone_on_cpu: True if clones should be placed on CPU. 81 | * replica_id: Integer. Index of the replica for which the model is 82 | deployed. Usually 0 for the chief replica. 83 | * num_replicas: Number of replicas to use. 84 | * num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 85 | * worker_job_name: A name for the worker job. 86 | * ps_job_name: A name for the parameter server job. 87 | 88 | TODO(sguada): 89 | - describe side effect to the graph. 90 | - what happens to summaries and update_ops. 91 | - which graph collections are altered. 92 | - write a tutorial on how to use this. 93 | - analyze the possibility of calling deploy more than once. 94 | 95 | 96 | """ 97 | 98 | from __future__ import absolute_import 99 | from __future__ import division 100 | from __future__ import print_function 101 | 102 | import collections 103 | 104 | import tensorflow as tf 105 | 106 | from tensorflow.python.ops import control_flow_ops 107 | 108 | slim = tf.contrib.slim 109 | 110 | __all__ = ['create_clones', 111 | 'deploy', 112 | 'optimize_clones', 113 | 'DeployedModel', 114 | 'DeploymentConfig', 115 | 'Clone', 116 | ] 117 | 118 | # Namedtuple used to represent a clone during deployment. 119 | Clone = collections.namedtuple('Clone', 120 | ['outputs', # Whatever model_fn() returned. 121 | 'scope', # The scope used to create it. 122 | 'device', # The device used to create. 123 | ]) 124 | 125 | # Namedtuple used to represent a DeployedModel, returned by deploy(). 126 | DeployedModel = collections.namedtuple('DeployedModel', 127 | ['train_op', # The `train_op` 128 | 'summary_op', # The `summary_op` 129 | 'total_loss', # The loss `Tensor` 130 | 'clones', # A list of `Clones` tuples. 131 | ]) 132 | 133 | # Default parameters for DeploymentConfig 134 | _deployment_params = {'num_clones': 1, 135 | 'clone_on_cpu': False, 136 | 'replica_id': 0, 137 | 'num_replicas': 1, 138 | 'num_ps_tasks': 0, 139 | 'worker_job_name': 'worker', 140 | 'ps_job_name': 'ps'} 141 | 142 | 143 | def create_clones(config, model_fn, args=None, kwargs=None): 144 | """Creates multiple clones according to config using a `model_fn`. 145 | 146 | The returned values of `model_fn(*args, **kwargs)` are collected along with 147 | the scope and device used to created it in a namedtuple 148 | `Clone(outputs, scope, device)` 149 | 150 | Note: it is assumed that any loss created by `model_fn` is collected at 151 | the tf.GraphKeys.LOSSES collection. 152 | 153 | To recover the losses, summaries or update_ops created by the clone use: 154 | ```python 155 | losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 156 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope) 157 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope) 158 | ``` 159 | 160 | The deployment options are specified by the config object and support 161 | deploying one or several clones on different GPUs and one or several replicas 162 | of such clones. 163 | 164 | The argument `model_fn` is called `config.num_clones` times to create the 165 | model clones as `model_fn(*args, **kwargs)`. 166 | 167 | If `config` specifies deployment on multiple replicas then the default 168 | tensorflow device is set appropriatly for each call to `model_fn` and for the 169 | slim variable creation functions: model and global variables will be created 170 | on the `ps` device, the clone operations will be on the `worker` device. 171 | 172 | Args: 173 | config: A DeploymentConfig object. 174 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 175 | args: Optional list of arguments to pass to `model_fn`. 176 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 177 | 178 | Returns: 179 | A list of namedtuples `Clone`. 180 | """ 181 | clones = [] 182 | args = args or [] 183 | kwargs = kwargs or {} 184 | with slim.arg_scope([slim.model_variable, slim.variable], 185 | device=config.variables_device()): 186 | # Create clones. 187 | for i in range(0, config.num_clones): 188 | with tf.name_scope(config.clone_scope(i)) as clone_scope: 189 | clone_device = config.clone_device(i) 190 | with tf.device(clone_device): 191 | with tf.variable_scope(tf.get_variable_scope(), 192 | reuse=True if i > 0 else None): 193 | outputs = model_fn(*args, **kwargs) 194 | clones.append(Clone(outputs, clone_scope, clone_device)) 195 | return clones 196 | 197 | 198 | def _gather_clone_loss(clone, num_clones, regularization_losses): 199 | """Gather the loss for a single clone. 200 | 201 | Args: 202 | clone: A Clone namedtuple. 203 | num_clones: The number of clones being deployed. 204 | regularization_losses: Possibly empty list of regularization_losses 205 | to add to the clone losses. 206 | 207 | Returns: 208 | A tensor for the total loss for the clone. Can be None. 209 | """ 210 | # The return value. 211 | sum_loss = None 212 | # Individual components of the loss that will need summaries. 213 | clone_loss = None 214 | regularization_loss = None 215 | # Compute and aggregate losses on the clone device. 216 | with tf.device(clone.device): 217 | all_losses = [] 218 | clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 219 | if clone_losses: 220 | clone_loss = tf.add_n(clone_losses, name='clone_loss') 221 | if num_clones > 1: 222 | clone_loss = tf.div(clone_loss, 1.0 * num_clones, 223 | name='scaled_clone_loss') 224 | all_losses.append(clone_loss) 225 | if regularization_losses: 226 | regularization_loss = tf.add_n(regularization_losses, 227 | name='regularization_loss') 228 | all_losses.append(regularization_loss) 229 | if all_losses: 230 | sum_loss = tf.add_n(all_losses) 231 | # Add the summaries out of the clone device block. 232 | if clone_loss is not None: 233 | tf.summary.scalar(clone.scope + '/clone_loss', clone_loss) 234 | if regularization_loss is not None: 235 | tf.summary.scalar('regularization_loss', regularization_loss) 236 | return sum_loss 237 | 238 | 239 | def _optimize_clone(optimizer, clone, num_clones, regularization_losses, 240 | **kwargs): 241 | """Compute losses and gradients for a single clone. 242 | 243 | Args: 244 | optimizer: A tf.Optimizer object. 245 | clone: A Clone namedtuple. 246 | num_clones: The number of clones being deployed. 247 | regularization_losses: Possibly empty list of regularization_losses 248 | to add to the clone losses. 249 | **kwargs: Dict of kwarg to pass to compute_gradients(). 250 | 251 | Returns: 252 | A tuple (clone_loss, clone_grads_and_vars). 253 | - clone_loss: A tensor for the total loss for the clone. Can be None. 254 | - clone_grads_and_vars: List of (gradient, variable) for the clone. 255 | Can be empty. 256 | """ 257 | sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses) 258 | clone_grad = None 259 | if sum_loss is not None: 260 | with tf.device(clone.device): 261 | clone_grad = optimizer.compute_gradients(sum_loss, **kwargs) 262 | return sum_loss, clone_grad 263 | 264 | 265 | def optimize_clones(clones, optimizer, 266 | regularization_losses=None, 267 | **kwargs): 268 | """Compute clone losses and gradients for the given list of `Clones`. 269 | 270 | Note: The regularization_losses are added to the first clone losses. 271 | 272 | Args: 273 | clones: List of `Clones` created by `create_clones()`. 274 | optimizer: An `Optimizer` object. 275 | regularization_losses: Optional list of regularization losses. If None it 276 | will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to 277 | exclude them. 278 | **kwargs: Optional list of keyword arguments to pass to `compute_gradients`. 279 | 280 | Returns: 281 | A tuple (total_loss, grads_and_vars). 282 | - total_loss: A Tensor containing the average of the clone losses including 283 | the regularization loss. 284 | - grads_and_vars: A List of tuples (gradient, variable) containing the sum 285 | of the gradients for each variable. 286 | 287 | """ 288 | grads_and_vars = [] 289 | clones_losses = [] 290 | num_clones = len(clones) 291 | if regularization_losses is None: 292 | regularization_losses = tf.get_collection( 293 | tf.GraphKeys.REGULARIZATION_LOSSES) 294 | for clone in clones: 295 | with tf.name_scope(clone.scope): 296 | clone_loss, clone_grad = _optimize_clone( 297 | optimizer, clone, num_clones, regularization_losses, **kwargs) 298 | if clone_loss is not None: 299 | clones_losses.append(clone_loss) 300 | grads_and_vars.append(clone_grad) 301 | # Only use regularization_losses for the first clone 302 | regularization_losses = None 303 | # Compute the total_loss summing all the clones_losses. 304 | total_loss = tf.add_n(clones_losses, name='total_loss') 305 | # Sum the gradients across clones. 306 | grads_and_vars = _sum_clones_gradients(grads_and_vars) 307 | return total_loss, grads_and_vars 308 | 309 | 310 | def deploy(config, 311 | model_fn, 312 | args=None, 313 | kwargs=None, 314 | optimizer=None, 315 | summarize_gradients=False): 316 | """Deploys a Slim-constructed model across multiple clones. 317 | 318 | The deployment options are specified by the config object and support 319 | deploying one or several clones on different GPUs and one or several replicas 320 | of such clones. 321 | 322 | The argument `model_fn` is called `config.num_clones` times to create the 323 | model clones as `model_fn(*args, **kwargs)`. 324 | 325 | The optional argument `optimizer` is an `Optimizer` object. If not `None`, 326 | the deployed model is configured for training with that optimizer. 327 | 328 | If `config` specifies deployment on multiple replicas then the default 329 | tensorflow device is set appropriatly for each call to `model_fn` and for the 330 | slim variable creation functions: model and global variables will be created 331 | on the `ps` device, the clone operations will be on the `worker` device. 332 | 333 | Args: 334 | config: A `DeploymentConfig` object. 335 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 336 | args: Optional list of arguments to pass to `model_fn`. 337 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 338 | optimizer: Optional `Optimizer` object. If passed the model is deployed 339 | for training with that optimizer. 340 | summarize_gradients: Whether or not add summaries to the gradients. 341 | 342 | Returns: 343 | A `DeployedModel` namedtuple. 344 | 345 | """ 346 | # Gather initial summaries. 347 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 348 | 349 | # Create Clones. 350 | clones = create_clones(config, model_fn, args, kwargs) 351 | first_clone = clones[0] 352 | 353 | # Gather update_ops from the first clone. These contain, for example, 354 | # the updates for the batch_norm variables created by model_fn. 355 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope) 356 | 357 | train_op = None 358 | total_loss = None 359 | with tf.device(config.optimizer_device()): 360 | if optimizer: 361 | # Place the global step on the device storing the variables. 362 | with tf.device(config.variables_device()): 363 | global_step = slim.get_or_create_global_step() 364 | 365 | # Compute the gradients for the clones. 366 | total_loss, clones_gradients = optimize_clones(clones, optimizer) 367 | 368 | if clones_gradients: 369 | if summarize_gradients: 370 | # Add summaries to the gradients. 371 | summaries |= set(_add_gradients_summaries(clones_gradients)) 372 | 373 | # Create gradient updates. 374 | grad_updates = optimizer.apply_gradients(clones_gradients, 375 | global_step=global_step) 376 | update_ops.append(grad_updates) 377 | 378 | update_op = tf.group(*update_ops) 379 | train_op = control_flow_ops.with_dependencies([update_op], total_loss, 380 | name='train_op') 381 | else: 382 | clones_losses = [] 383 | regularization_losses = tf.get_collection( 384 | tf.GraphKeys.REGULARIZATION_LOSSES) 385 | for clone in clones: 386 | with tf.name_scope(clone.scope): 387 | clone_loss = _gather_clone_loss(clone, len(clones), 388 | regularization_losses) 389 | if clone_loss is not None: 390 | clones_losses.append(clone_loss) 391 | # Only use regularization_losses for the first clone 392 | regularization_losses = None 393 | if clones_losses: 394 | total_loss = tf.add_n(clones_losses, name='total_loss') 395 | 396 | # Add the summaries from the first clone. These contain the summaries 397 | # created by model_fn and either optimize_clones() or _gather_clone_loss(). 398 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, 399 | first_clone.scope)) 400 | 401 | if total_loss is not None: 402 | # Add total_loss to summary. 403 | summaries.add(tf.summary.scalar('total_loss', total_loss)) 404 | 405 | if summaries: 406 | # Merge all summaries together. 407 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 408 | else: 409 | summary_op = None 410 | 411 | return DeployedModel(train_op, summary_op, total_loss, clones) 412 | 413 | 414 | def _sum_clones_gradients(clone_grads): 415 | """Calculate the sum gradient for each shared variable across all clones. 416 | 417 | This function assumes that the clone_grads has been scaled appropriately by 418 | 1 / num_clones. 419 | 420 | Args: 421 | clone_grads: A List of List of tuples (gradient, variable), one list per 422 | `Clone`. 423 | 424 | Returns: 425 | List of tuples of (gradient, variable) where the gradient has been summed 426 | across all clones. 427 | """ 428 | sum_grads = [] 429 | for grad_and_vars in zip(*clone_grads): 430 | # Note that each grad_and_vars looks like the following: 431 | # ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN)) 432 | grads = [] 433 | var = grad_and_vars[0][1] 434 | for g, v in grad_and_vars: 435 | assert v == var 436 | if g is not None: 437 | grads.append(g) 438 | if grads: 439 | if len(grads) > 1: 440 | sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads') 441 | else: 442 | sum_grad = grads[0] 443 | sum_grads.append((sum_grad, var)) 444 | return sum_grads 445 | 446 | 447 | def _add_gradients_summaries(grads_and_vars): 448 | """Add histogram summaries to gradients. 449 | 450 | Note: The summaries are also added to the SUMMARIES collection. 451 | 452 | Args: 453 | grads_and_vars: A list of gradient to variable pairs (tuples). 454 | 455 | Returns: 456 | The _list_ of the added summaries for grads_and_vars. 457 | """ 458 | summaries = [] 459 | for grad, var in grads_and_vars: 460 | if grad is not None: 461 | if isinstance(grad, tf.IndexedSlices): 462 | grad_values = grad.values 463 | else: 464 | grad_values = grad 465 | summaries.append(tf.summary.histogram(var.op.name + ':gradient', 466 | grad_values)) 467 | summaries.append(tf.summary.histogram(var.op.name + ':gradient_norm', 468 | tf.global_norm([grad_values]))) 469 | else: 470 | tf.logging.info('Var %s has no gradient', var.op.name) 471 | return summaries 472 | 473 | 474 | class DeploymentConfig(object): 475 | """Configuration for deploying a model with `deploy()`. 476 | 477 | You can pass an instance of this class to `deploy()` to specify exactly 478 | how to deploy the model to build. If you do not pass one, an instance built 479 | from the default deployment_hparams will be used. 480 | """ 481 | 482 | def __init__(self, 483 | num_clones=1, 484 | clone_on_cpu=False, 485 | replica_id=0, 486 | num_replicas=1, 487 | num_ps_tasks=0, 488 | worker_job_name='worker', 489 | ps_job_name='ps'): 490 | """Create a DeploymentConfig. 491 | 492 | The config describes how to deploy a model across multiple clones and 493 | replicas. The model will be replicated `num_clones` times in each replica. 494 | If `clone_on_cpu` is True, each clone will placed on CPU. 495 | 496 | If `num_replicas` is 1, the model is deployed via a single process. In that 497 | case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored. 498 | 499 | If `num_replicas` is greater than 1, then `worker_device` and `ps_device` 500 | must specify TensorFlow devices for the `worker` and `ps` jobs and 501 | `num_ps_tasks` must be positive. 502 | 503 | Args: 504 | num_clones: Number of model clones to deploy in each replica. 505 | clone_on_cpu: If True clones would be placed on CPU. 506 | replica_id: Integer. Index of the replica for which the model is 507 | deployed. Usually 0 for the chief replica. 508 | num_replicas: Number of replicas to use. 509 | num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 510 | worker_job_name: A name for the worker job. 511 | ps_job_name: A name for the parameter server job. 512 | 513 | Raises: 514 | ValueError: If the arguments are invalid. 515 | """ 516 | if num_replicas > 1: 517 | if num_ps_tasks < 1: 518 | raise ValueError('When using replicas num_ps_tasks must be positive') 519 | if num_replicas > 1 or num_ps_tasks > 0: 520 | if not worker_job_name: 521 | raise ValueError('Must specify worker_job_name when using replicas') 522 | if not ps_job_name: 523 | raise ValueError('Must specify ps_job_name when using parameter server') 524 | if replica_id >= num_replicas: 525 | raise ValueError('replica_id must be less than num_replicas') 526 | self._num_clones = num_clones 527 | self._clone_on_cpu = clone_on_cpu 528 | self._replica_id = replica_id 529 | self._num_replicas = num_replicas 530 | self._num_ps_tasks = num_ps_tasks 531 | self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else '' 532 | self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else '' 533 | 534 | @property 535 | def num_clones(self): 536 | return self._num_clones 537 | 538 | @property 539 | def clone_on_cpu(self): 540 | return self._clone_on_cpu 541 | 542 | @property 543 | def replica_id(self): 544 | return self._replica_id 545 | 546 | @property 547 | def num_replicas(self): 548 | return self._num_replicas 549 | 550 | @property 551 | def num_ps_tasks(self): 552 | return self._num_ps_tasks 553 | 554 | @property 555 | def ps_device(self): 556 | return self._ps_device 557 | 558 | @property 559 | def worker_device(self): 560 | return self._worker_device 561 | 562 | def caching_device(self): 563 | """Returns the device to use for caching variables. 564 | 565 | Variables are cached on the worker CPU when using replicas. 566 | 567 | Returns: 568 | A device string or None if the variables do not need to be cached. 569 | """ 570 | if self._num_ps_tasks > 0: 571 | return lambda op: op.device 572 | else: 573 | return None 574 | 575 | def clone_device(self, clone_index): 576 | """Device used to create the clone and all the ops inside the clone. 577 | 578 | Args: 579 | clone_index: Int, representing the clone_index. 580 | 581 | Returns: 582 | A value suitable for `tf.device()`. 583 | 584 | Raises: 585 | ValueError: if `clone_index` is greater or equal to the number of clones". 586 | """ 587 | if clone_index >= self._num_clones: 588 | raise ValueError('clone_index must be less than num_clones') 589 | device = '' 590 | if self._num_ps_tasks > 0: 591 | device += self._worker_device 592 | if self._clone_on_cpu: 593 | device += '/device:CPU:0' 594 | else: 595 | if self._num_clones > 1: 596 | device += '/device:GPU:%d' % clone_index 597 | return device 598 | 599 | def clone_scope(self, clone_index): 600 | """Name scope to create the clone. 601 | 602 | Args: 603 | clone_index: Int, representing the clone_index. 604 | 605 | Returns: 606 | A name_scope suitable for `tf.name_scope()`. 607 | 608 | Raises: 609 | ValueError: if `clone_index` is greater or equal to the number of clones". 610 | """ 611 | if clone_index >= self._num_clones: 612 | raise ValueError('clone_index must be less than num_clones') 613 | scope = '' 614 | if self._num_clones > 1: 615 | scope = 'clone_%d' % clone_index 616 | return scope 617 | 618 | def optimizer_device(self): 619 | """Device to use with the optimizer. 620 | 621 | Returns: 622 | A value suitable for `tf.device()`. 623 | """ 624 | if self._num_ps_tasks > 0 or self._num_clones > 0: 625 | return self._worker_device + '/device:CPU:0' 626 | else: 627 | return '' 628 | 629 | def inputs_device(self): 630 | """Device to use to build the inputs. 631 | 632 | Returns: 633 | A value suitable for `tf.device()`. 634 | """ 635 | device = '' 636 | if self._num_ps_tasks > 0: 637 | device += self._worker_device 638 | device += '/device:CPU:0' 639 | return device 640 | 641 | def variables_device(self): 642 | """Returns the device to use for variables created inside the clone. 643 | 644 | Returns: 645 | A value suitable for `tf.device()`. 646 | """ 647 | device = '' 648 | if self._num_ps_tasks > 0: 649 | device += self._ps_device 650 | device += '/device:CPU:0' 651 | 652 | class _PSDeviceChooser(object): 653 | """Slim device chooser for variables when using PS.""" 654 | 655 | def __init__(self, device, tasks): 656 | self._device = device 657 | self._tasks = tasks 658 | self._task = 0 659 | 660 | def choose(self, op): 661 | if op.device: 662 | return op.device 663 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 664 | if node_def.op == 'Variable': 665 | t = self._task 666 | self._task = (self._task + 1) % self._tasks 667 | d = '%s/task:%d' % (self._device, t) 668 | return d 669 | else: 670 | return op.device 671 | 672 | if not self._num_ps_tasks: 673 | return device 674 | else: 675 | chooser = _PSDeviceChooser(device, self._num_ps_tasks) 676 | return chooser.choose 677 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 129 | 130 | 154 | 155 | 156 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 |