├── __init__.py
├── nets
├── __init__.py
├── .DS_Store
├── inception.py
├── nets_factory.py
├── inception_utils.py
├── inception_v1.py
└── mobilenet_v1.py
├── datasets
├── __init__.py
├── .DS_Store
├── dataset_factory.py
├── dataset_utils.py
├── utils.py
├── format_market_train.py
├── make_filename_list.py
├── convert_to_tfrecords.py
└── reid.py
├── deployment
├── __init__.py
├── .DS_Store
└── model_deploy.py
├── preprocessing
├── __init__.py
├── .DS_Store
├── preprocessing_factory.py
├── reid_preprocessing.py
└── inception_preprocessing.py
├── DML.png
├── .DS_Store
├── scripts
├── .DS_Store
├── format_and_convert_market.sh
├── evaluate_ind_mobilenet_on_market.sh
├── evaluate_dml_mobilenet_on_market.sh
├── train_ind_mobilenet_on_market.sh
└── train_dml_mobilenet_on_market.sh
├── .idea
├── markdown-navigator
│ └── profiles_settings.xml
├── vcs.xml
├── misc.xml
├── modules.xml
├── Deep-Mutual-Learning.iml
├── inspectionProfiles
│ └── Project_Default.xml
├── markdown-navigator.xml
└── workspace.xml
├── LICENSE
├── format_and_convert_data.py
├── README.md
├── eval_image_classifier.py
├── train_image_classifier.py
├── eval_models.py
└── train_models.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/nets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/deployment/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/DML.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/DML.png
--------------------------------------------------------------------------------
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/.DS_Store
--------------------------------------------------------------------------------
/nets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/nets/.DS_Store
--------------------------------------------------------------------------------
/scripts/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/scripts/.DS_Store
--------------------------------------------------------------------------------
/datasets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/datasets/.DS_Store
--------------------------------------------------------------------------------
/deployment/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/deployment/.DS_Store
--------------------------------------------------------------------------------
/preprocessing/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YingZhangDUT/Deep-Mutual-Learning/HEAD/preprocessing/.DS_Store
--------------------------------------------------------------------------------
/.idea/markdown-navigator/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/Deep-Mutual-Learning.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/scripts/format_and_convert_market.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # This script performs the following operations:
4 | # 1. Format the Market-1501 training images with consecutive labels.
5 | # 2. Convert the Market-1501 images into TFRecords.
6 | #
7 | # Usage:
8 | # cd Deep-Mutual-Learning
9 | # ./scripts/format_and_convert_market.sh
10 |
11 |
12 | # Where the Market-1501 images are saved to.
13 | IMAGE_DIR=/path/to/Market-1501/images
14 |
15 | # Where the TFRecord data will be saved to.
16 | TF_DIR=/path/to/market-1501/tfrecords
17 |
18 |
19 | echo "Building the TFRecords of market1501..."
20 |
21 | for split in bounding_box_train bounding_box_test gt_bbox query; do
22 | echo "Processing ${split} ..."
23 | python format_and_convert_data.py \
24 | --image_dir="$IMAGE_DIR/$split" \
25 | --output_dir=${TF_DIR} \
26 | --dataset_name="market1501" \
27 | --split_name="$split"
28 |
29 | done
30 |
31 | echo "Finished converting all the splits!"
32 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 YingZhangDUT
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/nets/inception.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Brings all inception models under one namespace."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | # pylint: disable=unused-import
22 | from nets.inception_v1 import inception_v1
23 | from nets.inception_v1 import inception_v1_arg_scope
24 | from nets.inception_v1 import inception_v1_base
25 | # pylint: enable=unused-import
26 |
--------------------------------------------------------------------------------
/scripts/evaluate_ind_mobilenet_on_market.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # This script performs the following operations:
4 | # Evaluate the MobileNet trained independently on Market-1501
5 | #
6 | # Usage:
7 | # cd Deep-Mutual-Learning
8 | # ./scripts/evaluate_ind_mobilenet_on_market.sh
9 |
10 |
11 | # Where the TFRecords are saved to.
12 | DATASET_DIR=/path/to/market-1501/tfrecords
13 |
14 | # Where the checkpoints are saved to.
15 | DATASET_NAME=market1501
16 | SAVE_NAME=market1501_ind_mobilenet
17 | CKPT_DIR=${SAVE_NAME}/checkpoint
18 |
19 | # Where the results will be saved to.
20 | RESULT_DIR=${SAVE_NAME}/results
21 |
22 | # Model setting
23 | MODEL_NAME=mobilenet_v1
24 |
25 | # Run evaluation.
26 | for split in query bounding_box_test gt_bbox; do
27 | python eval_image_classifier.py \
28 | --dataset_name=${DATASET_NAME}\
29 | --split_name="$split" \
30 | --dataset_dir=${DATASET_DIR} \
31 | --checkpoint_dir=${CKPT_DIR} \
32 | --eval_dir=${RESULT_DIR} \
33 | --model_name=${MODEL_NAME} \
34 | --preprocessing_name=reid \
35 | --num_classes=751 \
36 | --batch_size=1 \
37 | --num_networks=1
38 | done
39 |
--------------------------------------------------------------------------------
/scripts/evaluate_dml_mobilenet_on_market.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # This script performs the following operations:
4 | # Evaluate the MobileNets trained with DML on Market-1501
5 | #
6 | # Usage:
7 | # cd Deep-Mutual-Learning
8 | # ./scripts/evaluate_dml_mobilenet_on_market.sh
9 |
10 |
11 | # Where the TFRecords are saved to.
12 | DATASET_DIR=/path/to/market-1501/tfrecords
13 |
14 | # Where the checkpoints are saved to.
15 | DATASET_NAME=market1501
16 | SAVE_NAME=market1501_dml_mobilenet
17 | CKPT_DIR=${SAVE_NAME}/checkpoint
18 |
19 | # Where the results will be saved to.
20 | RESULT_DIR=${SAVE_NAME}/results
21 |
22 | # Model setting
23 | MODEL_NAME=mobilenet_v1,mobilenet_v1
24 |
25 | # Run evaluation.
26 | for split in query bounding_box_test gt_bbox; do
27 | python eval_image_classifier.py \
28 | --dataset_name=${DATASET_NAME}\
29 | --split_name="$split" \
30 | --dataset_dir=${DATASET_DIR} \
31 | --checkpoint_dir=${CKPT_DIR} \
32 | --eval_dir=${RESULT_DIR} \
33 | --model_name=${MODEL_NAME} \
34 | --preprocessing_name=reid \
35 | --num_classes=751 \
36 | --batch_size=1 \
37 | --num_networks=2
38 | done
39 |
--------------------------------------------------------------------------------
/scripts/train_ind_mobilenet_on_market.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # This script performs the following operations:
4 | # Training 1 MobileNet independently on Market-1501
5 | #
6 | # Usage:
7 | # cd Deep-Mutual-Learning
8 | # ./scripts/train_ind_mobilenet_on_market.sh
9 |
10 |
11 | # Where the TFRecords are saved to.
12 | DATASET_DIR=/path/to/market-1501/tfrecords
13 |
14 | # Where the checkpoint and logs will be saved to.
15 | DATASET_NAME=market1501
16 | SAVE_NAME=market1501_ind_mobilenet
17 | CKPT_DIR=${SAVE_NAME}/checkpoint
18 | LOG_DIR=${SAVE_NAME}/logs
19 |
20 | # Model setting
21 | MODEL_NAME=mobilenet_v1
22 | SPLIT_NAME=bounding_box_train
23 |
24 | # Run training.
25 | python train_image_classifier.py \
26 | --dataset_name=${DATASET_NAME}\
27 | --split_name=${SPLIT_NAME} \
28 | --dataset_dir=${DATASET_DIR} \
29 | --checkpoint_dir=${CKPT_DIR} \
30 | --log_dir=${LOG_DIR} \
31 | --model_name=${MODEL_NAME} \
32 | --preprocessing_name=reid \
33 | --max_number_of_steps=200000 \
34 | --ckpt_steps=5000 \
35 | --batch_size=16 \
36 | --num_classes=751 \
37 | --optimizer=adam \
38 | --learning_rate=0.0002 \
39 | --adam_beta1=0.5 \
40 | --opt_epsilon=1e-8 \
41 | --label_smoothing=0.1 \
42 | --num_networks=1
43 |
44 |
--------------------------------------------------------------------------------
/scripts/train_dml_mobilenet_on_market.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # This script performs the following operations:
4 | # Training 2 MobileNets with DML on Market-1501
5 | #
6 | # Usage:
7 | # cd Deep-Mutual-Learning
8 | # ./scripts/train_dml_mobilenet_on_market.sh
9 |
10 |
11 | # Where the TFRecords are saved to.
12 | DATASET_DIR=/path/to/market-1501/tfrecords
13 |
14 | # Where the checkpoint and logs will be saved to.
15 | DATASET_NAME=market1501
16 | SAVE_NAME=market1501_dml_mobilenet
17 | CKPT_DIR=${SAVE_NAME}/checkpoint
18 | LOG_DIR=${SAVE_NAME}/logs
19 |
20 | # Model setting
21 | MODEL_NAME=mobilenet_v1,mobilenet_v1
22 | SPLIT_NAME=bounding_box_train
23 |
24 | # Run training.
25 | python train_image_classifier.py \
26 | --dataset_name=${DATASET_NAME}\
27 | --split_name=${SPLIT_NAME} \
28 | --dataset_dir=${DATASET_DIR} \
29 | --checkpoint_dir=${CKPT_DIR} \
30 | --log_dir=${LOG_DIR} \
31 | --model_name=${MODEL_NAME} \
32 | --preprocessing_name=reid \
33 | --max_number_of_steps=200000 \
34 | --ckpt_steps=5000 \
35 | --batch_size=16 \
36 | --num_classes=751 \
37 | --optimizer=adam \
38 | --learning_rate=0.0002 \
39 | --adam_beta1=0.5 \
40 | --opt_epsilon=1e-8 \
41 | --label_smoothing=0.1 \
42 | --num_networks=2
43 |
44 |
--------------------------------------------------------------------------------
/datasets/dataset_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | Provide dataset given split name.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | from datasets import reid
10 |
11 | # provider functions might vary on different datasets
12 | datasets_map = {
13 | 'market1501': reid,
14 | }
15 |
16 |
17 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
18 | """Given a dataset name and a split_name returns a Dataset.
19 |
20 | Args:
21 | name: String, the name of the dataset.
22 | split_name: A train/test split name.
23 | dataset_dir: The directory where the dataset files are stored.
24 | file_pattern: The file pattern to use for matching the dataset source files.
25 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default
26 | reader defined by each dataset is used.
27 |
28 | Returns:
29 | A `Dataset` class.
30 |
31 | Raises:
32 | ValueError: If the dataset `name` is unknown.
33 | """
34 | if name not in datasets_map:
35 | raise ValueError('Name of dataset unknown %s' % name)
36 | return datasets_map[name].get_split(
37 | split_name,
38 | dataset_dir,
39 | file_pattern,
40 | reader)
41 |
--------------------------------------------------------------------------------
/format_and_convert_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Format Market-1501 training images and convert all the splits into TFRecords
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 | from datasets import convert_to_tfrecords
11 | from datasets import format_market_train
12 | from datasets import make_filename_list
13 | from datasets.utils import *
14 |
15 | FLAGS = tf.app.flags.FLAGS
16 | tf.app.flags.DEFINE_string('image_dir', None, 'path to the raw images')
17 | tf.app.flags.DEFINE_string('output_dir', None, 'path to the list and tfrecords ')
18 | tf.app.flags.DEFINE_string('split_name', None, 'split name')
19 |
20 |
21 | def main(_):
22 |
23 | mkdir_if_missing(FLAGS.output_dir)
24 |
25 | if FLAGS.split_name == 'bounding_box_train':
26 | format_market_train.run(image_dir=FLAGS.image_dir)
27 |
28 | make_filename_list.run(image_dir=FLAGS.image_dir,
29 | output_dir=FLAGS.output_dir,
30 | split_name=FLAGS.split_name)
31 |
32 | convert_to_tfrecords.run(image_dir=FLAGS.image_dir,
33 | output_dir=FLAGS.output_dir,
34 | split_name=FLAGS.split_name)
35 |
36 |
37 | if __name__ == '__main__':
38 | tf.app.run()
39 |
40 |
--------------------------------------------------------------------------------
/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains utilities for converting datasets.
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 |
10 |
11 | def int64_feature(values):
12 | """Returns a TF-Feature of int64s.
13 |
14 | Args:
15 | values: A scalar or list of values.
16 |
17 | Returns:
18 | a TF-Feature.
19 | """
20 | if not isinstance(values, (tuple, list)):
21 | values = [values]
22 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
23 |
24 |
25 | def bytes_feature(values):
26 | """Returns a TF-Feature of bytes.
27 |
28 | Args:
29 | values: A string.
30 |
31 | Returns:
32 | a TF-Feature.
33 | """
34 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
35 |
36 |
37 | def image_to_tfexample(image_data, class_id, filename, height, width, image_format):
38 | return tf.train.Example(features=tf.train.Features(feature={
39 | 'image/encoded': bytes_feature(image_data),
40 | 'image/label': int64_feature(class_id),
41 | 'image/filename': bytes_feature(filename),
42 | 'image/height': int64_feature(height),
43 | 'image/width': int64_feature(width),
44 | 'image/format': bytes_feature(image_format),
45 | }))
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep-Mutual-Learning
2 |
3 | TensorFlow implementation of **[Deep Mutual Learning](https://drive.google.com/file/d/1Deg9xXqPKAlxRgmWbggavftTvJPqJeyp/view)** accepted by CVPR 2018.
4 |
5 |
6 | ## Introduction
7 | Deep mutual learning provides a simple but effective way to improve the generalisation ability of a network by training collaboratively with a cohort of other networks.
8 |
9 | 
10 |
11 | ## Requirements
12 |
13 | 1. TensorFlow 1.3.1
14 | 2. CUDA 8.0 and cuDNN 6.0
15 | 3. Matlab
16 |
17 | ## Usage
18 |
19 | ### Data Preparation
20 | 1. Please download the [Market-1501 Dataset](http://www.liangzheng.com.cn/Project/project_reid.html)
21 |
22 | 2. Convert the image data into TFRecords
23 | ```
24 | sh scripts/format_and_convert_market.sh
25 | ```
26 |
27 | ### Training
28 | 1. Train MobileNets with DML
29 | ```
30 | sh scripts/train_dml_mobilenet_on_market.sh
31 | ```
32 |
33 | 2. Train MobileNet independently
34 | ```
35 | sh scripts/train_ind_mobilenet_on_market.sh
36 | ```
37 |
38 | ### Testing
39 | 1. Extract features of the test image
40 | ```
41 | sh scripts/evaludate_dml_mobilenet_on_market.sh
42 | ```
43 |
44 | 2. Evaluate the performance with matlab [code](https://github.com/zhunzhong07/person-re-ranking/tree/master/evaluation)
45 |
46 |
47 | ## Citation
48 | If you find DML useful in your research, please kindly cite our paper:
49 |
50 | ```
51 | @inproceedings{ying2018DML,
52 | author = {Ying Zhang and Tao Xiang and Timothy M. Hospedales and Huchuan Lu},
53 | title = {Deep Mutual Learning},
54 | booktitle = {CVPR},
55 | year = {2018}}
56 | ```
--------------------------------------------------------------------------------
/preprocessing/preprocessing_factory.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains a factory for building various models.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 |
11 | from preprocessing import inception_preprocessing
12 | from preprocessing import reid_preprocessing
13 |
14 | slim = tf.contrib.slim
15 |
16 |
17 | def get_preprocessing(name, is_training=False):
18 | """Returns preprocessing_fn(image, height, width, **kwargs).
19 |
20 | Args:
21 | name: The name of the preprocessing function.
22 | is_training: `True` if the model is being used for training and `False`
23 | otherwise.
24 |
25 | Returns:
26 | preprocessing_fn: A function that preprocessing a single image (pre-batch).
27 | It has the following signature:
28 | image = preprocessing_fn(image, output_height, output_width, ...).
29 |
30 | Raises:
31 | ValueError: If Preprocessing `name` is not recognized.
32 | """
33 | preprocessing_fn_map = {
34 | 'inception_v1': inception_preprocessing,
35 | 'mobilenet_v1': inception_preprocessing,
36 | 'reid': reid_preprocessing,
37 | }
38 |
39 | if name not in preprocessing_fn_map:
40 | raise ValueError('Preprocessing name [%s] was not recognized' % name)
41 |
42 | def preprocessing_fn(image, output_height, output_width, **kwargs):
43 | return preprocessing_fn_map[name].preprocess_image(
44 | image, output_height, output_width, is_training=is_training, **kwargs)
45 |
46 | return preprocessing_fn
47 |
--------------------------------------------------------------------------------
/datasets/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Contains utilities for general usage.
3 | """
4 |
5 | import os
6 | import os.path as osp
7 | import json
8 | import codecs
9 | import cPickle
10 |
11 |
12 | def mkdir_if_missing(d):
13 | if not osp.isdir(d):
14 | os.makedirs(d)
15 |
16 |
17 | def pickle(data, file_path):
18 | with open(file_path, 'wb') as f:
19 | cPickle.dump(data, f, cPickle.HIGHEST_PROTOCOL)
20 |
21 |
22 | def unpickle(file_path):
23 | with open(file_path, 'rb') as f:
24 | data = cPickle.load(f)
25 | return data
26 |
27 |
28 | def read_list(file_path, coding=None):
29 | if coding is None:
30 | with open(file_path, 'r') as f:
31 | arr = [line.strip() for line in f.readlines()]
32 | else:
33 | with codecs.open(file_path, 'r', coding) as f:
34 | arr = [line.strip() for line in f.readlines()]
35 | return arr
36 |
37 |
38 | def write_list(arr, file_path, coding=None):
39 | if coding is None:
40 | arr = ['{}'.format(item) for item in arr]
41 | with open(file_path, 'w') as f:
42 | f.write('\n'.join(arr))
43 | else:
44 | with codecs.open(file_path, 'w', coding) as f:
45 | f.write(u'\n'.join(arr))
46 |
47 |
48 | def read_kv(file_path, coding=None):
49 | arr = read_list(file_path, coding)
50 | if len(arr) == 0:
51 | return [], []
52 | return zip(*map(str.split, arr))
53 |
54 |
55 | def write_kv(k, v, file_path, coding=None):
56 | arr = zip(k, v)
57 | arr = [' '.join(item) for item in arr]
58 | write_list(arr, file_path, coding)
59 |
60 |
61 | def read_json(file_path):
62 | with open(file_path, 'r') as f:
63 | obj = json.load(f)
64 | return obj
65 |
66 |
67 | def write_json(obj, file_path):
68 | with open(file_path, 'w') as f:
69 | json.dump(obj, f, indent=4, separators=(',', ': '))
70 |
--------------------------------------------------------------------------------
/datasets/format_market_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Format Market-1501 training images with consecutive labels.
3 |
4 | This code modifies the data preparation method of
5 | "Learning Deep Feature Representations with Domain Guided Dropout for Person Re-identification".
6 |
7 | """
8 |
9 | import shutil
10 | from glob import glob
11 | from datasets.utils import *
12 |
13 |
14 | def _format_train_data(in_dir, output_dir):
15 | # cam_0 to cam_5
16 | for i in xrange(6):
17 | mkdir_if_missing(osp.join(output_dir, 'cam_' + str(i)))
18 | # pdb.set_trace()
19 | images = glob(osp.join(in_dir, '*.jpg'))
20 | images.sort()
21 | identities = []
22 | prev_pid = -1
23 | for name in images:
24 | name = osp.basename(name)
25 | p_id = int(name[0:4])
26 | c_id = int(name[6]) - 1
27 | if prev_pid != p_id:
28 | identities.append([])
29 | prev_cid = -1
30 | p_images = identities[-1]
31 | if prev_cid != c_id:
32 | p_images.append([])
33 | v_images = p_images[-1]
34 | file_name = 'cam_{}/cam_{:02d}_{:05d}_{:05d}.jpg'.format(c_id, c_id, len(identities)-1, len(v_images))
35 | shutil.copy(osp.join(in_dir, name),
36 | osp.join(output_dir, file_name))
37 | v_images.append(file_name)
38 | prev_pid = p_id
39 | prev_cid = c_id
40 | # Save meta information into a json file
41 | meta = {'name': 'market1501', 'shot': 'multiple', 'num_cameras': 6}
42 | meta['identities'] = identities
43 | write_json(meta, osp.join(output_dir, 'meta.json'))
44 | num_images = len(images)
45 | num_classes = len(identities)
46 | print("Training data has %d images of %d classes" % (num_images, num_classes))
47 |
48 |
49 | def run(image_dir):
50 | """Format the datasets with consecutive labels.
51 |
52 | Args:
53 | image_dir: The dataset directory where the raw images are stored.
54 |
55 | """
56 | in_dir = image_dir + "_raw"
57 | os.rename(image_dir, in_dir)
58 | mkdir_if_missing(image_dir)
59 | _format_train_data(in_dir, image_dir)
60 |
--------------------------------------------------------------------------------
/datasets/make_filename_list.py:
--------------------------------------------------------------------------------
1 | """
2 | Make a list for image files. This provides guidance for checking the data preparation.
3 | """
4 | import numpy as np
5 | from glob import glob
6 | from utils import *
7 |
8 |
9 | def _save(file_label_list, file_path):
10 | content = ['{} {}'.format(x, y) for x, y in file_label_list]
11 | write_list(content, file_path)
12 |
13 |
14 | def _get_train_list(files):
15 | ret = []
16 | for views in files:
17 | for v in views:
18 | for f in v:
19 | # camID = int(osp.basename(f)[4:6])
20 | label = int(osp.basename(f)[7:12])
21 | ret.append((f, label))
22 | return np.asarray(ret)
23 |
24 |
25 | def _make_train_list(image_dir, output_dir, split_name):
26 | meta = read_json(osp.join(image_dir, 'meta.json'))
27 | identities = np.asarray(meta['identities'])
28 | images = _get_train_list(identities)
29 | _save(images, os.path.join(output_dir, '%s.txt' % split_name))
30 |
31 |
32 | def _get_test_list(files):
33 | ret = []
34 | for f in files:
35 | if osp.basename(f)[:2] == '-1':
36 | # camID = int(osp.basename(f)[4]) - 1
37 | label = int(osp.basename(f)[:2])
38 | else:
39 | # camID = int(osp.basename(f)[6]) - 1
40 | label = int(osp.basename(f)[:4])
41 | ret.append((osp.basename(f), label))
42 | return np.asarray(ret)
43 |
44 |
45 | def _make_test_list(image_dir, output_dir, split_name):
46 | files = sorted(glob(osp.join(image_dir, '*.jpg')))
47 | images = _get_test_list(files)
48 | _save(images, os.path.join(output_dir, '%s.txt' % split_name))
49 |
50 |
51 | def run(image_dir, output_dir, split_name):
52 | """Make list file for images.
53 |
54 | Args:
55 | image_dir: The image directory where the raw images are stored.
56 | output_dir: The directory where the lists and tfrecords are stored.
57 | split_name: The split name of dataset.
58 | """
59 | if split_name == 'bounding_box_train':
60 | _make_train_list(image_dir, output_dir, split_name)
61 | else:
62 | _make_test_list(image_dir, output_dir, split_name)
63 |
--------------------------------------------------------------------------------
/eval_image_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | Generic evaluation script that evaluates a model using a given dataset.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 | import eval_models
11 | from datasets.utils import *
12 |
13 | slim = tf.contrib.slim
14 |
15 | tf.app.flags.DEFINE_string('dataset_name', 'market1501',
16 | 'The name of the dataset to load.')
17 |
18 | tf.app.flags.DEFINE_string('split_name', 'test',
19 | 'The name of the train/test split.')
20 |
21 | tf.app.flags.DEFINE_string('dataset_dir', None,
22 | 'The directory where the dataset files are stored.')
23 |
24 | tf.app.flags.DEFINE_string('checkpoint_dir', None,
25 | 'The directory where the model was written to or an absolute path to a '
26 | 'checkpoint file.')
27 |
28 | tf.app.flags.DEFINE_string('eval_dir', 'results',
29 | 'Directory where the results are saved to.')
30 |
31 | tf.app.flags.DEFINE_string('model_name', 'mobilenet_v1',
32 | 'The name of the architecture to evaluate.')
33 |
34 | tf.app.flags.DEFINE_integer('num_networks', 2,
35 | 'Number of Networks')
36 |
37 | tf.app.flags.DEFINE_integer('num_classes', 751,
38 | 'The number of classes.')
39 |
40 | tf.app.flags.DEFINE_integer('batch_size', 1,
41 | 'The number of samples in each batch.')
42 |
43 | tf.app.flags.DEFINE_string('preprocessing_name', None,
44 | 'The name of the preprocessing to use. If left '
45 | 'as `None`, then the model_name flag is used.')
46 |
47 | tf.app.flags.DEFINE_integer('num_preprocessing_threads', 1,
48 | 'The number of threads used to create the batches.')
49 |
50 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999,
51 | 'The decay to use for the moving average.'
52 | 'If left as None, then moving averages are not used.')
53 |
54 | #########################
55 |
56 | FLAGS = tf.app.flags.FLAGS
57 |
58 |
59 | def main(_):
60 | # create folders
61 | mkdir_if_missing(FLAGS.eval_dir)
62 | # test
63 | eval_models.evaluate()
64 |
65 |
66 | if __name__ == '__main__':
67 | tf.app.run()
68 |
69 |
--------------------------------------------------------------------------------
/nets/nets_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import functools
21 |
22 | import tensorflow as tf
23 | from nets import inception
24 | from nets import mobilenet_v1
25 |
26 | slim = tf.contrib.slim
27 |
28 | networks_map = {'mobilenet_v1': mobilenet_v1.mobilenet_v1,
29 | 'inception_v1': inception.inception_v1,
30 | }
31 |
32 | arg_scopes_map = {'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope,
33 | 'inception_v1': inception.inception_v1_arg_scope,
34 | }
35 |
36 |
37 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
38 | """Returns a network_fn such as `logits, end_points = network_fn(images)`.
39 |
40 | Args:
41 | name: The name of the network.
42 | num_classes: The number of classes to use for classification.
43 | weight_decay: The l2 coefficient for the model weights.
44 | is_training: `True` if the model is being used for training and `False`
45 | otherwise.
46 |
47 | Returns:
48 | network_fn: A function that applies the model to a batch of images. It has
49 | the following signature:
50 | logits, end_points = network_fn(images)
51 | Raises:
52 | ValueError: If network `name` is not recognized.
53 | """
54 | if name not in networks_map:
55 | raise ValueError('Name of network unknown %s' % name)
56 | func = networks_map[name]
57 |
58 | @functools.wraps(func)
59 | def network_fn(images, scope=None):
60 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
61 | with slim.arg_scope(arg_scope):
62 | return func(images, num_classes, is_training=is_training, scope=scope)
63 |
64 | if hasattr(func, 'default_image_size'):
65 | network_fn.default_image_size = func.default_image_size
66 |
67 | return network_fn
68 |
--------------------------------------------------------------------------------
/datasets/convert_to_tfrecords.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert Market-1501 to TFRecords of TF-Example protos.
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import tensorflow as tf
9 | import os
10 | import sys
11 | from scipy import misc
12 | from datasets.dataset_utils import *
13 |
14 | # resize all the re-id images into the same size
15 | _IMAGE_HEIGHT = 160
16 | _IMAGE_WIDTH = 64
17 | _IMAGE_CHANNELS = 3
18 |
19 |
20 | def _add_to_tfrecord(image_dir, list_filename, tfrecord_writer, split_name):
21 | """Loads images and writes files to a TFRecord.
22 |
23 | Args:
24 | image_dir: The image directory where the raw images are stored.
25 | list_filename: The list file of images.
26 | tfrecord_writer: The TFRecord writer to use for writing.
27 | """
28 | num_images = len(tf.gfile.FastGFile(list_filename, 'r').readlines())
29 |
30 | shape = (_IMAGE_HEIGHT, _IMAGE_WIDTH, _IMAGE_CHANNELS)
31 | with tf.Graph().as_default():
32 | image = tf.placeholder(dtype=tf.uint8, shape=shape)
33 | encoded_png = tf.image.encode_png(image)
34 | j = 0
35 | with tf.Session('') as sess:
36 | for line in tf.gfile.FastGFile(list_filename, 'r').readlines():
37 | sys.stdout.write('\r>> Converting %s image %d/%d' % (split_name, j + 1, num_images))
38 | sys.stdout.flush()
39 | j += 1
40 | imagename, label = line.split(' ')
41 | label = int(label)
42 | file_path = os.path.join(image_dir, imagename)
43 | image_data = misc.imread(file_path)
44 | image_data = misc.imresize(image_data, [_IMAGE_HEIGHT, _IMAGE_WIDTH])
45 | png_string = sess.run(encoded_png, feed_dict={image: image_data})
46 | example = image_to_tfexample(png_string, label, imagename, _IMAGE_HEIGHT, _IMAGE_WIDTH, 'png')
47 | tfrecord_writer.write(example.SerializeToString())
48 |
49 |
50 | def run(image_dir, output_dir, split_name):
51 | """Convert images to tfrecords.
52 | Args:
53 | image_dir: The image directory where the raw images are stored.
54 | output_dir: The directory where the lists and tfrecords are stored.
55 | split_name: The split name of dataset.
56 | """
57 | list_filename = os.path.join(output_dir, '%s.txt' % split_name)
58 | tf_filename = os.path.join(output_dir, '%s.tfrecord' % split_name)
59 |
60 | if tf.gfile.Exists(tf_filename):
61 | print('Dataset files already exist. Exiting without re-creating them.')
62 | return
63 |
64 | with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
65 | _add_to_tfrecord(image_dir, list_filename, tfrecord_writer, split_name)
66 |
67 | print(" Done! \n")
68 |
69 |
--------------------------------------------------------------------------------
/preprocessing/reid_preprocessing.py:
--------------------------------------------------------------------------------
1 | """
2 | Provides utilities to preprocess images for re-id.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 | slim = tf.contrib.slim
11 |
12 |
13 | IMAGE_HEIGHT = 160
14 | IMAGE_WIDTH = 64
15 |
16 |
17 | def preprocess_for_train(image,
18 | output_height,
19 | output_width):
20 | """Preprocesses the given image for training.
21 |
22 | Args:
23 | image: A `Tensor` representing an image of arbitrary size.
24 | output_height: The height of the image after preprocessing.
25 | output_width: The width of the image after preprocessing.
26 |
27 | Returns:
28 | A preprocessed image.
29 | """
30 | if image.dtype != tf.float32:
31 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
32 |
33 | image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, 3])
34 | image = tf.image.resize_images(image, [output_height, output_width])
35 | image = tf.image.random_flip_left_right(image)
36 | tf.summary.image('cropped_resized_image',
37 | tf.expand_dims(image, 0))
38 | image = tf.subtract(image, 0.5)
39 | image = tf.multiply(image, 2.0)
40 | return image
41 |
42 |
43 | def preprocess_for_eval(image, output_height, output_width):
44 | """Preprocesses the given image for evaluation.
45 |
46 | Args:
47 | image: A `Tensor` representing an image of arbitrary size.
48 | output_height: The height of the image after preprocessing.
49 | output_width: The width of the image after preprocessing.
50 |
51 | Returns:
52 | A preprocessed image.
53 | """
54 | if image.dtype != tf.float32:
55 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
56 | image.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH, 3])
57 | image = tf.image.resize_images(image, [output_height, output_width])
58 | image.set_shape([output_height, output_width, 3])
59 | image = tf.subtract(image, 0.5)
60 | image = tf.multiply(image, 2.0)
61 | return image
62 |
63 |
64 | def preprocess_image(image, output_height, output_width, is_training=False):
65 | """Preprocesses the given image.
66 |
67 | Args:
68 | image: A `Tensor` representing an image of arbitrary size.
69 | output_height: The height of the image after preprocessing.
70 | output_width: The width of the image after preprocessing.
71 | is_training: `True` if we're preprocessing the image for training and
72 | `False` otherwise.
73 |
74 | Returns:
75 | A preprocessed image.
76 | """
77 | if is_training:
78 | return preprocess_for_train(image, output_height, output_width)
79 | else:
80 | return preprocess_for_eval(image, output_height, output_width)
81 |
--------------------------------------------------------------------------------
/nets/inception_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains common code shared by all inception models.
16 |
17 | Usage of arg scope:
18 | with slim.arg_scope(inception_arg_scope()):
19 | logits, end_points = inception.inception_v3(images, num_classes,
20 | is_training=is_training)
21 |
22 | """
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | import tensorflow as tf
28 |
29 | slim = tf.contrib.slim
30 |
31 |
32 | def inception_arg_scope(weight_decay=0.00004,
33 | use_batch_norm=True,
34 | batch_norm_decay=0.9997,
35 | batch_norm_epsilon=0.001):
36 | """Defines the default arg scope for inception models.
37 |
38 | Args:
39 | weight_decay: The weight decay to use for regularizing the model.
40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution.
41 | batch_norm_decay: Decay for batch norm moving average.
42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero
43 | in batch norm.
44 |
45 | Returns:
46 | An `arg_scope` to use for the inception models.
47 | """
48 | batch_norm_params = {
49 | # Decay for the moving averages.
50 | 'decay': batch_norm_decay,
51 | # epsilon to prevent 0s in variance.
52 | 'epsilon': batch_norm_epsilon,
53 | # collection containing update_ops.
54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
55 | }
56 | if use_batch_norm:
57 | normalizer_fn = slim.batch_norm
58 | normalizer_params = batch_norm_params
59 | else:
60 | normalizer_fn = None
61 | normalizer_params = {}
62 | # Set weight_decay for weights in Conv and FC layers.
63 | with slim.arg_scope([slim.conv2d, slim.fully_connected],
64 | weights_regularizer=slim.l2_regularizer(weight_decay)):
65 | with slim.arg_scope(
66 | [slim.conv2d],
67 | weights_initializer=slim.variance_scaling_initializer(),
68 | activation_fn=tf.nn.relu,
69 | normalizer_fn=normalizer_fn,
70 | normalizer_params=normalizer_params) as sc:
71 | return sc
72 |
--------------------------------------------------------------------------------
/datasets/reid.py:
--------------------------------------------------------------------------------
1 | """
2 | Provides data given split name
3 | """
4 | from __future__ import absolute_import
5 | from __future__ import division
6 | from __future__ import print_function
7 |
8 | import os
9 | import tensorflow as tf
10 |
11 | FLAGS = tf.app.flags.FLAGS
12 |
13 | slim = tf.contrib.slim
14 |
15 | _FILE_PATTERN = '%s.tfrecord'
16 |
17 | _SPLITS_NAMES = ['bounding_box_train', 'bounding_box_test', 'gt_bbox', 'query']
18 |
19 | _ITEMS_TO_DESCRIPTIONS = {
20 | 'image': 'A color image of varying height and width.',
21 | 'label': 'The label id of the image, integer between 0 and num_classes',
22 | 'filename': 'The name of an image',
23 | }
24 |
25 |
26 | def get_num_examples(split_name):
27 | list_file = os.path.join(FLAGS.dataset_dir, '%s.txt' % split_name)
28 | num_examples = len(tf.gfile.FastGFile(list_file, 'r').readlines())
29 |
30 | return num_examples
31 |
32 |
33 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
34 | """Get a dataset tuple.
35 |
36 | Args:
37 | split_name: A train/test split name.
38 | dataset_dir: The base directory of the dataset sources.
39 | file_pattern: The file pattern to use when matching the dataset sources.
40 | It is assumed that the pattern contains a '%s' string so that the split
41 | name can be inserted.
42 | reader: The TensorFlow reader type.
43 |
44 | Returns:
45 | A `Dataset` namedtuple.
46 |
47 | Raises:
48 | ValueError: if `split_name` is not a valid train/test split.
49 | """
50 | if split_name not in _SPLITS_NAMES:
51 | raise ValueError('split name %s was not recognized.' % split_name)
52 |
53 | if not file_pattern:
54 | file_pattern = _FILE_PATTERN
55 |
56 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
57 |
58 | # Allowing None in the signature so that dataset_factory can use the default.
59 | if reader is None:
60 | reader = tf.TFRecordReader
61 |
62 | keys_to_features = {
63 | 'image/encoded': tf.FixedLenFeature(
64 | (), tf.string, default_value=''),
65 | 'image/format': tf.FixedLenFeature(
66 | (), tf.string, default_value='png'),
67 | 'image/label': tf.FixedLenFeature(
68 | [], dtype=tf.int64, default_value=-1),
69 | 'image/filename': tf.FixedLenFeature(
70 | [], dtype=tf.string, default_value=''),
71 | }
72 |
73 | items_to_handlers = {
74 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
75 | 'label': slim.tfexample_decoder.Tensor('image/label'),
76 | 'filename': slim.tfexample_decoder.Tensor('image/filename'),
77 | }
78 |
79 | decoder = slim.tfexample_decoder.TFExampleDecoder(
80 | keys_to_features, items_to_handlers)
81 |
82 | num_examples = get_num_examples(split_name)
83 | num_classes = FLAGS.num_classes
84 |
85 | return slim.dataset.Dataset(
86 | data_sources=file_pattern,
87 | reader=reader,
88 | decoder=decoder,
89 | num_samples=num_examples,
90 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
91 | num_classes=num_classes)
92 |
--------------------------------------------------------------------------------
/.idea/markdown-navigator.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
--------------------------------------------------------------------------------
/train_image_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | Generic training script that trains a model using a given dataset.
3 |
4 | This code modifies the "TensorFlow-Slim image classification model library",
5 | Please visit https://github.com/tensorflow/models/tree/master/research/slim
6 | for more detailed usage.
7 |
8 | """
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import tensorflow as tf
15 | import train_models
16 | from datasets.utils import *
17 |
18 | slim = tf.contrib.slim
19 |
20 | #########################
21 | # Training Directories #
22 | #########################
23 |
24 | tf.app.flags.DEFINE_string('dataset_name', 'market1501',
25 | 'The name of the dataset to load.')
26 |
27 | tf.app.flags.DEFINE_string('split_name', 'bounding_box_train',
28 | 'The name of the data split.')
29 |
30 | tf.app.flags.DEFINE_string('dataset_dir', None,
31 | 'The directory where the dataset files are stored.')
32 |
33 | tf.app.flags.DEFINE_string('checkpoint_dir', 'checkpoint',
34 | 'Directory name to save the checkpoints [checkpoint]')
35 |
36 | tf.app.flags.DEFINE_string('log_dir', 'logs',
37 | 'Directory name to save the logs')
38 |
39 |
40 | #########################
41 | # Model Settings #
42 | #########################
43 |
44 | tf.app.flags.DEFINE_string('model_name', 'mobilenet_v1',
45 | 'The name of the architecture to train.')
46 |
47 | tf.app.flags.DEFINE_string('preprocessing_name', None,
48 | 'The name of the preprocessing to use. If left as `None`, '
49 | 'then the model_name flag is used.')
50 |
51 | tf.app.flags.DEFINE_float('weight_decay', 0.00004,
52 | 'The weight decay on the model weights.')
53 |
54 | tf.app.flags.DEFINE_float('label_smoothing', 0.0,
55 | 'The amount of label smoothing.')
56 |
57 | tf.app.flags.DEFINE_integer('batch_size', 16,
58 | 'The number of samples in each batch.')
59 |
60 | tf.app.flags.DEFINE_integer('max_number_of_steps', 200000,
61 | 'The maximum number of training steps.')
62 |
63 | tf.app.flags.DEFINE_integer('ckpt_steps', 5000,
64 | 'How many steps to save checkpoints.')
65 |
66 | tf.app.flags.DEFINE_integer('num_classes', 751,
67 | 'The number of classes.')
68 |
69 | tf.app.flags.DEFINE_integer('num_networks', 2,
70 | 'The number of networks in DML.')
71 |
72 | tf.app.flags.DEFINE_integer('num_gpus', 1,
73 | 'The number of GPUs.')
74 |
75 | #########################
76 | # Optimization Settings #
77 | #########################
78 |
79 | tf.app.flags.DEFINE_string('optimizer', 'adam',
80 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
81 | '"ftrl", "momentum", "sgd" or "rmsprop".')
82 |
83 | tf.app.flags.DEFINE_float('learning_rate', 0.0002,
84 | 'Initial learning rate.')
85 |
86 | tf.app.flags.DEFINE_float('adam_beta1', 0.5,
87 | 'The exponential decay rate for the 1st moment estimates.')
88 |
89 | tf.app.flags.DEFINE_float('adam_beta2', 0.999,
90 | 'The exponential decay rate for the 2nd moment estimates.')
91 |
92 | tf.app.flags.DEFINE_float('opt_epsilon', 1.0,
93 | 'Epsilon term for the optimizer.')
94 |
95 |
96 | #########################
97 | # Default Settings #
98 | #########################
99 | tf.app.flags.DEFINE_integer('num_clones', 1,
100 | 'Number of model clones to deploy.')
101 |
102 | tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
103 | 'Use CPUs to deploy clones.')
104 |
105 | tf.app.flags.DEFINE_integer('worker_replicas', 1,
106 | 'Number of worker replicas.')
107 |
108 | tf.app.flags.DEFINE_integer('num_ps_tasks', 0,
109 | 'The number of parameter servers. If the value is 0, then the parameters '
110 | 'are handled locally by the worker.')
111 |
112 | tf.app.flags.DEFINE_integer('task', 0,
113 | 'Task id of the replica running the training.')
114 |
115 | tf.app.flags.DEFINE_float('moving_average_decay', 0.9999,
116 | 'The decay to use for the moving average.'
117 | 'If left as None, then moving averages are not used.')
118 |
119 | tf.app.flags.DEFINE_integer('input_queue_memory_factor', 16,
120 | """Size of the queue of preprocessed images. """)
121 |
122 | tf.app.flags.DEFINE_integer('num_readers', 4,
123 | 'The number of parallel readers that read data from the dataset.')
124 |
125 | tf.app.flags.DEFINE_integer('num_preprocessing_threads', 4,
126 | 'The number of threads used to create the batches.')
127 |
128 | tf.app.flags.DEFINE_boolean('log_device_placement', False,
129 | """Whether to log device placement.""")
130 |
131 |
132 | FLAGS = tf.app.flags.FLAGS
133 |
134 |
135 | def main(_):
136 | # create folders
137 | mkdir_if_missing(FLAGS.checkpoint_dir)
138 | mkdir_if_missing(FLAGS.log_dir)
139 | # training
140 | train_models.train()
141 |
142 |
143 | if __name__ == '__main__':
144 | tf.app.run()
145 |
--------------------------------------------------------------------------------
/eval_models.py:
--------------------------------------------------------------------------------
1 | """
2 | Generic evaluation script that evaluates a model using a given dataset.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 | from datasets import dataset_factory
11 | from nets import nets_factory
12 | from preprocessing import preprocessing_factory
13 | import math
14 | from datetime import datetime
15 | import numpy as np
16 | import os.path
17 | import sys
18 | import scipy.io as sio
19 |
20 | slim = tf.contrib.slim
21 |
22 | FLAGS = tf.app.flags.FLAGS
23 |
24 |
25 | def _extract_once(features, labels, filenames, num_examples, saver):
26 | """Extract Features.
27 | """
28 | config = tf.ConfigProto()
29 | config.gpu_options.per_process_gpu_memory_fraction = 0.2
30 | with tf.device('/cpu:0'):
31 | with tf.Session(config=config) as sess:
32 | ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
33 | if ckpt and ckpt.model_checkpoint_path:
34 | if os.path.isabs(ckpt.model_checkpoint_path):
35 | saver.restore(sess, ckpt.model_checkpoint_path)
36 | else:
37 | ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
38 | saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, ckpt_name))
39 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
40 | print('Succesfully loaded model from %s at step=%s.' %
41 | (ckpt.model_checkpoint_path, global_step))
42 | else:
43 | print('No checkpoint file found')
44 | return
45 |
46 | # Start the queue runners.
47 | coord = tf.train.Coordinator()
48 | try:
49 | threads = []
50 | for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
51 | threads.extend(qr.create_threads(sess, coord=coord, daemon=True, start=True))
52 | # num_examples = get_num_examples()
53 | num_iter = int(math.ceil(num_examples / FLAGS.batch_size))
54 | # Counts the number of correct predictions.
55 | step = 0
56 | all_features = []
57 | all_labels = []
58 | print("Current Path: %s" % os.getcwd())
59 | print('%s: starting extracting features on (%s).' % (datetime.now(), FLAGS.split_name))
60 | while step < num_iter and not coord.should_stop():
61 | step += 1
62 | sys.stdout.write('\r>> Extracting %s image %d/%d' % (FLAGS.split_name, step, num_examples))
63 | sys.stdout.flush()
64 | eval_features, eval_labels, eval_filenames = sess.run([features, labels, filenames])
65 | # print('Filename:%s, Camid:%d, Label:%d' % (eval_filenames, eval_camids, eval_labels))
66 | concat_features = np.concatenate(eval_features, axis=3)
67 | eval_features = np.reshape(concat_features, [concat_features.shape[0], -1])
68 | all_features.append(eval_features)
69 | all_labels.append(eval_labels)
70 |
71 | # save features and labels
72 | np_features = np.asarray(all_features)
73 | np_features = np.reshape(np_features, [len(all_features), -1])
74 | np_labels = np.asarray(all_labels)
75 | np_labels = np.reshape(np_labels, len(all_labels))
76 | feature_filename = "%s/%s_features.mat" % (FLAGS.eval_dir, FLAGS.split_name)
77 | sio.savemat(feature_filename, {'feature': np_features})
78 | label_filename = "%s/%s_labels.mat" % (FLAGS.eval_dir, FLAGS.split_name)
79 | sio.savemat(label_filename, {'label': np_labels})
80 | print("Done!\n")
81 |
82 | except Exception as e:
83 | coord.request_stop(e)
84 |
85 | coord.request_stop()
86 | coord.join(threads, stop_grace_period_secs=10)
87 |
88 |
89 | def evaluate():
90 | if not FLAGS.dataset_dir:
91 | raise ValueError('You must supply the dataset directory with --dataset_dir')
92 |
93 | tf.logging.set_verbosity(tf.logging.INFO)
94 | with tf.Graph().as_default():
95 | tf_global_step = slim.get_or_create_global_step()
96 |
97 | ######################
98 | # Select the dataset #
99 | ######################
100 | dataset = dataset_factory.get_dataset(
101 | FLAGS.dataset_name, FLAGS.split_name, FLAGS.dataset_dir)
102 |
103 | ####################
104 | # Select the model #
105 | ####################
106 | network_fn = {}
107 | model_names = [net.strip() for net in FLAGS.model_name.split(',')]
108 | for i in range(FLAGS.num_networks):
109 | network_fn["{0}".format(i)] = nets_factory.get_network_fn(
110 | model_names[i],
111 | num_classes=dataset.num_classes,
112 | is_training=False)
113 |
114 | ##############################################################
115 | # Create a dataset provider that loads data from the dataset #
116 | ##############################################################
117 | provider = slim.dataset_data_provider.DatasetDataProvider(
118 | dataset,
119 | shuffle=False,
120 | common_queue_capacity=2 * FLAGS.batch_size,
121 | common_queue_min=FLAGS.batch_size)
122 | [image, label, filename] = provider.get(['image', 'label', 'filename'])
123 |
124 | #####################################
125 | # Select the preprocessing function #
126 | #####################################
127 | preprocessing_name = FLAGS.preprocessing_name
128 | image_preprocessing_fn = preprocessing_factory.get_preprocessing(
129 | preprocessing_name,
130 | is_training=False)
131 |
132 | eval_image_size = network_fn['0'].default_image_size
133 |
134 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
135 |
136 | images, labels, filenames = tf.train.batch(
137 | [image, label, filename],
138 | batch_size=FLAGS.batch_size,
139 | num_threads=FLAGS.num_preprocessing_threads,
140 | capacity=5 * FLAGS.batch_size)
141 |
142 | ####################
143 | # Define the model #
144 | ####################
145 | net_endpoints, net_features = {}, {}
146 | all_features = []
147 | for i in range(FLAGS.num_networks):
148 | _, net_endpoints["{0}".format(i)] = network_fn["{0}".format(i)](images, scope=('dmlnet_%d' % i))
149 | net_features["{0}".format(i)] = net_endpoints["{0}".format(i)]['PreLogits']
150 | all_features.append(net_features["{0}".format(i)])
151 |
152 | if FLAGS.moving_average_decay:
153 | variable_averages = tf.train.ExponentialMovingAverage(
154 | FLAGS.moving_average_decay, tf_global_step)
155 | variables_to_restore = variable_averages.variables_to_restore(
156 | slim.get_model_variables())
157 | variables_to_restore[tf_global_step.op.name] = tf_global_step
158 | else:
159 | variables_to_restore = slim.get_variables_to_restore()
160 |
161 | saver = tf.train.Saver(variables_to_restore)
162 | _extract_once(all_features, labels, filenames, dataset.num_samples, saver)
163 |
--------------------------------------------------------------------------------
/preprocessing/inception_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides utilities to preprocess images for the Inception networks."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from tensorflow.python.ops import control_flow_ops
24 |
25 |
26 | def apply_with_random_selector(x, func, num_cases):
27 | """Computes func(x, sel), with sel sampled from [0...num_cases-1].
28 |
29 | Args:
30 | x: input Tensor.
31 | func: Python function to apply.
32 | num_cases: Python int32, number of cases to sample sel from.
33 |
34 | Returns:
35 | The result of func(x, sel), where func receives the value of the
36 | selector as a python integer, but sel is sampled dynamically.
37 | """
38 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
39 | # Pass the real x only to one of the func calls.
40 | return control_flow_ops.merge([
41 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
42 | for case in range(num_cases)])[0]
43 |
44 |
45 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
46 | """Distort the color of a Tensor image.
47 |
48 | Each color distortion is non-commutative and thus ordering of the color ops
49 | matters. Ideally we would randomly permute the ordering of the color ops.
50 | Rather then adding that level of complication, we select a distinct ordering
51 | of color ops for each preprocessing thread.
52 |
53 | Args:
54 | image: 3-D Tensor containing single image in [0, 1].
55 | color_ordering: Python int, a type of distortion (valid values: 0-3).
56 | fast_mode: Avoids slower ops (random_hue and random_contrast)
57 | scope: Optional scope for name_scope.
58 | Returns:
59 | 3-D Tensor color-distorted image on range [0, 1]
60 | Raises:
61 | ValueError: if color_ordering not in [0, 3]
62 | """
63 | with tf.name_scope(scope, 'distort_color', [image]):
64 | if fast_mode:
65 | if color_ordering == 0:
66 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
67 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
68 | else:
69 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
70 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
71 | else:
72 | if color_ordering == 0:
73 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
74 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
75 | image = tf.image.random_hue(image, max_delta=0.2)
76 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
77 | elif color_ordering == 1:
78 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
79 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
80 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
81 | image = tf.image.random_hue(image, max_delta=0.2)
82 | elif color_ordering == 2:
83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
84 | image = tf.image.random_hue(image, max_delta=0.2)
85 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
86 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
87 | elif color_ordering == 3:
88 | image = tf.image.random_hue(image, max_delta=0.2)
89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
90 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
91 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
92 | else:
93 | raise ValueError('color_ordering must be in [0, 3]')
94 |
95 | # The random_* ops do not necessarily clamp.
96 | return tf.clip_by_value(image, 0.0, 1.0)
97 |
98 |
99 | def distorted_bounding_box_crop(image,
100 | bbox,
101 | min_object_covered=0.1,
102 | aspect_ratio_range=(0.75, 1.33),
103 | area_range=(0.05, 1.0),
104 | max_attempts=100,
105 | scope=None):
106 | """Generates cropped_image using a one of the bboxes randomly distorted.
107 |
108 | See `tf.image.sample_distorted_bounding_box` for more documentation.
109 |
110 | Args:
111 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
112 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
113 | where each coordinate is [0, 1) and the coordinates are arranged
114 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
115 | image.
116 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
117 | area of the image must contain at least this fraction of any bounding box
118 | supplied.
119 | aspect_ratio_range: An optional list of `floats`. The cropped area of the
120 | image must have an aspect ratio = width / height within this range.
121 | area_range: An optional list of `floats`. The cropped area of the image
122 | must contain a fraction of the supplied image within in this range.
123 | max_attempts: An optional `int`. Number of attempts at generating a cropped
124 | region of the image of the specified constraints. After `max_attempts`
125 | failures, return the entire image.
126 | scope: Optional scope for name_scope.
127 | Returns:
128 | A tuple, a 3-D Tensor cropped_image and the distorted bbox
129 | """
130 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
131 | # Each bounding box has shape [1, num_boxes, box coords] and
132 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
133 |
134 | # A large fraction of image datasets contain a human-annotated bounding
135 | # box delineating the region of the image containing the object of interest.
136 | # We choose to create a new bounding box for the object which is a randomly
137 | # distorted version of the human-annotated bounding box that obeys an
138 | # allowed range of aspect ratios, sizes and overlap with the human-annotated
139 | # bounding box. If no box is supplied, then we assume the bounding box is
140 | # the entire image.
141 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
142 | tf.shape(image),
143 | bounding_boxes=bbox,
144 | min_object_covered=min_object_covered,
145 | aspect_ratio_range=aspect_ratio_range,
146 | area_range=area_range,
147 | max_attempts=max_attempts,
148 | use_image_if_no_bounding_boxes=True)
149 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
150 |
151 | # Crop the image to the specified bounding box.
152 | cropped_image = tf.slice(image, bbox_begin, bbox_size)
153 | return cropped_image, distort_bbox
154 |
155 |
156 | def preprocess_for_train(image, height, width, bbox,
157 | fast_mode=True,
158 | scope=None):
159 | """Distort one image for training a network.
160 |
161 | Distorting images provides a useful technique for augmenting the data
162 | set during training in order to make the network invariant to aspects
163 | of the image that do not effect the label.
164 |
165 | Additionally it would create image_summaries to display the different
166 | transformations applied to the image.
167 |
168 | Args:
169 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
170 | [0, 1], otherwise it would converted to tf.float32 assuming that the range
171 | is [0, MAX], where MAX is largest positive representable number for
172 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
173 | height: integer
174 | width: integer
175 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
176 | where each coordinate is [0, 1) and the coordinates are arranged
177 | as [ymin, xmin, ymax, xmax].
178 | fast_mode: Optional boolean, if True avoids slower transformations (i.e.
179 | bi-cubic resizing, random_hue or random_contrast).
180 | scope: Optional scope for name_scope.
181 | Returns:
182 | 3-D float Tensor of distorted image used for training with range [-1, 1].
183 | """
184 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
185 | if bbox is None:
186 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
187 | dtype=tf.float32,
188 | shape=[1, 1, 4])
189 | if image.dtype != tf.float32:
190 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
191 | # Each bounding box has shape [1, num_boxes, box coords] and
192 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
193 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
194 | bbox)
195 | tf.summary.image('image_with_bounding_boxes', image_with_box)
196 |
197 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
198 | # Restore the shape since the dynamic slice based upon the bbox_size loses
199 | # the third dimension.
200 | distorted_image.set_shape([None, None, 3])
201 | image_with_distorted_box = tf.image.draw_bounding_boxes(
202 | tf.expand_dims(image, 0), distorted_bbox)
203 | tf.summary.image('images_with_distorted_bounding_box',
204 | image_with_distorted_box)
205 |
206 | # This resizing operation may distort the images because the aspect
207 | # ratio is not respected. We select a resize method in a round robin
208 | # fashion based on the thread number.
209 | # Note that ResizeMethod contains 4 enumerated resizing methods.
210 |
211 | # We select only 1 case for fast_mode bilinear.
212 | num_resize_cases = 1 if fast_mode else 4
213 | distorted_image = apply_with_random_selector(
214 | distorted_image,
215 | lambda x, method: tf.image.resize_images(x, [height, width], method=method),
216 | num_cases=num_resize_cases)
217 |
218 | tf.summary.image('cropped_resized_image',
219 | tf.expand_dims(distorted_image, 0))
220 |
221 | # Randomly flip the image horizontally.
222 | distorted_image = tf.image.random_flip_left_right(distorted_image)
223 |
224 | # Randomly distort the colors. There are 4 ways to do it.
225 | distorted_image = apply_with_random_selector(
226 | distorted_image,
227 | lambda x, ordering: distort_color(x, ordering, fast_mode),
228 | num_cases=4)
229 |
230 | tf.summary.image('final_distorted_image',
231 | tf.expand_dims(distorted_image, 0))
232 | distorted_image = tf.subtract(distorted_image, 0.5)
233 | distorted_image = tf.multiply(distorted_image, 2.0)
234 | return distorted_image
235 |
236 |
237 | def preprocess_for_eval(image, height, width,
238 | central_fraction=0.875, scope=None):
239 | """Prepare one image for evaluation.
240 |
241 | If height and width are specified it would output an image with that size by
242 | applying resize_bilinear.
243 |
244 | If central_fraction is specified it would cropt the central fraction of the
245 | input image.
246 |
247 | Args:
248 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
249 | [0, 1], otherwise it would converted to tf.float32 assuming that the range
250 | is [0, MAX], where MAX is largest positive representable number for
251 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
252 | height: integer
253 | width: integer
254 | central_fraction: Optional Float, fraction of the image to crop.
255 | scope: Optional scope for name_scope.
256 | Returns:
257 | 3-D float Tensor of prepared image.
258 | """
259 | with tf.name_scope(scope, 'eval_image', [image, height, width]):
260 | if image.dtype != tf.float32:
261 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
262 | # Crop the central region of the image with an area containing 87.5% of
263 | # the original image.
264 | if central_fraction:
265 | image = tf.image.central_crop(image, central_fraction=central_fraction)
266 |
267 | if height and width:
268 | # Resize the image to the specified height and width.
269 | image = tf.expand_dims(image, 0)
270 | image = tf.image.resize_bilinear(image, [height, width],
271 | align_corners=False)
272 | image = tf.squeeze(image, [0])
273 | image = tf.subtract(image, 0.5)
274 | image = tf.multiply(image, 2.0)
275 | return image
276 |
277 |
278 | def preprocess_image(image, height, width,
279 | is_training=False,
280 | bbox=None,
281 | fast_mode=True):
282 | """Pre-process one image for training or evaluation.
283 |
284 | Args:
285 | image: 3-D Tensor [height, width, channels] with the image.
286 | height: integer, image expected height.
287 | width: integer, image expected width.
288 | is_training: Boolean. If true it would transform an image for train,
289 | otherwise it would transform it for evaluation.
290 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
291 | where each coordinate is [0, 1) and the coordinates are arranged as
292 | [ymin, xmin, ymax, xmax].
293 | fast_mode: Optional boolean, if True avoids slower transformations.
294 |
295 | Returns:
296 | 3-D float Tensor containing an appropriately scaled image
297 |
298 | Raises:
299 | ValueError: if user does not provide bounding box
300 | """
301 | if is_training:
302 | return preprocess_for_train(image, height, width, bbox, fast_mode)
303 | else:
304 | return preprocess_for_eval(image, height, width)
305 |
--------------------------------------------------------------------------------
/train_models.py:
--------------------------------------------------------------------------------
1 | """
2 | Generic training script that trains a model using a given dataset.
3 | """
4 |
5 | from __future__ import absolute_import
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import tensorflow as tf
10 | from datasets import dataset_factory
11 | from deployment import model_deploy
12 | from nets import nets_factory
13 | from preprocessing import preprocessing_factory
14 | from datasets.utils import *
15 | import numpy as np
16 |
17 | slim = tf.contrib.slim
18 |
19 | FLAGS = tf.app.flags.FLAGS
20 |
21 |
22 | def _average_gradients(tower_grads, catname=None):
23 | """Calculate the average gradient for each shared variable across all towers.
24 |
25 | Note that this function provides a synchronization point across all towers.
26 |
27 | Args:
28 | tower_grads: List of lists of (gradient, variable) tuples. The outer list
29 | is over individual gradients. The inner list is over the gradient
30 | calculation for each tower.
31 | Returns:
32 | List of pairs of (gradient, variable) where the gradient has been averaged
33 | across all towers.
34 | """
35 | average_grads = []
36 | for grad_and_vars in zip(*tower_grads):
37 | # Note that each grad_and_vars looks like the following:
38 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
39 | grads = []
40 | for g, _ in grad_and_vars:
41 | # Add 0 dimension to the gradients to represent the tower.
42 | expanded_g = tf.expand_dims(input=g, axis=0)
43 | # print(g)
44 | # Append on a 'tower' dimension which we will average over below.
45 | grads.append(expanded_g)
46 |
47 | # Average over the 'tower' dimension.
48 | grad = tf.concat(axis=0, values=grads, name=catname)
49 | grad = tf.reduce_mean(grad, 0)
50 |
51 | # Keep in mind that the Variables are redundant because they are shared
52 | # across towers. So .. we will just return the first tower's pointer to
53 | # the Variable.
54 | v = grad_and_vars[0][1]
55 | grad_and_var = (grad, v)
56 | average_grads.append(grad_and_var)
57 | return average_grads
58 |
59 |
60 | def kl_loss_compute(logits1, logits2):
61 | """ KL loss
62 | """
63 | pred1 = tf.nn.softmax(logits1)
64 | pred2 = tf.nn.softmax(logits2)
65 | loss = tf.reduce_mean(tf.reduce_sum(pred2 * tf.log(1e-8 + pred2 / (pred1 + 1e-8)), 1))
66 |
67 | return loss
68 |
69 |
70 | def _tower_loss(network_fn, images, labels):
71 | """Calculate the total loss on a single tower running the reid model."""
72 | # Build inference Graph.
73 | net_logits, net_endpoints, net_raw_loss, net_pred, net_features = {}, {}, {}, {}, {}
74 | for i in range(FLAGS.num_networks):
75 | net_logits["{0}".format(i)], net_endpoints["{0}".format(i)] = \
76 | network_fn["{0}".format(i)](images, scope=('dmlnet_%d' % i))
77 | net_raw_loss["{0}".format(i)] = tf.losses.softmax_cross_entropy(
78 | logits=net_logits["{0}".format(i)], onehot_labels=labels,
79 | label_smoothing=FLAGS.label_smoothing, weights=1.0)
80 | net_pred["{0}".format(i)] = net_endpoints["{0}".format(i)]['Predictions']
81 |
82 | if 'AuxLogits' in net_endpoints["{0}".format(i)]:
83 | net_raw_loss["{0}".format(i)] += tf.losses.softmax_cross_entropy(
84 | logits=net_endpoints["{0}".format(i)]['AuxLogits'], onehot_labels=labels,
85 | label_smoothing=FLAGS.label_smoothing, weights=0.4, scope='aux_loss')
86 |
87 | # Add KL loss if there are more than one network
88 | net_loss, kl_loss, net_reg_loss, net_total_loss, net_loss_averages, net_loss_averages_op = {}, {}, {}, {}, {}, {}
89 |
90 | for i in range(FLAGS.num_networks):
91 | net_loss["{0}".format(i)] = net_raw_loss["{0}".format(i)]
92 | for j in range(FLAGS.num_networks):
93 | if i != j:
94 | kl_loss["{0}{0}".format(i, j)] = kl_loss_compute(net_logits["{0}".format(i)], net_logits["{0}".format(j)])
95 | net_loss["{0}".format(i)] += kl_loss["{0}{0}".format(i, j)]
96 | tf.summary.scalar('kl_loss_%d%d' % (i, j), kl_loss["{0}{0}".format(i, j)])
97 |
98 | net_reg_loss["{0}".format(i)] = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, scope=('dmlnet_%d' % i))
99 | net_total_loss["{0}".format(i)] = tf.add_n([net_loss["{0}".format(i)]] +
100 | net_reg_loss["{0}".format(i)],
101 | name=('net%d_total_loss' % i))
102 |
103 | net_loss_averages["{0}".format(i)] = tf.train.ExponentialMovingAverage(0.9, name='net%d_avg' % i)
104 | net_loss_averages_op["{0}".format(i)] = net_loss_averages["{0}".format(i)].apply(
105 | [net_loss["{0}".format(i)]] + [net_total_loss["{0}".format(i)]])
106 |
107 | tf.summary.scalar('net%d_loss_raw' % i, net_raw_loss["{0}".format(i)])
108 | tf.summary.scalar('net%d_loss_sum' % i, net_loss["{0}".format(i)])
109 | tf.summary.scalar('net%d_loss_avg' % i, net_loss_averages["{0}".format(i)].average(net_loss["{0}".format(i)]))
110 |
111 | with tf.control_dependencies([net_loss_averages_op["{0}".format(i)]]):
112 | net_total_loss["{0}".format(i)] = tf.identity(net_total_loss["{0}".format(i)])
113 |
114 | return net_total_loss, net_pred
115 |
116 |
117 | def train():
118 | if not FLAGS.dataset_dir:
119 | raise ValueError('You must supply the dataset directory with --dataset_dir')
120 |
121 | tf.logging.set_verbosity(tf.logging.INFO)
122 | with tf.Graph().as_default():
123 | #######################
124 | # Config model_deploy #
125 | #######################
126 | deploy_config = model_deploy.DeploymentConfig(
127 | num_clones=FLAGS.num_clones,
128 | clone_on_cpu=FLAGS.clone_on_cpu,
129 | replica_id=FLAGS.task,
130 | num_replicas=FLAGS.worker_replicas,
131 | num_ps_tasks=FLAGS.num_ps_tasks)
132 |
133 | # Create global_step
134 | with tf.device(deploy_config.variables_device()):
135 | global_step = slim.create_global_step()
136 |
137 | ######################
138 | # Select the dataset #
139 | ######################
140 | dataset = dataset_factory.get_dataset(
141 | FLAGS.dataset_name, FLAGS.split_name, FLAGS.dataset_dir)
142 |
143 | ######################
144 | # Select the network and #
145 | ######################
146 | network_fn = {}
147 | model_names = [net.strip() for net in FLAGS.model_name.split(',')]
148 | for i in range(FLAGS.num_networks):
149 | network_fn["{0}".format(i)] = nets_factory.get_network_fn(
150 | model_names[i],
151 | num_classes=dataset.num_classes,
152 | weight_decay=FLAGS.weight_decay,
153 | is_training=True)
154 |
155 | #########################################
156 | # Configure the optimization procedure. #
157 | #########################################
158 | with tf.device(deploy_config.optimizer_device()):
159 | net_opt = {}
160 | for i in range(FLAGS.num_networks):
161 | net_opt["{0}".format(i)] = tf.train.AdamOptimizer(FLAGS.learning_rate,
162 | beta1=FLAGS.adam_beta1,
163 | beta2=FLAGS.adam_beta2,
164 | epsilon=FLAGS.opt_epsilon)
165 |
166 | #####################################
167 | # Select the preprocessing function #
168 | #####################################
169 | preprocessing_name = FLAGS.preprocessing_name # or FLAGS.model_name
170 | image_preprocessing_fn = preprocessing_factory.get_preprocessing(
171 | preprocessing_name,
172 | is_training=True)
173 |
174 | ##############################################################
175 | # Create a dataset provider that loads data from the dataset #
176 | ##############################################################
177 | with tf.device(deploy_config.inputs_device()):
178 | examples_per_shard = 1024
179 | min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
180 | provider = slim.dataset_data_provider.DatasetDataProvider(
181 | dataset,
182 | num_readers=FLAGS.num_readers,
183 | common_queue_capacity=min_queue_examples + 3 * FLAGS.batch_size,
184 | common_queue_min=min_queue_examples)
185 | [image, label] = provider.get(['image', 'label'])
186 |
187 | train_image_size = network_fn["{0}".format(0)].default_image_size
188 |
189 | image = image_preprocessing_fn(image, train_image_size, train_image_size)
190 |
191 | images, labels = tf.train.batch(
192 | [image, label],
193 | batch_size=FLAGS.batch_size,
194 | num_threads=FLAGS.num_preprocessing_threads,
195 | capacity=2 * FLAGS.num_preprocessing_threads * FLAGS.batch_size)
196 | labels = slim.one_hot_encoding(labels, dataset.num_classes)
197 | batch_queue = slim.prefetch_queue.prefetch_queue(
198 | [images, labels], capacity=16 * deploy_config.num_clones,
199 | num_threads=FLAGS.num_preprocessing_threads)
200 |
201 | images, labels = batch_queue.dequeue()
202 |
203 | images_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=images)
204 | labels_splits = tf.split(axis=0, num_or_size_splits=FLAGS.num_gpus, value=labels)
205 |
206 | precision, net_tower_grads, net_update_ops, net_var_list, net_grads = {}, {}, {}, {}, {}
207 |
208 | for i in range(FLAGS.num_networks):
209 | net_tower_grads["{0}".format(i)] = []
210 |
211 | for k in xrange(FLAGS.num_gpus):
212 | with tf.device('/gpu:%d' % k):
213 | with tf.name_scope('tower_%d' % k) as scope:
214 | with tf.variable_scope(tf.get_variable_scope()):
215 |
216 | net_loss, net_pred = _tower_loss(network_fn, images_splits[k], labels_splits[k])
217 |
218 | truth = tf.argmax(labels_splits[k], axis=1)
219 |
220 | # Reuse variables for the next tower.
221 | tf.get_variable_scope().reuse_variables()
222 |
223 | # Retain the summaries from the final tower.
224 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
225 | var_list = tf.trainable_variables()
226 |
227 | for i in range(FLAGS.num_networks):
228 | predictions = tf.argmax(net_pred["{0}".format(i)], axis=1)
229 | precision["{0}".format(i)] = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))
230 |
231 | # Add a summary to track the training precision.
232 | summaries.append(tf.summary.scalar('precision_%d' % i, precision["{0}".format(i)]))
233 |
234 | net_update_ops["{0}".format(i)] = \
235 | tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=('%sdmlnet_%d' % (scope, i)))
236 |
237 | net_var_list["{0}".format(i)] = \
238 | [var for var in var_list if 'dmlnet_%d' % i in var.name]
239 |
240 | net_grads["{0}".format(i)] = net_opt["{0}".format(i)].compute_gradients(
241 | net_loss["{0}".format(i)], var_list=net_var_list["{0}".format(i)])
242 |
243 | net_tower_grads["{0}".format(i)].append(net_grads["{0}".format(i)])
244 |
245 | # We must calculate the mean of each gradient. Note that this is the
246 | # synchronization point across all towers.
247 | for i in range(FLAGS.num_networks):
248 | net_grads["{0}".format(i)] = _average_gradients(net_tower_grads["{0}".format(i)],
249 | catname=('dmlnet_%d_cat' % i))
250 |
251 | # Add histograms for histogram and trainable variables.
252 | for i in range(FLAGS.num_networks):
253 | for grad, var in net_grads["{0}".format(i)]:
254 | if grad is not None:
255 | summaries.append(tf.summary.histogram(var.op.name + '/gradients', grad))
256 |
257 | for var in tf.trainable_variables():
258 | summaries.append(tf.summary.histogram(var.op.name, var))
259 |
260 | #################################
261 | # Configure the moving averages #
262 | #################################
263 |
264 | if FLAGS.moving_average_decay:
265 | moving_average_variables = {}
266 | all_moving_average_variables = slim.get_model_variables()
267 | variable_averages = tf.train.ExponentialMovingAverage(
268 | FLAGS.moving_average_decay, global_step)
269 | for i in range(FLAGS.num_networks):
270 | moving_average_variables["{0}".format(i)] = \
271 | [var for var in all_moving_average_variables if 'dmlnet_%d' % i in var.name]
272 | net_update_ops["{0}".format(i)].append(
273 | variable_averages.apply(moving_average_variables["{0}".format(i)]))
274 |
275 | # Apply the gradients to adjust the shared variables.
276 | net_grad_updates, net_train_op = {}, {}
277 | for i in range(FLAGS.num_networks):
278 | net_grad_updates["{0}".format(i)] = net_opt["{0}".format(i)].apply_gradients(
279 | net_grads["{0}".format(i)], global_step=global_step)
280 | net_update_ops["{0}".format(i)].append(net_grad_updates["{0}".format(i)])
281 | # Group all updates to into a single train op.
282 | net_train_op["{0}".format(i)] = tf.group(*net_update_ops["{0}".format(i)])
283 |
284 | # Create a saver.
285 | saver = tf.train.Saver(tf.global_variables())
286 |
287 | # Build the summary operation from the last tower summaries.
288 | summary_op = tf.summary.merge(summaries)
289 |
290 | # Build an initialization operation to run below.
291 | init = tf.global_variables_initializer()
292 |
293 | # Start running operations on the Graph. allow_soft_placement must be set to
294 | # True to build towers on GPU, as some of the ops do not have GPU
295 | # implementations.
296 | sess = tf.Session(config=tf.ConfigProto(
297 | allow_soft_placement=True,
298 | log_device_placement=FLAGS.log_device_placement))
299 | sess.run(init)
300 |
301 | # Start the queue runners.
302 | tf.train.start_queue_runners(sess=sess)
303 |
304 | summary_writer = tf.summary.FileWriter(
305 | os.path.join(FLAGS.log_dir),
306 | graph=sess.graph)
307 |
308 | net_loss_value, precision_value = {}, {}
309 |
310 | for step in xrange(FLAGS.max_number_of_steps):
311 |
312 | for i in range(FLAGS.num_networks):
313 | _, net_loss_value["{0}".format(i)], precision_value["{0}".format(i)] = \
314 | sess.run([net_train_op["{0}".format(i)], net_loss["{0}".format(i)],
315 | precision["{0}".format(i)]])
316 | assert not np.isnan(net_loss_value["{0}".format(i)]), 'Model diverged with loss = NaN'
317 |
318 | if step % 10 == 0:
319 | format_str = '%s: step %d, net0_loss = %.2f, net0_acc = %.4f'
320 | print(format_str % (FLAGS.dataset_name, step, net_loss_value["{0}".format(0)],
321 | precision_value["{0}".format(0)]))
322 |
323 | if step % 100 == 0:
324 | summary_str = sess.run(summary_op)
325 | summary_writer.add_summary(summary_str, step)
326 |
327 | # Save the model checkpoint periodically.
328 | if step % FLAGS.ckpt_steps == 0 or (step + 1) == FLAGS.max_number_of_steps:
329 | checkpoint_path = os.path.join(FLAGS.checkpoint_dir, 'model.ckpt')
330 | saver.save(sess, checkpoint_path, global_step=step)
331 |
332 |
--------------------------------------------------------------------------------
/nets/inception_v1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains the definition for inception v1 classification network."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from nets import inception_utils
24 |
25 | slim = tf.contrib.slim
26 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
27 |
28 |
29 | def inception_v1_base(inputs,
30 | final_endpoint='Mixed_5c',
31 | scope='InceptionV1'):
32 | """Defines the Inception V1 base architecture.
33 |
34 | This architecture is defined in:
35 | Going deeper with convolutions
36 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
37 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
38 | http://arxiv.org/pdf/1409.4842v1.pdf.
39 |
40 | Args:
41 | inputs: a tensor of size [batch_size, height, width, channels].
42 | final_endpoint: specifies the endpoint to construct the network up to. It
43 | can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
44 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
45 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
46 | 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
47 | scope: Optional variable_scope.
48 |
49 | Returns:
50 | A dictionary from components of the network to the corresponding activation.
51 |
52 | Raises:
53 | ValueError: if final_endpoint is not set to one of the predefined values.
54 | """
55 | end_points = {}
56 | with tf.variable_scope(scope, 'InceptionV1', [inputs]):
57 | with slim.arg_scope(
58 | [slim.conv2d, slim.fully_connected],
59 | weights_initializer=trunc_normal(0.01)):
60 | with slim.arg_scope([slim.conv2d, slim.max_pool2d],
61 | stride=1, padding='SAME'):
62 | end_point = 'Conv2d_1a_7x7'
63 | net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point)
64 | end_points[end_point] = net
65 | if final_endpoint == end_point: return net, end_points
66 | end_point = 'MaxPool_2a_3x3'
67 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
68 | end_points[end_point] = net
69 | if final_endpoint == end_point: return net, end_points
70 | end_point = 'Conv2d_2b_1x1'
71 | net = slim.conv2d(net, 64, [1, 1], scope=end_point)
72 | end_points[end_point] = net
73 | if final_endpoint == end_point: return net, end_points
74 | end_point = 'Conv2d_2c_3x3'
75 | net = slim.conv2d(net, 192, [3, 3], scope=end_point)
76 | end_points[end_point] = net
77 | if final_endpoint == end_point: return net, end_points
78 | end_point = 'MaxPool_3a_3x3'
79 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
80 | end_points[end_point] = net
81 | if final_endpoint == end_point: return net, end_points
82 |
83 | end_point = 'Mixed_3b'
84 | with tf.variable_scope(end_point):
85 | with tf.variable_scope('Branch_0'):
86 | branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
87 | with tf.variable_scope('Branch_1'):
88 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
89 | branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3')
90 | with tf.variable_scope('Branch_2'):
91 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
92 | branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3')
93 | with tf.variable_scope('Branch_3'):
94 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
95 | branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1')
96 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
97 | end_points[end_point] = net
98 | if final_endpoint == end_point: return net, end_points
99 |
100 | end_point = 'Mixed_3c'
101 | with tf.variable_scope(end_point):
102 | with tf.variable_scope('Branch_0'):
103 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
104 | with tf.variable_scope('Branch_1'):
105 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
106 | branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3')
107 | with tf.variable_scope('Branch_2'):
108 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
109 | branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
110 | with tf.variable_scope('Branch_3'):
111 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
112 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
113 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
114 | end_points[end_point] = net
115 | if final_endpoint == end_point: return net, end_points
116 |
117 | end_point = 'MaxPool_4a_3x3'
118 | net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
119 | end_points[end_point] = net
120 | if final_endpoint == end_point: return net, end_points
121 |
122 | end_point = 'Mixed_4b'
123 | with tf.variable_scope(end_point):
124 | with tf.variable_scope('Branch_0'):
125 | branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
126 | with tf.variable_scope('Branch_1'):
127 | branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
128 | branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3')
129 | with tf.variable_scope('Branch_2'):
130 | branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
131 | branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3')
132 | with tf.variable_scope('Branch_3'):
133 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
134 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
135 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
136 | end_points[end_point] = net
137 | if final_endpoint == end_point: return net, end_points
138 |
139 | end_point = 'Mixed_4c'
140 | with tf.variable_scope(end_point):
141 | with tf.variable_scope('Branch_0'):
142 | branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
143 | with tf.variable_scope('Branch_1'):
144 | branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
145 | branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
146 | with tf.variable_scope('Branch_2'):
147 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
148 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
149 | with tf.variable_scope('Branch_3'):
150 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
151 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
152 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
153 | end_points[end_point] = net
154 | if final_endpoint == end_point: return net, end_points
155 |
156 | end_point = 'Mixed_4d'
157 | with tf.variable_scope(end_point):
158 | with tf.variable_scope('Branch_0'):
159 | branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
160 | with tf.variable_scope('Branch_1'):
161 | branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
162 | branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3')
163 | with tf.variable_scope('Branch_2'):
164 | branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
165 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
166 | with tf.variable_scope('Branch_3'):
167 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
168 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
169 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
170 | end_points[end_point] = net
171 | if final_endpoint == end_point: return net, end_points
172 |
173 | end_point = 'Mixed_4e'
174 | with tf.variable_scope(end_point):
175 | with tf.variable_scope('Branch_0'):
176 | branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
177 | with tf.variable_scope('Branch_1'):
178 | branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1')
179 | branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3')
180 | with tf.variable_scope('Branch_2'):
181 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
182 | branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
183 | with tf.variable_scope('Branch_3'):
184 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
185 | branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
186 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
187 | end_points[end_point] = net
188 | if final_endpoint == end_point: return net, end_points
189 |
190 | end_point = 'Mixed_4f'
191 | with tf.variable_scope(end_point):
192 | with tf.variable_scope('Branch_0'):
193 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
194 | with tf.variable_scope('Branch_1'):
195 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
196 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
197 | with tf.variable_scope('Branch_2'):
198 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
199 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
200 | with tf.variable_scope('Branch_3'):
201 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
202 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
203 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
204 | end_points[end_point] = net
205 | if final_endpoint == end_point: return net, end_points
206 |
207 | end_point = 'MaxPool_5a_2x2'
208 | net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point)
209 | end_points[end_point] = net
210 | if final_endpoint == end_point: return net, end_points
211 |
212 | end_point = 'Mixed_5b'
213 | with tf.variable_scope(end_point):
214 | with tf.variable_scope('Branch_0'):
215 | branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
216 | with tf.variable_scope('Branch_1'):
217 | branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
218 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
219 | with tf.variable_scope('Branch_2'):
220 | branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
221 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3')
222 | with tf.variable_scope('Branch_3'):
223 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
224 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
225 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
226 | end_points[end_point] = net
227 | if final_endpoint == end_point: return net, end_points
228 |
229 | end_point = 'Mixed_5c'
230 | with tf.variable_scope(end_point):
231 | with tf.variable_scope('Branch_0'):
232 | branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1')
233 | with tf.variable_scope('Branch_1'):
234 | branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
235 | branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3')
236 | with tf.variable_scope('Branch_2'):
237 | branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1')
238 | branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
239 | with tf.variable_scope('Branch_3'):
240 | branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
241 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
242 | net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
243 | end_points[end_point] = net
244 | if final_endpoint == end_point: return net, end_points
245 | raise ValueError('Unknown final endpoint %s' % final_endpoint)
246 |
247 |
248 | def inception_v1(inputs,
249 | num_classes=1000,
250 | is_training=True,
251 | dropout_keep_prob=0.8,
252 | prediction_fn=slim.softmax,
253 | spatial_squeeze=True,
254 | reuse=None,
255 | scope='InceptionV1'):
256 | """Defines the Inception V1 architecture.
257 |
258 | This architecture is defined in:
259 |
260 | Going deeper with convolutions
261 | Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
262 | Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
263 | http://arxiv.org/pdf/1409.4842v1.pdf.
264 |
265 | The default image size used to train this network is 224x224.
266 |
267 | Args:
268 | inputs: a tensor of size [batch_size, height, width, channels].
269 | num_classes: number of predicted classes.
270 | is_training: whether is training or not.
271 | dropout_keep_prob: the percentage of activation values that are retained.
272 | prediction_fn: a function to get predictions out of logits.
273 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is
274 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
275 | reuse: whether or not the network and its variables should be reused. To be
276 | able to reuse 'scope' must be given.
277 | scope: Optional variable_scope.
278 |
279 | Returns:
280 | logits: the pre-softmax activations, a tensor of size
281 | [batch_size, num_classes]
282 | end_points: a dictionary from components of the network to the corresponding
283 | activation.
284 | """
285 | # Final pooling and prediction
286 | with tf.variable_scope(scope, 'InceptionV1', [inputs, num_classes],
287 | reuse=reuse) as scope:
288 | with slim.arg_scope([slim.batch_norm, slim.dropout],
289 | is_training=is_training):
290 | net, end_points = inception_v1_base(inputs, scope=scope)
291 | with tf.variable_scope('Logits'):
292 | net = slim.avg_pool2d(net, [7, 7], stride=1, scope='AvgPool_0a_7x7')
293 | net = slim.dropout(net, dropout_keep_prob, scope='Dropout_0b')
294 | end_points['PreLogits'] = net
295 | if not num_classes:
296 | return net, end_points
297 |
298 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
299 | normalizer_fn=None, scope='Conv2d_0c_1x1')
300 | if spatial_squeeze:
301 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
302 |
303 | end_points['Logits'] = logits
304 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
305 | return logits, end_points
306 |
307 |
308 | inception_v1.default_image_size = 224
309 |
310 | inception_v1_arg_scope = inception_utils.inception_arg_scope
311 |
--------------------------------------------------------------------------------
/nets/mobilenet_v1.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # =============================================================================
15 | """MobileNet v1.
16 |
17 | MobileNet is a general architecture and can be used for multiple use cases.
18 | Depending on the use case, it can use different input layer size and different
19 | head (for example: embeddings, localization and classification).
20 |
21 | As described in https://arxiv.org/abs/1704.04861.
22 |
23 | MobileNets: Efficient Convolutional Neural Networks for
24 | Mobile Vision Applications
25 | Andrew G. Howard, Menglong Zhu, Bo Chen, Dmitry Kalenichenko, Weijun Wang,
26 | Tobias Weyand, Marco Andreetto, Hartwig Adam
27 |
28 | 100% Mobilenet V1 (base) with input size 224x224:
29 |
30 | See mobilenet_v1()
31 |
32 | Layer params macs
33 | --------------------------------------------------------------------------------
34 | MobilenetV1/Conv2d_0/Conv2D: 864 10,838,016
35 | MobilenetV1/Conv2d_1_depthwise/depthwise: 288 3,612,672
36 | MobilenetV1/Conv2d_1_pointwise/Conv2D: 2,048 25,690,112
37 | MobilenetV1/Conv2d_2_depthwise/depthwise: 576 1,806,336
38 | MobilenetV1/Conv2d_2_pointwise/Conv2D: 8,192 25,690,112
39 | MobilenetV1/Conv2d_3_depthwise/depthwise: 1,152 3,612,672
40 | MobilenetV1/Conv2d_3_pointwise/Conv2D: 16,384 51,380,224
41 | MobilenetV1/Conv2d_4_depthwise/depthwise: 1,152 903,168
42 | MobilenetV1/Conv2d_4_pointwise/Conv2D: 32,768 25,690,112
43 | MobilenetV1/Conv2d_5_depthwise/depthwise: 2,304 1,806,336
44 | MobilenetV1/Conv2d_5_pointwise/Conv2D: 65,536 51,380,224
45 | MobilenetV1/Conv2d_6_depthwise/depthwise: 2,304 451,584
46 | MobilenetV1/Conv2d_6_pointwise/Conv2D: 131,072 25,690,112
47 | MobilenetV1/Conv2d_7_depthwise/depthwise: 4,608 903,168
48 | MobilenetV1/Conv2d_7_pointwise/Conv2D: 262,144 51,380,224
49 | MobilenetV1/Conv2d_8_depthwise/depthwise: 4,608 903,168
50 | MobilenetV1/Conv2d_8_pointwise/Conv2D: 262,144 51,380,224
51 | MobilenetV1/Conv2d_9_depthwise/depthwise: 4,608 903,168
52 | MobilenetV1/Conv2d_9_pointwise/Conv2D: 262,144 51,380,224
53 | MobilenetV1/Conv2d_10_depthwise/depthwise: 4,608 903,168
54 | MobilenetV1/Conv2d_10_pointwise/Conv2D: 262,144 51,380,224
55 | MobilenetV1/Conv2d_11_depthwise/depthwise: 4,608 903,168
56 | MobilenetV1/Conv2d_11_pointwise/Conv2D: 262,144 51,380,224
57 | MobilenetV1/Conv2d_12_depthwise/depthwise: 4,608 225,792
58 | MobilenetV1/Conv2d_12_pointwise/Conv2D: 524,288 25,690,112
59 | MobilenetV1/Conv2d_13_depthwise/depthwise: 9,216 451,584
60 | MobilenetV1/Conv2d_13_pointwise/Conv2D: 1,048,576 51,380,224
61 | --------------------------------------------------------------------------------
62 | Total: 3,185,088 567,716,352
63 |
64 |
65 | 75% Mobilenet V1 (base) with input size 128x128:
66 |
67 | See mobilenet_v1_075()
68 |
69 | Layer params macs
70 | --------------------------------------------------------------------------------
71 | MobilenetV1/Conv2d_0/Conv2D: 648 2,654,208
72 | MobilenetV1/Conv2d_1_depthwise/depthwise: 216 884,736
73 | MobilenetV1/Conv2d_1_pointwise/Conv2D: 1,152 4,718,592
74 | MobilenetV1/Conv2d_2_depthwise/depthwise: 432 442,368
75 | MobilenetV1/Conv2d_2_pointwise/Conv2D: 4,608 4,718,592
76 | MobilenetV1/Conv2d_3_depthwise/depthwise: 864 884,736
77 | MobilenetV1/Conv2d_3_pointwise/Conv2D: 9,216 9,437,184
78 | MobilenetV1/Conv2d_4_depthwise/depthwise: 864 221,184
79 | MobilenetV1/Conv2d_4_pointwise/Conv2D: 18,432 4,718,592
80 | MobilenetV1/Conv2d_5_depthwise/depthwise: 1,728 442,368
81 | MobilenetV1/Conv2d_5_pointwise/Conv2D: 36,864 9,437,184
82 | MobilenetV1/Conv2d_6_depthwise/depthwise: 1,728 110,592
83 | MobilenetV1/Conv2d_6_pointwise/Conv2D: 73,728 4,718,592
84 | MobilenetV1/Conv2d_7_depthwise/depthwise: 3,456 221,184
85 | MobilenetV1/Conv2d_7_pointwise/Conv2D: 147,456 9,437,184
86 | MobilenetV1/Conv2d_8_depthwise/depthwise: 3,456 221,184
87 | MobilenetV1/Conv2d_8_pointwise/Conv2D: 147,456 9,437,184
88 | MobilenetV1/Conv2d_9_depthwise/depthwise: 3,456 221,184
89 | MobilenetV1/Conv2d_9_pointwise/Conv2D: 147,456 9,437,184
90 | MobilenetV1/Conv2d_10_depthwise/depthwise: 3,456 221,184
91 | MobilenetV1/Conv2d_10_pointwise/Conv2D: 147,456 9,437,184
92 | MobilenetV1/Conv2d_11_depthwise/depthwise: 3,456 221,184
93 | MobilenetV1/Conv2d_11_pointwise/Conv2D: 147,456 9,437,184
94 | MobilenetV1/Conv2d_12_depthwise/depthwise: 3,456 55,296
95 | MobilenetV1/Conv2d_12_pointwise/Conv2D: 294,912 4,718,592
96 | MobilenetV1/Conv2d_13_depthwise/depthwise: 6,912 110,592
97 | MobilenetV1/Conv2d_13_pointwise/Conv2D: 589,824 9,437,184
98 | --------------------------------------------------------------------------------
99 | Total: 1,800,144 106,002,432
100 |
101 | """
102 |
103 | # Tensorflow mandates these.
104 | from __future__ import absolute_import
105 | from __future__ import division
106 | from __future__ import print_function
107 |
108 | from collections import namedtuple
109 | import functools
110 |
111 | import tensorflow as tf
112 |
113 | slim = tf.contrib.slim
114 |
115 | # Conv and DepthSepConv namedtuple define layers of the MobileNet architecture
116 | # Conv defines 3x3 convolution layers
117 | # DepthSepConv defines 3x3 depthwise convolution followed by 1x1 convolution.
118 | # stride is the stride of the convolution
119 | # depth is the number of channels or filters in a layer
120 | Conv = namedtuple('Conv', ['kernel', 'stride', 'depth'])
121 | DepthSepConv = namedtuple('DepthSepConv', ['kernel', 'stride', 'depth'])
122 |
123 | # _CONV_DEFS specifies the MobileNet body
124 | _CONV_DEFS = [
125 | Conv(kernel=[3, 3], stride=2, depth=32),
126 | DepthSepConv(kernel=[3, 3], stride=1, depth=64),
127 | DepthSepConv(kernel=[3, 3], stride=2, depth=128),
128 | DepthSepConv(kernel=[3, 3], stride=1, depth=128),
129 | DepthSepConv(kernel=[3, 3], stride=2, depth=256),
130 | DepthSepConv(kernel=[3, 3], stride=1, depth=256),
131 | DepthSepConv(kernel=[3, 3], stride=2, depth=512),
132 | DepthSepConv(kernel=[3, 3], stride=1, depth=512),
133 | DepthSepConv(kernel=[3, 3], stride=1, depth=512),
134 | DepthSepConv(kernel=[3, 3], stride=1, depth=512),
135 | DepthSepConv(kernel=[3, 3], stride=1, depth=512),
136 | DepthSepConv(kernel=[3, 3], stride=1, depth=512),
137 | DepthSepConv(kernel=[3, 3], stride=2, depth=1024),
138 | DepthSepConv(kernel=[3, 3], stride=1, depth=1024)
139 | ]
140 |
141 |
142 | def mobilenet_v1_base(inputs,
143 | final_endpoint='Conv2d_13_pointwise',
144 | min_depth=8,
145 | depth_multiplier=1.0,
146 | conv_defs=None,
147 | output_stride=None,
148 | scope=None):
149 | """Mobilenet v1.
150 |
151 | Constructs a Mobilenet v1 network from inputs to the given final endpoint.
152 |
153 | Args:
154 | inputs: a tensor of shape [batch_size, height, width, channels].
155 | final_endpoint: specifies the endpoint to construct the network up to. It
156 | can be one of ['Conv2d_0', 'Conv2d_1_pointwise', 'Conv2d_2_pointwise',
157 | 'Conv2d_3_pointwise', 'Conv2d_4_pointwise', 'Conv2d_5'_pointwise,
158 | 'Conv2d_6_pointwise', 'Conv2d_7_pointwise', 'Conv2d_8_pointwise',
159 | 'Conv2d_9_pointwise', 'Conv2d_10_pointwise', 'Conv2d_11_pointwise',
160 | 'Conv2d_12_pointwise', 'Conv2d_13_pointwise'].
161 | min_depth: Minimum depth value (number of channels) for all convolution ops.
162 | Enforced when depth_multiplier < 1, and not an active constraint when
163 | depth_multiplier >= 1.
164 | depth_multiplier: Float multiplier for the depth (number of channels)
165 | for all convolution ops. The value must be greater than zero. Typical
166 | usage will be to set this value in (0, 1) to reduce the number of
167 | parameters or computation cost of the model.
168 | conv_defs: A list of ConvDef namedtuples specifying the net architecture.
169 | output_stride: An integer that specifies the requested ratio of input to
170 | output spatial resolution. If not None, then we invoke atrous convolution
171 | if necessary to prevent the network from reducing the spatial resolution
172 | of the activation maps. Allowed values are 8 (accurate fully convolutional
173 | mode), 16 (fast fully convolutional mode), 32 (classification mode).
174 | scope: Optional variable_scope.
175 |
176 | Returns:
177 | tensor_out: output tensor corresponding to the final_endpoint.
178 | end_points: a set of activations for external use, for example summaries or
179 | losses.
180 |
181 | Raises:
182 | ValueError: if final_endpoint is not set to one of the predefined values,
183 | or depth_multiplier <= 0, or the target output_stride is not
184 | allowed.
185 | """
186 | depth = lambda d: max(int(d * depth_multiplier), min_depth)
187 | end_points = {}
188 |
189 | # Used to find thinned depths for each layer.
190 | if depth_multiplier <= 0:
191 | raise ValueError('depth_multiplier is not greater than zero.')
192 |
193 | if conv_defs is None:
194 | conv_defs = _CONV_DEFS
195 |
196 | if output_stride is not None and output_stride not in [8, 16, 32]:
197 | raise ValueError('Only allowed output_stride values are 8, 16, 32.')
198 |
199 | with tf.variable_scope(scope, 'MobilenetV1', [inputs]):
200 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], padding='SAME'):
201 | # The current_stride variable keeps track of the output stride of the
202 | # activations, i.e., the running product of convolution strides up to the
203 | # current network layer. This allows us to invoke atrous convolution
204 | # whenever applying the next convolution would result in the activations
205 | # having output stride larger than the target output_stride.
206 | current_stride = 1
207 |
208 | # The atrous convolution rate parameter.
209 | rate = 1
210 |
211 | net = inputs
212 | for i, conv_def in enumerate(conv_defs):
213 | end_point_base = 'Conv2d_%d' % i
214 |
215 | if output_stride is not None and current_stride == output_stride:
216 | # If we have reached the target output_stride, then we need to employ
217 | # atrous convolution with stride=1 and multiply the atrous rate by the
218 | # current unit's stride for use in subsequent layers.
219 | layer_stride = 1
220 | layer_rate = rate
221 | rate *= conv_def.stride
222 | else:
223 | layer_stride = conv_def.stride
224 | layer_rate = 1
225 | current_stride *= conv_def.stride
226 |
227 | if isinstance(conv_def, Conv):
228 | end_point = end_point_base
229 | net = slim.conv2d(net, depth(conv_def.depth), conv_def.kernel,
230 | stride=conv_def.stride,
231 | normalizer_fn=slim.batch_norm,
232 | scope=end_point)
233 | end_points[end_point] = net
234 | if end_point == final_endpoint:
235 | return net, end_points
236 |
237 | elif isinstance(conv_def, DepthSepConv):
238 | end_point = end_point_base + '_depthwise'
239 |
240 | # By passing filters=None
241 | # separable_conv2d produces only a depthwise convolution layer
242 | net = slim.separable_conv2d(net, None, conv_def.kernel,
243 | depth_multiplier=1,
244 | stride=layer_stride,
245 | rate=layer_rate,
246 | normalizer_fn=slim.batch_norm,
247 | scope=end_point)
248 |
249 | end_points[end_point] = net
250 | if end_point == final_endpoint:
251 | return net, end_points
252 |
253 | end_point = end_point_base + '_pointwise'
254 |
255 | net = slim.conv2d(net, depth(conv_def.depth), [1, 1],
256 | stride=1,
257 | normalizer_fn=slim.batch_norm,
258 | scope=end_point)
259 |
260 | end_points[end_point] = net
261 | if end_point == final_endpoint:
262 | return net, end_points
263 | else:
264 | raise ValueError('Unknown convolution type %s for layer %d'
265 | % (conv_def.ltype, i))
266 | raise ValueError('Unknown final endpoint %s' % final_endpoint)
267 |
268 |
269 | def mobilenet_v1(inputs,
270 | num_classes=1000,
271 | dropout_keep_prob=0.999,
272 | is_training=True,
273 | min_depth=8,
274 | depth_multiplier=1.0,
275 | conv_defs=None,
276 | prediction_fn=tf.contrib.layers.softmax,
277 | spatial_squeeze=True,
278 | reuse=None,
279 | scope='MobilenetV1',
280 | global_pool=False):
281 | """Mobilenet v1 model for classification.
282 |
283 | Args:
284 | inputs: a tensor of shape [batch_size, height, width, channels].
285 | num_classes: number of predicted classes. If 0 or None, the logits layer
286 | is omitted and the input features to the logits layer (before dropout)
287 | are returned instead.
288 | dropout_keep_prob: the percentage of activation values that are retained.
289 | is_training: whether is training or not.
290 | min_depth: Minimum depth value (number of channels) for all convolution ops.
291 | Enforced when depth_multiplier < 1, and not an active constraint when
292 | depth_multiplier >= 1.
293 | depth_multiplier: Float multiplier for the depth (number of channels)
294 | for all convolution ops. The value must be greater than zero. Typical
295 | usage will be to set this value in (0, 1) to reduce the number of
296 | parameters or computation cost of the model.
297 | conv_defs: A list of ConvDef namedtuples specifying the net architecture.
298 | prediction_fn: a function to get predictions out of logits.
299 | spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
300 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
301 | reuse: whether or not the network and its variables should be reused. To be
302 | able to reuse 'scope' must be given.
303 | scope: Optional variable_scope.
304 | global_pool: Optional boolean flag to control the avgpooling before the
305 | logits layer. If false or unset, pooling is done with a fixed window
306 | that reduces default-sized inputs to 1x1, while larger inputs lead to
307 | larger outputs. If true, any input size is pooled down to 1x1.
308 |
309 | Returns:
310 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes
311 | is a non-zero integer, or the non-dropped-out input to the logits layer
312 | if num_classes is 0 or None.
313 | end_points: a dictionary from components of the network to the corresponding
314 | activation.
315 |
316 | Raises:
317 | ValueError: Input rank is invalid.
318 | """
319 | input_shape = inputs.get_shape().as_list()
320 | if len(input_shape) != 4:
321 | raise ValueError('Invalid input tensor rank, expected 4, was: %d' %
322 | len(input_shape))
323 |
324 | with tf.variable_scope(scope, 'MobilenetV1', [inputs], reuse=reuse) as scope:
325 | with slim.arg_scope([slim.batch_norm, slim.dropout],
326 | is_training=is_training):
327 | net, end_points = mobilenet_v1_base(inputs, scope=scope,
328 | min_depth=min_depth,
329 | depth_multiplier=depth_multiplier,
330 | conv_defs=conv_defs)
331 | end_points['FeatureMap'] = net
332 | with tf.variable_scope('Logits'):
333 | if global_pool:
334 | # Global average pooling.
335 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
336 | end_points['global_pool'] = net
337 | else:
338 | # Pooling with a fixed kernel size.
339 | kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
340 | net = slim.avg_pool2d(net, kernel_size, padding='VALID',
341 | scope='AvgPool_1a')
342 | end_points['AvgPool_1a'] = net
343 | # 1 x 1 x 1024
344 | net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
345 | end_points['PreLogits'] = net
346 | if not num_classes:
347 | return net, end_points
348 | logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
349 | normalizer_fn=None, scope='Conv2d_1c_1x1')
350 | if spatial_squeeze:
351 | logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
352 | end_points['Logits'] = logits
353 | if prediction_fn:
354 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
355 | return logits, end_points
356 |
357 |
358 | mobilenet_v1.default_image_size = 224
359 |
360 |
361 | def wrapped_partial(func, *args, **kwargs):
362 | partial_func = functools.partial(func, *args, **kwargs)
363 | functools.update_wrapper(partial_func, func)
364 | return partial_func
365 |
366 |
367 | mobilenet_v1_075 = wrapped_partial(mobilenet_v1, depth_multiplier=0.75)
368 | mobilenet_v1_050 = wrapped_partial(mobilenet_v1, depth_multiplier=0.50)
369 | mobilenet_v1_025 = wrapped_partial(mobilenet_v1, depth_multiplier=0.25)
370 |
371 |
372 | def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
373 | """Define kernel size which is automatically reduced for small input.
374 |
375 | If the shape of the input images is unknown at graph construction time this
376 | function assumes that the input images are large enough.
377 |
378 | Args:
379 | input_tensor: input tensor of size [batch_size, height, width, channels].
380 | kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
381 |
382 | Returns:
383 | a tensor with the kernel size.
384 | """
385 | shape = input_tensor.get_shape().as_list()
386 | if shape[1] is None or shape[2] is None:
387 | kernel_size_out = kernel_size
388 | else:
389 | kernel_size_out = [min(shape[1], kernel_size[0]),
390 | min(shape[2], kernel_size[1])]
391 | return kernel_size_out
392 |
393 |
394 | def mobilenet_v1_arg_scope(is_training=True,
395 | weight_decay=0.00004,
396 | stddev=0.09,
397 | regularize_depthwise=False):
398 | """Defines the default MobilenetV1 arg scope.
399 |
400 | Args:
401 | is_training: Whether or not we're training the model.
402 | weight_decay: The weight decay to use for regularizing the model.
403 | stddev: The standard deviation of the trunctated normal weight initializer.
404 | regularize_depthwise: Whether or not apply regularization on depthwise.
405 |
406 | Returns:
407 | An `arg_scope` to use for the mobilenet v1 model.
408 | """
409 | batch_norm_params = {
410 | 'is_training': is_training,
411 | 'center': True,
412 | 'scale': True,
413 | 'decay': 0.9997,
414 | 'epsilon': 0.001,
415 | }
416 |
417 | # Set weight_decay for weights in Conv and DepthSepConv layers.
418 | weights_init = tf.truncated_normal_initializer(stddev=stddev)
419 | regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
420 | if regularize_depthwise:
421 | depthwise_regularizer = regularizer
422 | else:
423 | depthwise_regularizer = None
424 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
425 | weights_initializer=weights_init,
426 | activation_fn=tf.nn.relu6, normalizer_fn=slim.batch_norm):
427 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
428 | with slim.arg_scope([slim.conv2d], weights_regularizer=regularizer):
429 | with slim.arg_scope([slim.separable_conv2d],
430 | weights_regularizer=depthwise_regularizer) as sc:
431 | return sc
432 |
--------------------------------------------------------------------------------
/deployment/model_deploy.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Deploy Slim models across multiple clones and replicas.
16 |
17 | # TODO(sguada) docstring paragraph by (a) motivating the need for the file and
18 | # (b) defining clones.
19 |
20 | # TODO(sguada) describe the high-level components of model deployment.
21 | # E.g. "each model deployment is composed of several parts: a DeploymentConfig,
22 | # which captures A, B and C, an input_fn which loads data.. etc
23 |
24 | To easily train a model on multiple GPUs or across multiple machines this
25 | module provides a set of helper functions: `create_clones`,
26 | `optimize_clones` and `deploy`.
27 |
28 | Usage:
29 |
30 | g = tf.Graph()
31 |
32 | # Set up DeploymentConfig
33 | config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True)
34 |
35 | # Create the global step on the device storing the variables.
36 | with tf.device(config.variables_device()):
37 | global_step = slim.create_global_step()
38 |
39 | # Define the inputs
40 | with tf.device(config.inputs_device()):
41 | images, labels = LoadData(...)
42 | inputs_queue = slim.data.prefetch_queue((images, labels))
43 |
44 | # Define the optimizer.
45 | with tf.device(config.optimizer_device()):
46 | optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
47 |
48 | # Define the model including the loss.
49 | def model_fn(inputs_queue):
50 | images, labels = inputs_queue.dequeue()
51 | predictions = CreateNetwork(images)
52 | slim.losses.log_loss(predictions, labels)
53 |
54 | model_dp = model_deploy.deploy(config, model_fn, [inputs_queue],
55 | optimizer=optimizer)
56 |
57 | # Run training.
58 | slim.learning.train(model_dp.train_op, my_log_dir,
59 | summary_op=model_dp.summary_op)
60 |
61 | The Clone namedtuple holds together the values associated with each call to
62 | model_fn:
63 | * outputs: The return values of the calls to `model_fn()`.
64 | * scope: The scope used to create the clone.
65 | * device: The device used to create the clone.
66 |
67 | DeployedModel namedtuple, holds together the values needed to train multiple
68 | clones:
69 | * train_op: An operation that run the optimizer training op and include
70 | all the update ops created by `model_fn`. Present only if an optimizer
71 | was specified.
72 | * summary_op: An operation that run the summaries created by `model_fn`
73 | and process_gradients.
74 | * total_loss: A `Tensor` that contains the sum of all losses created by
75 | `model_fn` plus the regularization losses.
76 | * clones: List of `Clone` tuples returned by `create_clones()`.
77 |
78 | DeploymentConfig parameters:
79 | * num_clones: Number of model clones to deploy in each replica.
80 | * clone_on_cpu: True if clones should be placed on CPU.
81 | * replica_id: Integer. Index of the replica for which the model is
82 | deployed. Usually 0 for the chief replica.
83 | * num_replicas: Number of replicas to use.
84 | * num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
85 | * worker_job_name: A name for the worker job.
86 | * ps_job_name: A name for the parameter server job.
87 |
88 | TODO(sguada):
89 | - describe side effect to the graph.
90 | - what happens to summaries and update_ops.
91 | - which graph collections are altered.
92 | - write a tutorial on how to use this.
93 | - analyze the possibility of calling deploy more than once.
94 |
95 |
96 | """
97 |
98 | from __future__ import absolute_import
99 | from __future__ import division
100 | from __future__ import print_function
101 |
102 | import collections
103 |
104 | import tensorflow as tf
105 |
106 | from tensorflow.python.ops import control_flow_ops
107 |
108 | slim = tf.contrib.slim
109 |
110 | __all__ = ['create_clones',
111 | 'deploy',
112 | 'optimize_clones',
113 | 'DeployedModel',
114 | 'DeploymentConfig',
115 | 'Clone',
116 | ]
117 |
118 | # Namedtuple used to represent a clone during deployment.
119 | Clone = collections.namedtuple('Clone',
120 | ['outputs', # Whatever model_fn() returned.
121 | 'scope', # The scope used to create it.
122 | 'device', # The device used to create.
123 | ])
124 |
125 | # Namedtuple used to represent a DeployedModel, returned by deploy().
126 | DeployedModel = collections.namedtuple('DeployedModel',
127 | ['train_op', # The `train_op`
128 | 'summary_op', # The `summary_op`
129 | 'total_loss', # The loss `Tensor`
130 | 'clones', # A list of `Clones` tuples.
131 | ])
132 |
133 | # Default parameters for DeploymentConfig
134 | _deployment_params = {'num_clones': 1,
135 | 'clone_on_cpu': False,
136 | 'replica_id': 0,
137 | 'num_replicas': 1,
138 | 'num_ps_tasks': 0,
139 | 'worker_job_name': 'worker',
140 | 'ps_job_name': 'ps'}
141 |
142 |
143 | def create_clones(config, model_fn, args=None, kwargs=None):
144 | """Creates multiple clones according to config using a `model_fn`.
145 |
146 | The returned values of `model_fn(*args, **kwargs)` are collected along with
147 | the scope and device used to created it in a namedtuple
148 | `Clone(outputs, scope, device)`
149 |
150 | Note: it is assumed that any loss created by `model_fn` is collected at
151 | the tf.GraphKeys.LOSSES collection.
152 |
153 | To recover the losses, summaries or update_ops created by the clone use:
154 | ```python
155 | losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
156 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope)
157 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
158 | ```
159 |
160 | The deployment options are specified by the config object and support
161 | deploying one or several clones on different GPUs and one or several replicas
162 | of such clones.
163 |
164 | The argument `model_fn` is called `config.num_clones` times to create the
165 | model clones as `model_fn(*args, **kwargs)`.
166 |
167 | If `config` specifies deployment on multiple replicas then the default
168 | tensorflow device is set appropriatly for each call to `model_fn` and for the
169 | slim variable creation functions: model and global variables will be created
170 | on the `ps` device, the clone operations will be on the `worker` device.
171 |
172 | Args:
173 | config: A DeploymentConfig object.
174 | model_fn: A callable. Called as `model_fn(*args, **kwargs)`
175 | args: Optional list of arguments to pass to `model_fn`.
176 | kwargs: Optional list of keyword arguments to pass to `model_fn`.
177 |
178 | Returns:
179 | A list of namedtuples `Clone`.
180 | """
181 | clones = []
182 | args = args or []
183 | kwargs = kwargs or {}
184 | with slim.arg_scope([slim.model_variable, slim.variable],
185 | device=config.variables_device()):
186 | # Create clones.
187 | for i in range(0, config.num_clones):
188 | with tf.name_scope(config.clone_scope(i)) as clone_scope:
189 | clone_device = config.clone_device(i)
190 | with tf.device(clone_device):
191 | with tf.variable_scope(tf.get_variable_scope(),
192 | reuse=True if i > 0 else None):
193 | outputs = model_fn(*args, **kwargs)
194 | clones.append(Clone(outputs, clone_scope, clone_device))
195 | return clones
196 |
197 |
198 | def _gather_clone_loss(clone, num_clones, regularization_losses):
199 | """Gather the loss for a single clone.
200 |
201 | Args:
202 | clone: A Clone namedtuple.
203 | num_clones: The number of clones being deployed.
204 | regularization_losses: Possibly empty list of regularization_losses
205 | to add to the clone losses.
206 |
207 | Returns:
208 | A tensor for the total loss for the clone. Can be None.
209 | """
210 | # The return value.
211 | sum_loss = None
212 | # Individual components of the loss that will need summaries.
213 | clone_loss = None
214 | regularization_loss = None
215 | # Compute and aggregate losses on the clone device.
216 | with tf.device(clone.device):
217 | all_losses = []
218 | clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
219 | if clone_losses:
220 | clone_loss = tf.add_n(clone_losses, name='clone_loss')
221 | if num_clones > 1:
222 | clone_loss = tf.div(clone_loss, 1.0 * num_clones,
223 | name='scaled_clone_loss')
224 | all_losses.append(clone_loss)
225 | if regularization_losses:
226 | regularization_loss = tf.add_n(regularization_losses,
227 | name='regularization_loss')
228 | all_losses.append(regularization_loss)
229 | if all_losses:
230 | sum_loss = tf.add_n(all_losses)
231 | # Add the summaries out of the clone device block.
232 | if clone_loss is not None:
233 | tf.summary.scalar(clone.scope + '/clone_loss', clone_loss)
234 | if regularization_loss is not None:
235 | tf.summary.scalar('regularization_loss', regularization_loss)
236 | return sum_loss
237 |
238 |
239 | def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
240 | **kwargs):
241 | """Compute losses and gradients for a single clone.
242 |
243 | Args:
244 | optimizer: A tf.Optimizer object.
245 | clone: A Clone namedtuple.
246 | num_clones: The number of clones being deployed.
247 | regularization_losses: Possibly empty list of regularization_losses
248 | to add to the clone losses.
249 | **kwargs: Dict of kwarg to pass to compute_gradients().
250 |
251 | Returns:
252 | A tuple (clone_loss, clone_grads_and_vars).
253 | - clone_loss: A tensor for the total loss for the clone. Can be None.
254 | - clone_grads_and_vars: List of (gradient, variable) for the clone.
255 | Can be empty.
256 | """
257 | sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses)
258 | clone_grad = None
259 | if sum_loss is not None:
260 | with tf.device(clone.device):
261 | clone_grad = optimizer.compute_gradients(sum_loss, **kwargs)
262 | return sum_loss, clone_grad
263 |
264 |
265 | def optimize_clones(clones, optimizer,
266 | regularization_losses=None,
267 | **kwargs):
268 | """Compute clone losses and gradients for the given list of `Clones`.
269 |
270 | Note: The regularization_losses are added to the first clone losses.
271 |
272 | Args:
273 | clones: List of `Clones` created by `create_clones()`.
274 | optimizer: An `Optimizer` object.
275 | regularization_losses: Optional list of regularization losses. If None it
276 | will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
277 | exclude them.
278 | **kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
279 |
280 | Returns:
281 | A tuple (total_loss, grads_and_vars).
282 | - total_loss: A Tensor containing the average of the clone losses including
283 | the regularization loss.
284 | - grads_and_vars: A List of tuples (gradient, variable) containing the sum
285 | of the gradients for each variable.
286 |
287 | """
288 | grads_and_vars = []
289 | clones_losses = []
290 | num_clones = len(clones)
291 | if regularization_losses is None:
292 | regularization_losses = tf.get_collection(
293 | tf.GraphKeys.REGULARIZATION_LOSSES)
294 | for clone in clones:
295 | with tf.name_scope(clone.scope):
296 | clone_loss, clone_grad = _optimize_clone(
297 | optimizer, clone, num_clones, regularization_losses, **kwargs)
298 | if clone_loss is not None:
299 | clones_losses.append(clone_loss)
300 | grads_and_vars.append(clone_grad)
301 | # Only use regularization_losses for the first clone
302 | regularization_losses = None
303 | # Compute the total_loss summing all the clones_losses.
304 | total_loss = tf.add_n(clones_losses, name='total_loss')
305 | # Sum the gradients across clones.
306 | grads_and_vars = _sum_clones_gradients(grads_and_vars)
307 | return total_loss, grads_and_vars
308 |
309 |
310 | def deploy(config,
311 | model_fn,
312 | args=None,
313 | kwargs=None,
314 | optimizer=None,
315 | summarize_gradients=False):
316 | """Deploys a Slim-constructed model across multiple clones.
317 |
318 | The deployment options are specified by the config object and support
319 | deploying one or several clones on different GPUs and one or several replicas
320 | of such clones.
321 |
322 | The argument `model_fn` is called `config.num_clones` times to create the
323 | model clones as `model_fn(*args, **kwargs)`.
324 |
325 | The optional argument `optimizer` is an `Optimizer` object. If not `None`,
326 | the deployed model is configured for training with that optimizer.
327 |
328 | If `config` specifies deployment on multiple replicas then the default
329 | tensorflow device is set appropriatly for each call to `model_fn` and for the
330 | slim variable creation functions: model and global variables will be created
331 | on the `ps` device, the clone operations will be on the `worker` device.
332 |
333 | Args:
334 | config: A `DeploymentConfig` object.
335 | model_fn: A callable. Called as `model_fn(*args, **kwargs)`
336 | args: Optional list of arguments to pass to `model_fn`.
337 | kwargs: Optional list of keyword arguments to pass to `model_fn`.
338 | optimizer: Optional `Optimizer` object. If passed the model is deployed
339 | for training with that optimizer.
340 | summarize_gradients: Whether or not add summaries to the gradients.
341 |
342 | Returns:
343 | A `DeployedModel` namedtuple.
344 |
345 | """
346 | # Gather initial summaries.
347 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
348 |
349 | # Create Clones.
350 | clones = create_clones(config, model_fn, args, kwargs)
351 | first_clone = clones[0]
352 |
353 | # Gather update_ops from the first clone. These contain, for example,
354 | # the updates for the batch_norm variables created by model_fn.
355 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope)
356 |
357 | train_op = None
358 | total_loss = None
359 | with tf.device(config.optimizer_device()):
360 | if optimizer:
361 | # Place the global step on the device storing the variables.
362 | with tf.device(config.variables_device()):
363 | global_step = slim.get_or_create_global_step()
364 |
365 | # Compute the gradients for the clones.
366 | total_loss, clones_gradients = optimize_clones(clones, optimizer)
367 |
368 | if clones_gradients:
369 | if summarize_gradients:
370 | # Add summaries to the gradients.
371 | summaries |= set(_add_gradients_summaries(clones_gradients))
372 |
373 | # Create gradient updates.
374 | grad_updates = optimizer.apply_gradients(clones_gradients,
375 | global_step=global_step)
376 | update_ops.append(grad_updates)
377 |
378 | update_op = tf.group(*update_ops)
379 | train_op = control_flow_ops.with_dependencies([update_op], total_loss,
380 | name='train_op')
381 | else:
382 | clones_losses = []
383 | regularization_losses = tf.get_collection(
384 | tf.GraphKeys.REGULARIZATION_LOSSES)
385 | for clone in clones:
386 | with tf.name_scope(clone.scope):
387 | clone_loss = _gather_clone_loss(clone, len(clones),
388 | regularization_losses)
389 | if clone_loss is not None:
390 | clones_losses.append(clone_loss)
391 | # Only use regularization_losses for the first clone
392 | regularization_losses = None
393 | if clones_losses:
394 | total_loss = tf.add_n(clones_losses, name='total_loss')
395 |
396 | # Add the summaries from the first clone. These contain the summaries
397 | # created by model_fn and either optimize_clones() or _gather_clone_loss().
398 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
399 | first_clone.scope))
400 |
401 | if total_loss is not None:
402 | # Add total_loss to summary.
403 | summaries.add(tf.summary.scalar('total_loss', total_loss))
404 |
405 | if summaries:
406 | # Merge all summaries together.
407 | summary_op = tf.summary.merge(list(summaries), name='summary_op')
408 | else:
409 | summary_op = None
410 |
411 | return DeployedModel(train_op, summary_op, total_loss, clones)
412 |
413 |
414 | def _sum_clones_gradients(clone_grads):
415 | """Calculate the sum gradient for each shared variable across all clones.
416 |
417 | This function assumes that the clone_grads has been scaled appropriately by
418 | 1 / num_clones.
419 |
420 | Args:
421 | clone_grads: A List of List of tuples (gradient, variable), one list per
422 | `Clone`.
423 |
424 | Returns:
425 | List of tuples of (gradient, variable) where the gradient has been summed
426 | across all clones.
427 | """
428 | sum_grads = []
429 | for grad_and_vars in zip(*clone_grads):
430 | # Note that each grad_and_vars looks like the following:
431 | # ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN))
432 | grads = []
433 | var = grad_and_vars[0][1]
434 | for g, v in grad_and_vars:
435 | assert v == var
436 | if g is not None:
437 | grads.append(g)
438 | if grads:
439 | if len(grads) > 1:
440 | sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads')
441 | else:
442 | sum_grad = grads[0]
443 | sum_grads.append((sum_grad, var))
444 | return sum_grads
445 |
446 |
447 | def _add_gradients_summaries(grads_and_vars):
448 | """Add histogram summaries to gradients.
449 |
450 | Note: The summaries are also added to the SUMMARIES collection.
451 |
452 | Args:
453 | grads_and_vars: A list of gradient to variable pairs (tuples).
454 |
455 | Returns:
456 | The _list_ of the added summaries for grads_and_vars.
457 | """
458 | summaries = []
459 | for grad, var in grads_and_vars:
460 | if grad is not None:
461 | if isinstance(grad, tf.IndexedSlices):
462 | grad_values = grad.values
463 | else:
464 | grad_values = grad
465 | summaries.append(tf.summary.histogram(var.op.name + ':gradient',
466 | grad_values))
467 | summaries.append(tf.summary.histogram(var.op.name + ':gradient_norm',
468 | tf.global_norm([grad_values])))
469 | else:
470 | tf.logging.info('Var %s has no gradient', var.op.name)
471 | return summaries
472 |
473 |
474 | class DeploymentConfig(object):
475 | """Configuration for deploying a model with `deploy()`.
476 |
477 | You can pass an instance of this class to `deploy()` to specify exactly
478 | how to deploy the model to build. If you do not pass one, an instance built
479 | from the default deployment_hparams will be used.
480 | """
481 |
482 | def __init__(self,
483 | num_clones=1,
484 | clone_on_cpu=False,
485 | replica_id=0,
486 | num_replicas=1,
487 | num_ps_tasks=0,
488 | worker_job_name='worker',
489 | ps_job_name='ps'):
490 | """Create a DeploymentConfig.
491 |
492 | The config describes how to deploy a model across multiple clones and
493 | replicas. The model will be replicated `num_clones` times in each replica.
494 | If `clone_on_cpu` is True, each clone will placed on CPU.
495 |
496 | If `num_replicas` is 1, the model is deployed via a single process. In that
497 | case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored.
498 |
499 | If `num_replicas` is greater than 1, then `worker_device` and `ps_device`
500 | must specify TensorFlow devices for the `worker` and `ps` jobs and
501 | `num_ps_tasks` must be positive.
502 |
503 | Args:
504 | num_clones: Number of model clones to deploy in each replica.
505 | clone_on_cpu: If True clones would be placed on CPU.
506 | replica_id: Integer. Index of the replica for which the model is
507 | deployed. Usually 0 for the chief replica.
508 | num_replicas: Number of replicas to use.
509 | num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
510 | worker_job_name: A name for the worker job.
511 | ps_job_name: A name for the parameter server job.
512 |
513 | Raises:
514 | ValueError: If the arguments are invalid.
515 | """
516 | if num_replicas > 1:
517 | if num_ps_tasks < 1:
518 | raise ValueError('When using replicas num_ps_tasks must be positive')
519 | if num_replicas > 1 or num_ps_tasks > 0:
520 | if not worker_job_name:
521 | raise ValueError('Must specify worker_job_name when using replicas')
522 | if not ps_job_name:
523 | raise ValueError('Must specify ps_job_name when using parameter server')
524 | if replica_id >= num_replicas:
525 | raise ValueError('replica_id must be less than num_replicas')
526 | self._num_clones = num_clones
527 | self._clone_on_cpu = clone_on_cpu
528 | self._replica_id = replica_id
529 | self._num_replicas = num_replicas
530 | self._num_ps_tasks = num_ps_tasks
531 | self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else ''
532 | self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else ''
533 |
534 | @property
535 | def num_clones(self):
536 | return self._num_clones
537 |
538 | @property
539 | def clone_on_cpu(self):
540 | return self._clone_on_cpu
541 |
542 | @property
543 | def replica_id(self):
544 | return self._replica_id
545 |
546 | @property
547 | def num_replicas(self):
548 | return self._num_replicas
549 |
550 | @property
551 | def num_ps_tasks(self):
552 | return self._num_ps_tasks
553 |
554 | @property
555 | def ps_device(self):
556 | return self._ps_device
557 |
558 | @property
559 | def worker_device(self):
560 | return self._worker_device
561 |
562 | def caching_device(self):
563 | """Returns the device to use for caching variables.
564 |
565 | Variables are cached on the worker CPU when using replicas.
566 |
567 | Returns:
568 | A device string or None if the variables do not need to be cached.
569 | """
570 | if self._num_ps_tasks > 0:
571 | return lambda op: op.device
572 | else:
573 | return None
574 |
575 | def clone_device(self, clone_index):
576 | """Device used to create the clone and all the ops inside the clone.
577 |
578 | Args:
579 | clone_index: Int, representing the clone_index.
580 |
581 | Returns:
582 | A value suitable for `tf.device()`.
583 |
584 | Raises:
585 | ValueError: if `clone_index` is greater or equal to the number of clones".
586 | """
587 | if clone_index >= self._num_clones:
588 | raise ValueError('clone_index must be less than num_clones')
589 | device = ''
590 | if self._num_ps_tasks > 0:
591 | device += self._worker_device
592 | if self._clone_on_cpu:
593 | device += '/device:CPU:0'
594 | else:
595 | if self._num_clones > 1:
596 | device += '/device:GPU:%d' % clone_index
597 | return device
598 |
599 | def clone_scope(self, clone_index):
600 | """Name scope to create the clone.
601 |
602 | Args:
603 | clone_index: Int, representing the clone_index.
604 |
605 | Returns:
606 | A name_scope suitable for `tf.name_scope()`.
607 |
608 | Raises:
609 | ValueError: if `clone_index` is greater or equal to the number of clones".
610 | """
611 | if clone_index >= self._num_clones:
612 | raise ValueError('clone_index must be less than num_clones')
613 | scope = ''
614 | if self._num_clones > 1:
615 | scope = 'clone_%d' % clone_index
616 | return scope
617 |
618 | def optimizer_device(self):
619 | """Device to use with the optimizer.
620 |
621 | Returns:
622 | A value suitable for `tf.device()`.
623 | """
624 | if self._num_ps_tasks > 0 or self._num_clones > 0:
625 | return self._worker_device + '/device:CPU:0'
626 | else:
627 | return ''
628 |
629 | def inputs_device(self):
630 | """Device to use to build the inputs.
631 |
632 | Returns:
633 | A value suitable for `tf.device()`.
634 | """
635 | device = ''
636 | if self._num_ps_tasks > 0:
637 | device += self._worker_device
638 | device += '/device:CPU:0'
639 | return device
640 |
641 | def variables_device(self):
642 | """Returns the device to use for variables created inside the clone.
643 |
644 | Returns:
645 | A value suitable for `tf.device()`.
646 | """
647 | device = ''
648 | if self._num_ps_tasks > 0:
649 | device += self._ps_device
650 | device += '/device:CPU:0'
651 |
652 | class _PSDeviceChooser(object):
653 | """Slim device chooser for variables when using PS."""
654 |
655 | def __init__(self, device, tasks):
656 | self._device = device
657 | self._tasks = tasks
658 | self._task = 0
659 |
660 | def choose(self, op):
661 | if op.device:
662 | return op.device
663 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def
664 | if node_def.op == 'Variable':
665 | t = self._task
666 | self._task = (self._task + 1) % self._tasks
667 | d = '%s/task:%d' % (self._device, t)
668 | return d
669 | else:
670 | return op.device
671 |
672 | if not self._num_ps_tasks:
673 | return device
674 | else:
675 | chooser = _PSDeviceChooser(device, self._num_ps_tasks)
676 | return chooser.choose
677 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
--------------------------------------------------------------------------------