├── src ├── __init__.py ├── nets │ ├── __init__.py │ ├── nets_factory.py │ └── yolo_v2.py ├── utils │ ├── __init__.py │ ├── draw_boxes.py │ ├── train_utils.py │ └── tf_utils.py ├── datasets │ ├── __init__.py │ ├── voc_2012.py │ ├── voc_2007.py │ ├── dataset_factory.py │ ├── voc_common.py │ ├── decorations.py │ ├── dataset_utils.py │ ├── decorations_to_tfrecords.py │ ├── imagenet_1000.py │ └── voc_to_tfrecords.py ├── deployment │ ├── __init__.py │ ├── model_deploy_test.py │ └── model_deploy.py ├── preprocessing │ ├── __init__.py │ ├── preprocessing_factory.py │ └── yolo_v2_preprocessing.py └── train.py ├── README.md ├── .gitignore └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/deployment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLO2TensorFlow 2 | 3 | YOLOv2 implemented by TensorFlow 4 | 5 | 6 | TODO: 7 | 8 | 1. 数据增强 9 | 2. 损失函数优化 10 | 3. 多GPU部署训练 11 | 4. 性能评估脚本 12 | 5. detector.py 13 | 6. preprocess_bboxes 优化 14 | 7. draw box 15 | 16 | 17 | 18 | # References: 19 | 20 | [darknet](https://github.com/pjreddie/darknet) 21 | [darkflow](https://github.com/thtrieu/darkflow) 22 | [YAD2K](https://github.com/allanzelener/YAD2K) 23 | [YOLOv1 paper](https://arxiv.org/abs/1506.02640) 24 | [YOLOv2 paper](https://arxiv.org/abs/1612.08242) 25 | -------------------------------------------------------------------------------- /src/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from preprocessing import yolo_v2_preprocessing 7 | 8 | slim = tf.contrib.slim 9 | 10 | 11 | def get_preprocessing(name, is_training=False): 12 | """Returns preprocessing_fn(image, height, width, **kwargs). 13 | 14 | Args: 15 | name: The name of the preprocessing function. 16 | is_training: `True` if the model is being used for training and `False` 17 | otherwise. 18 | 19 | Returns: 20 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 21 | It has the following signature: 22 | image = preprocessing_fn(image, output_height, output_width, ...). 23 | 24 | Raises: 25 | ValueError: If Preprocessing `name` is not recognized. 26 | """ 27 | preprocessing_fn_map = { 28 | 'yolo_v2': yolo_v2_preprocessing, 29 | } 30 | 31 | if name not in preprocessing_fn_map: 32 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 33 | 34 | def preprocessing_fn(image, labels, bboxes, out_shape, data_format='NHWC', **kwargs): 35 | return preprocessing_fn_map[name].preprocess_image(image, labels, bboxes, out_shape, data_format=data_format, 36 | is_training=is_training, **kwargs) 37 | 38 | return preprocessing_fn 39 | -------------------------------------------------------------------------------- /src/datasets/voc_2012.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from datasets import voc_common 3 | 4 | slim = tf.contrib.slim 5 | 6 | FILE_PATTERN = '*.tfrecords' 7 | ITEMS_TO_DESCRIPTIONS = { 8 | 'image': 'A color image of varying height and width.', 9 | 'shape': 'Shape of the image', 10 | 'object/bbox': 'A list of bounding boxes, one per each object.', 11 | 'object/label': 'A list of labels, one per each object.', 12 | } 13 | 14 | SPLITS_TO_SIZES = { 15 | 'train': 5717, 16 | 'val':5823 17 | } 18 | 19 | 20 | MAX_BOX_NUM_PER_IMAGE = { 21 | 'train': 56, 22 | 'val': 42, 23 | } 24 | 25 | NUM_CLASSES = 20 26 | 27 | 28 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 29 | """Gets a dataset tuple with instructions for reading ImageNet. 30 | 31 | Args: 32 | split_name: A train/test split name. 33 | dataset_dir: The base directory of the dataset sources. 34 | file_pattern: The file pattern to use when matching the dataset sources. 35 | It is assumed that the pattern contains a '%s' string so that the split 36 | name can be inserted. 37 | reader: The TensorFlow reader type. 38 | 39 | Returns: 40 | A `Dataset` namedtuple. 41 | 42 | Raises: 43 | ValueError: if `split_name` is not a valid train/test split. 44 | """ 45 | if not file_pattern: 46 | file_pattern = FILE_PATTERN 47 | return voc_common.get_split(split_name, dataset_dir, 48 | file_pattern, reader, 49 | SPLITS_TO_SIZES, 50 | ITEMS_TO_DESCRIPTIONS, 51 | NUM_CLASSES) 52 | 53 | -------------------------------------------------------------------------------- /src/datasets/voc_2007.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from datasets import voc_common 3 | 4 | slim = tf.contrib.slim 5 | 6 | FILE_PATTERN = '*.tfrecords' 7 | ITEMS_TO_DESCRIPTIONS = { 8 | 'image': 'A color image of varying height and width.', 9 | 'shape': 'Shape of the image', 10 | 'object/bbox': 'A list of bounding boxes, one per each object.', 11 | 'object/label': 'A list of labels, one per each object.', 12 | } 13 | 14 | SPLITS_TO_SIZES = { 15 | 'train': 2501, 16 | 'val': 2510, 17 | 'test': 4952, 18 | } 19 | 20 | MAX_BOX_NUM_PER_IMAGE = { 21 | 'train': 37, 22 | 'val': 42, 23 | 'test': 41 24 | } 25 | 26 | NUM_CLASSES = 20 27 | 28 | 29 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 30 | """Gets a dataset tuple with instructions for reading ImageNet. 31 | 32 | Args: 33 | split_name: A train/test split name. 34 | dataset_dir: The base directory of the dataset sources. 35 | file_pattern: The file pattern to use when matching the dataset sources. 36 | It is assumed that the pattern contains a '%s' string so that the split 37 | name can be inserted. 38 | reader: The TensorFlow reader type. 39 | 40 | Returns: 41 | A `Dataset` namedtuple. 42 | 43 | Raises: 44 | ValueError: if `split_name` is not a valid train/test split. 45 | """ 46 | if not file_pattern: 47 | file_pattern = FILE_PATTERN 48 | return voc_common.get_split(split_name, dataset_dir, 49 | file_pattern, reader, 50 | SPLITS_TO_SIZES, 51 | ITEMS_TO_DESCRIPTIONS, 52 | NUM_CLASSES) 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /src/datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import voc_2007,voc_2012,imagenet_1000,decorations 22 | 23 | datasets_map = { 24 | 'voc_2007': voc_2007, 25 | 'voc_2012': voc_2012, 26 | 'imagenet_1000':imagenet_1000, 27 | 'decorations':decorations, 28 | } 29 | 30 | 31 | def get_box_num_per_image(name, split_name): 32 | return datasets_map[name].MAX_BOX_NUM_PER_IMAGE.get(split_name) 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | 45 | Returns: 46 | A `Dataset` class. 47 | 48 | Raises: 49 | ValueError: If the dataset `name` is unknown. 50 | """ 51 | if name not in datasets_map: 52 | raise ValueError('Name of dataset unknown %s' % name) 53 | return datasets_map[name].get_split( 54 | split_name, 55 | dataset_dir, 56 | file_pattern, 57 | reader) 58 | -------------------------------------------------------------------------------- /src/nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import yolo_v2 25 | 26 | 27 | slim = tf.contrib.slim 28 | 29 | networks_map = {'yolo_v2': yolo_v2.yolo_v2, 30 | } 31 | 32 | arg_scopes_map = {'yolo_v2': yolo_v2.yolo_v2_arg_scope, 33 | } 34 | 35 | 36 | def get_network_fn(name, num_classes, is_training=False, **kwargs): 37 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 38 | 39 | Args: 40 | name: The name of the network. 41 | num_classes: The number of classes to use for classification. 42 | is_training: `True` if the model is being used for training and `False` 43 | otherwise. 44 | weight_decay: The l2 coefficient for the model weights. 45 | Returns: 46 | network_fn: A function that applies the model to a batch of images. It has 47 | the following signature: logits, end_points = network_fn(images) 48 | Raises: 49 | ValueError: If network `name` is not recognized. 50 | """ 51 | if name not in networks_map: 52 | raise ValueError('Name of network unknown %s' % name) 53 | func = networks_map[name] 54 | @functools.wraps(func) 55 | def network_fn(images, **kwargs): 56 | arg_scope = arg_scopes_map[name](**kwargs) 57 | with slim.arg_scope(arg_scope): 58 | return func(images, num_classes, is_training=is_training, **kwargs) 59 | if hasattr(func, 'default_image_size'): 60 | network_fn.default_image_size = func.default_image_size 61 | 62 | return network_fn -------------------------------------------------------------------------------- /src/utils/draw_boxes.py: -------------------------------------------------------------------------------- 1 | """Draw predicted or ground truth boxes on input image.""" 2 | 3 | import colorsys 4 | import random 5 | 6 | import numpy as np 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | def get_colors_for_classes(num_classes): 11 | """Return list of random colors for number of classes given.""" 12 | # Use previously generated colors if num_classes is the same. 13 | if (hasattr(get_colors_for_classes, "colors") and 14 | len(get_colors_for_classes.colors) == num_classes): 15 | return get_colors_for_classes.colors 16 | 17 | hsv_tuples = [(x / num_classes, 1., 1.) for x in range(num_classes)] 18 | colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples)) 19 | colors = list( 20 | map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), 21 | colors)) 22 | random.seed(10101) # Fixed seed for consistent colors across runs. 23 | random.shuffle(colors) # Shuffle colors to decorrelate adjacent classes. 24 | random.seed(None) # Reset seed to default. 25 | get_colors_for_classes.colors = colors # Save colors for future calls. 26 | return colors 27 | 28 | 29 | def draw_boxes(image, boxes, box_classes, class_names, scores=None): 30 | """Draw bounding boxes on image. 31 | 32 | Draw bounding boxes with class name and optional box score on image. 33 | 34 | Args: 35 | image: An `array` of shape (width, height, 3) with values in [0, 1]. 36 | boxes: An `array` of shape (num_boxes, 4) containing box corners as 37 | (y_min, x_min, y_max, x_max). 38 | box_classes: A `list` of indicies into `class_names`. 39 | class_names: A `list` of `string` class names. 40 | `scores`: A `list` of scores for each box. 41 | 42 | Returns: 43 | A copy of `image` modified with given bounding boxes. 44 | """ 45 | image = Image.fromarray(np.floor(image * 255 + 0.5).astype('uint8')) 46 | 47 | font = ImageFont.truetype( 48 | font='font/FiraMono-Medium.otf', 49 | size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32')) 50 | thickness = (image.size[0] + image.size[1]) // 300 51 | 52 | colors = get_colors_for_classes(len(class_names)) 53 | 54 | for i, c in list(enumerate(box_classes)): 55 | box_class = class_names[c] 56 | box = boxes[i] 57 | if isinstance(scores, np.ndarray): 58 | score = scores[i] 59 | label = '{} {:.2f}'.format(box_class, score) 60 | else: 61 | label = '{}'.format(box_class) 62 | 63 | draw = ImageDraw.Draw(image) 64 | label_size = draw.textsize(label, font) 65 | 66 | top, left, bottom, right = box 67 | top = max(0, np.floor(top + 0.5).astype('int32')) 68 | left = max(0, np.floor(left + 0.5).astype('int32')) 69 | bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32')) 70 | right = min(image.size[0], np.floor(right + 0.5).astype('int32')) 71 | print(label, (left, top), (right, bottom)) 72 | 73 | if top - label_size[1] >= 0: 74 | text_origin = np.array([left, top - label_size[1]]) 75 | else: 76 | text_origin = np.array([left, top + 1]) 77 | 78 | # My kingdom for a good redistributable image drawing library. 79 | for i in range(thickness): 80 | draw.rectangle( 81 | [left + i, top + i, right - i, bottom - i], outline=colors[c]) 82 | draw.rectangle( 83 | [tuple(text_origin), tuple(text_origin + label_size)], 84 | fill=colors[c]) 85 | draw.text(text_origin, label, fill=(0, 0, 0), font=font) 86 | del draw 87 | 88 | return np.array(image) 89 | -------------------------------------------------------------------------------- /src/datasets/voc_common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | slim = tf.contrib.slim 5 | 6 | classes = [ 7 | "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", 8 | "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", 9 | "pottedplant", "sheep", "sofa", "train", "tvmonitor"] 10 | 11 | 12 | def get_split(split_name, dataset_dir, file_pattern, reader, 13 | split_to_sizes, items_to_descriptions, num_classes): 14 | """Gets a dataset tuple with instructions for reading Pascal VOC dataset. 15 | 16 | Args: 17 | split_name: A train/test split name. 18 | dataset_dir: The base directory of the dataset sources. 19 | file_pattern: The file pattern to use when matching the dataset sources. 20 | It is assumed that the pattern contains a '%s' string so that the split 21 | name can be inserted. 22 | reader: The TensorFlow reader type. 23 | 24 | Returns: 25 | A `Dataset` namedtuple. 26 | 27 | Raises: 28 | ValueError: if `split_name` is not a valid train/test split. 29 | """ 30 | if split_name not in split_to_sizes: 31 | raise ValueError('split name %s was not recognized.' % split_name) 32 | file_pattern = os.path.join(dataset_dir, split_name, file_pattern) 33 | 34 | # Allowing None in the signature so that dataset_factory can use the default. 35 | if reader is None: 36 | reader = tf.TFRecordReader 37 | # Features in Pascal VOC TFRecords. 38 | keys_to_features = { 39 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 40 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 41 | 'image/height': tf.FixedLenFeature([1], tf.int64), 42 | 'image/width': tf.FixedLenFeature([1], tf.int64), 43 | 'image/channels': tf.FixedLenFeature([1], tf.int64), 44 | 'image/shape': tf.FixedLenFeature([3], tf.int64), 45 | 'image/box_num': tf.FixedLenFeature([1],tf.int64), 46 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 47 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 48 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 49 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 50 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 51 | 'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64), 52 | 'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64), 53 | } 54 | items_to_handlers = { 55 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 56 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 57 | 'box_num':slim.tfexample_decoder.Tensor('image/box_num',shape=[1]), 58 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 59 | ['xmin', 'ymin', 'xmax', 'ymax'], 'image/object/bbox/'), 60 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), 61 | 'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'), 62 | 'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'), 63 | } 64 | decoder = slim.tfexample_decoder.TFExampleDecoder( 65 | keys_to_features, items_to_handlers) 66 | 67 | labels_to_names = {} 68 | for label in classes: 69 | labels_to_names[classes.index(label)] = label 70 | 71 | return slim.dataset.Dataset( 72 | data_sources=file_pattern, 73 | reader=reader, 74 | decoder=decoder, 75 | num_samples=split_to_sizes[split_name], 76 | items_to_descriptions=items_to_descriptions, 77 | num_classes=num_classes, 78 | labels_to_names=labels_to_names) 79 | -------------------------------------------------------------------------------- /src/datasets/decorations.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | ''' 5 | python src/train.py --train_dir=/tmp/tfmodel \ 6 | --dataset_name=decorations \ 7 | --dataset_dir=/home/paul/Data/decorations/TFRecords/2017 \ 8 | --num_classes=4 \ 9 | --max_number_of_steps=1000 \ 10 | --batch_size=2 \ 11 | ''' 12 | 13 | slim = tf.contrib.slim 14 | 15 | classes = ["glasses", "hat", "package", "tie"] 16 | FILE_PATTERN = '*.tfrecords' 17 | ITEMS_TO_DESCRIPTIONS = { 18 | 'image': 'A color image of varying height and width.', 19 | 'shape': 'Shape of the image', 20 | 'object/bbox': 'A list of bounding boxes, one per each object.', 21 | 'object/label': 'A list of labels, one per each object.', 22 | } 23 | 24 | SPLITS_TO_SIZES = { 25 | 'train': 762, 26 | } 27 | 28 | MAX_BOX_NUM_PER_IMAGE = { 29 | 'train': 9, 30 | } 31 | 32 | NUM_CLASSES = 4 33 | 34 | 35 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None, 36 | split_to_sizes=SPLITS_TO_SIZES, items_to_descriptions=ITEMS_TO_DESCRIPTIONS, num_classes=NUM_CLASSES): 37 | """Gets a dataset tuple with instructions for reading Pascal VOC type dataset. 38 | 39 | Args: 40 | split_name: A train/test split name. 41 | dataset_dir: The base directory of the dataset sources. 42 | file_pattern: The file pattern to use when matching the dataset sources. 43 | It is assumed that the pattern contains a '%s' string so that the split 44 | name can be inserted. 45 | reader: The TensorFlow reader type. 46 | 47 | Returns: 48 | A `Dataset` namedtuple. 49 | 50 | Raises: 51 | ValueError: if `split_name` is not a valid train/test split. 52 | """ 53 | if split_name not in split_to_sizes: 54 | raise ValueError('split name %s was not recognized.' % split_name) 55 | if file_pattern is None: 56 | file_pattern = FILE_PATTERN 57 | 58 | file_pattern = os.path.join(dataset_dir, split_name, file_pattern) 59 | 60 | # Allowing None in the signature so that dataset_factory can use the default. 61 | if reader is None: 62 | reader = tf.TFRecordReader 63 | # Features in Pascal VOC type TFRecords. 64 | keys_to_features = { 65 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 66 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='jpeg'), 67 | 'image/height': tf.FixedLenFeature([1], tf.int64), 68 | 'image/width': tf.FixedLenFeature([1], tf.int64), 69 | 'image/channels': tf.FixedLenFeature([1], tf.int64), 70 | 'image/shape': tf.FixedLenFeature([3], tf.int64), 71 | 'image/box_num': tf.FixedLenFeature([1], tf.int64), 72 | 'image/object/bbox/xmin': tf.VarLenFeature(dtype=tf.float32), 73 | 'image/object/bbox/ymin': tf.VarLenFeature(dtype=tf.float32), 74 | 'image/object/bbox/xmax': tf.VarLenFeature(dtype=tf.float32), 75 | 'image/object/bbox/ymax': tf.VarLenFeature(dtype=tf.float32), 76 | 'image/object/bbox/label': tf.VarLenFeature(dtype=tf.int64), 77 | 'image/object/bbox/difficult': tf.VarLenFeature(dtype=tf.int64), 78 | 'image/object/bbox/truncated': tf.VarLenFeature(dtype=tf.int64), 79 | } 80 | items_to_handlers = { 81 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 82 | 'shape': slim.tfexample_decoder.Tensor('image/shape'), 83 | 'box_num': slim.tfexample_decoder.Tensor('image/box_num', shape=[1]), 84 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 85 | ['xmin', 'ymin', 'xmax', 'ymax'], 'image/object/bbox/'), 86 | 'object/label': slim.tfexample_decoder.Tensor('image/object/bbox/label'), 87 | 'object/difficult': slim.tfexample_decoder.Tensor('image/object/bbox/difficult'), 88 | 'object/truncated': slim.tfexample_decoder.Tensor('image/object/bbox/truncated'), 89 | } 90 | decoder = slim.tfexample_decoder.TFExampleDecoder( 91 | keys_to_features, items_to_handlers) 92 | 93 | labels_to_names = {} 94 | for label in classes: 95 | labels_to_names[classes.index(label)] = label 96 | 97 | return slim.dataset.Dataset( 98 | data_sources=file_pattern, 99 | reader=reader, 100 | decoder=decoder, 101 | num_samples=split_to_sizes[split_name], 102 | items_to_descriptions=items_to_descriptions, 103 | num_classes=num_classes, 104 | labels_to_names=labels_to_names) 105 | -------------------------------------------------------------------------------- /src/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | 28 | LABELS_FILENAME = 'labels.txt' 29 | 30 | 31 | def int64_feature(value): 32 | """Wrapper for inserting int64 features into Example proto. 33 | """ 34 | if not isinstance(value, list): 35 | value = [value] 36 | return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) 37 | 38 | 39 | def float_feature(value): 40 | """Wrapper for inserting float features into Example proto. 41 | """ 42 | if not isinstance(value, list): 43 | value = [value] 44 | return tf.train.Feature(float_list=tf.train.FloatList(value=value)) 45 | 46 | 47 | def bytes_feature(value): 48 | """Wrapper for inserting bytes features into Example proto. 49 | """ 50 | if not isinstance(value, list): 51 | value = [value] 52 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=value)) 53 | 54 | 55 | def image_to_tfexample(image_data, image_format, height, width, class_id): 56 | return tf.train.Example(features=tf.train.Features(feature={ 57 | 'image/encoded': bytes_feature(image_data), 58 | 'image/format': bytes_feature(image_format), 59 | 'image/class/label': int64_feature(class_id), 60 | 'image/height': int64_feature(height), 61 | 'image/width': int64_feature(width), 62 | })) 63 | 64 | 65 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 66 | """Downloads the `tarball_url` and uncompresses it locally. 67 | 68 | Args: 69 | tarball_url: The URL of a tarball file. 70 | dataset_dir: The directory where the temporary files are stored. 71 | """ 72 | filename = tarball_url.split('/')[-1] 73 | filepath = os.path.join(dataset_dir, filename) 74 | 75 | def _progress(count, block_size, total_size): 76 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 77 | filename, float(count * block_size) / float(total_size) * 100.0)) 78 | sys.stdout.flush() 79 | 80 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 81 | print() 82 | statinfo = os.stat(filepath) 83 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 84 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 85 | 86 | 87 | def write_label_file(labels_to_class_names, dataset_dir, 88 | filename=LABELS_FILENAME): 89 | """Writes a file with the list of class names. 90 | 91 | Args: 92 | labels_to_class_names: A map of (integer) labels to class names. 93 | dataset_dir: The directory in which the labels file should be written. 94 | filename: The filename where the class names are written. 95 | """ 96 | labels_filename = os.path.join(dataset_dir, filename) 97 | with tf.gfile.Open(labels_filename, 'w') as f: 98 | for label in labels_to_class_names: 99 | class_name = labels_to_class_names[label] 100 | f.write('%d:%s\n' % (label, class_name)) 101 | 102 | 103 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 104 | """Specifies whether or not the dataset directory contains a label map file. 105 | 106 | Args: 107 | dataset_dir: The directory in which the labels file is found. 108 | filename: The filename where the class names are written. 109 | 110 | Returns: 111 | `True` if the labels file exists and `False` otherwise. 112 | """ 113 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 114 | 115 | 116 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 117 | """Reads the labels file and returns a mapping from ID to class name. 118 | 119 | Args: 120 | dataset_dir: The directory in which the labels file is found. 121 | filename: The filename where the class names are written. 122 | 123 | Returns: 124 | A map from a label (integer) to class name. 125 | """ 126 | labels_filename = os.path.join(dataset_dir, filename) 127 | with tf.gfile.Open(labels_filename, 'r') as f: 128 | lines = f.read().decode() 129 | lines = lines.split('\n') 130 | lines = filter(None, lines) 131 | 132 | labels_to_class_names = {} 133 | for line in lines: 134 | index = line.index(':') 135 | labels_to_class_names[int(line[:index])] = line[index + 1:] 136 | return labels_to_class_names 137 | -------------------------------------------------------------------------------- /src/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | slim = tf.contrib.slim 3 | 4 | 5 | def configure_learning_rate(flags,num_samples_per_epoch, global_step): 6 | """Configures the learning rate. 7 | 8 | Args: 9 | num_samples_per_epoch: The number of samples in each epoch of training. 10 | global_step: The global_step tensor. 11 | 12 | Returns: 13 | A `Tensor` representing the learning rate. 14 | 15 | Raises: 16 | ValueError: if 17 | """ 18 | decay_steps = int(num_samples_per_epoch / flags.batch_size * 19 | flags.num_epochs_per_decay) 20 | if flags.sync_replicas: 21 | decay_steps /= flags.replicas_to_aggregate 22 | 23 | if flags.learning_rate_decay_type == 'exponential': 24 | return tf.train.exponential_decay(flags.learning_rate, 25 | global_step, 26 | decay_steps, 27 | flags.learning_rate_decay_factor, 28 | staircase=True, 29 | name='exponential_decay_learning_rate') 30 | elif flags.learning_rate_decay_type == 'fixed': 31 | return tf.constant(flags.learning_rate, name='fixed_learning_rate') 32 | elif flags.learning_rate_decay_type == 'polynomial': 33 | return tf.train.polynomial_decay(flags.learning_rate, 34 | global_step, 35 | decay_steps, 36 | flags.end_learning_rate, 37 | power=1.0, 38 | cycle=False, 39 | name='polynomial_decay_learning_rate') 40 | else: 41 | raise ValueError('learning_rate_decay_type [%s] was not recognized', 42 | flags.learning_rate_decay_type) 43 | 44 | 45 | def configure_optimizer(flags,learning_rate): 46 | """Configures the optimizer used for training. 47 | 48 | Args: 49 | learning_rate: A scalar or `Tensor` learning rate. 50 | 51 | Returns: 52 | An instance of an optimizer. 53 | 54 | Raises: 55 | ValueError: if flags.optimizer is not recognized. 56 | """ 57 | if flags.optimizer == 'adadelta': 58 | optimizer = tf.train.AdadeltaOptimizer( 59 | learning_rate, 60 | rho=flags.adadelta_rho, 61 | epsilon=flags.opt_epsilon) 62 | elif flags.optimizer == 'adagrad': 63 | optimizer = tf.train.AdagradOptimizer( 64 | learning_rate, 65 | initial_accumulator_value=flags.adagrad_initial_accumulator_value) 66 | elif flags.optimizer == 'adam': 67 | optimizer = tf.train.AdamOptimizer( 68 | learning_rate, 69 | beta1=flags.adam_beta1, 70 | beta2=flags.adam_beta2, 71 | epsilon=flags.opt_epsilon) 72 | elif flags.optimizer == 'ftrl': 73 | optimizer = tf.train.FtrlOptimizer( 74 | learning_rate, 75 | learning_rate_power=flags.ftrl_learning_rate_power, 76 | initial_accumulator_value=flags.ftrl_initial_accumulator_value, 77 | l1_regularization_strength=flags.ftrl_l1, 78 | l2_regularization_strength=flags.ftrl_l2) 79 | elif flags.optimizer == 'momentum': 80 | optimizer = tf.train.MomentumOptimizer( 81 | learning_rate, 82 | momentum=flags.momentum, 83 | name='Momentum') 84 | elif flags.optimizer == 'rmsprop': 85 | optimizer = tf.train.RMSPropOptimizer( 86 | learning_rate, 87 | decay=flags.rmsprop_decay, 88 | momentum=flags.momentum, 89 | epsilon=flags.opt_epsilon) 90 | elif flags.optimizer == 'sgd': 91 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 92 | else: 93 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer) 94 | return optimizer 95 | 96 | def get_init_fn(flags): 97 | """Returns a function run by the chief worker to warm-start the training. 98 | 99 | Note that the init_fn is only run when initializing the model during the very 100 | first global step. 101 | 102 | Returns: 103 | An init function run by the supervisor. 104 | """ 105 | if flags.checkpoint_path is None: 106 | return None 107 | 108 | # Warn the user if a checkpoint exists in the train_dir. Then we'll be 109 | # ignoring the checkpoint anyway. 110 | if tf.train.latest_checkpoint(flags.train_dir): 111 | tf.logging.info( 112 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s' 113 | % flags.train_dir) 114 | return None 115 | 116 | exclusions = [] 117 | if flags.checkpoint_exclude_scopes: 118 | exclusions = [scope.strip() 119 | for scope in flags.checkpoint_exclude_scopes.split(',')] 120 | 121 | # TODO(sguada) variables.filter_variables() 122 | variables_to_restore = [] 123 | for var in slim.get_model_variables(): 124 | excluded = False 125 | for exclusion in exclusions: 126 | if var.op.name.startswith(exclusion): 127 | excluded = True 128 | break 129 | if not excluded: 130 | variables_to_restore.append(var) 131 | 132 | if tf.gfile.IsDirectory(flags.checkpoint_path): 133 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path) 134 | else: 135 | checkpoint_path = flags.checkpoint_path 136 | 137 | tf.logging.info('Fine-tuning from %s' % checkpoint_path) 138 | 139 | return slim.assign_from_checkpoint_fn( 140 | checkpoint_path, 141 | variables_to_restore, 142 | ignore_missing_vars=flags.ignore_missing_vars) 143 | 144 | 145 | def get_variables_to_train(flags): 146 | """Returns a list of variables to train. 147 | 148 | Returns: 149 | A list of variables to train by the optimizer. 150 | """ 151 | if flags.trainable_scopes is None: 152 | return tf.trainable_variables() 153 | else: 154 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')] 155 | 156 | variables_to_train = [] 157 | for scope in scopes: 158 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 159 | variables_to_train.extend(variables) 160 | return variables_to_train -------------------------------------------------------------------------------- /src/datasets/decorations_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | from dataset_utils import int64_feature, float_feature, bytes_feature 2 | import decorations 3 | import math 4 | import os 5 | import xml.etree.ElementTree as ElementTree 6 | import shutil 7 | import tensorflow as tf 8 | 9 | tf.app.flags.DEFINE_string( 10 | 'path_to_decorations', '/home/paul/Data/decorations', 'path to decorations dataset') 11 | tf.app.flags.DEFINE_string( 12 | 'image_list', '/home/paul/Data/decorations/2017.txt', 'image_list') 13 | tf.app.flags.DEFINE_string( 14 | 'year', '2017', '2017') 15 | tf.app.flags.DEFINE_string( 16 | 'type', 'train', 'all or train or val or test') 17 | 18 | FLAGS = tf.app.flags.FLAGS 19 | 20 | # Small graph for image decoding 21 | decoder_sess = tf.Session() 22 | image_placeholder = tf.placeholder(dtype=tf.string) 23 | decoded_jpeg = tf.image.decode_jpeg(image_placeholder, channels=3) 24 | 25 | 26 | def process_image(image_path, anno_path): 27 | image_data = tf.gfile.FastGFile(image_path, 'r').read() 28 | 29 | with open(anno_path) as f: 30 | xml_tree = ElementTree.parse(f) 31 | root = xml_tree.getroot() 32 | 33 | # Image shape. 34 | size = root.find('size') 35 | shape = [int(size.find('height').text), 36 | int(size.find('width').text), 37 | int(size.find('depth').text)] 38 | # Find annotations. 39 | bboxes = [] 40 | labels = [] 41 | labels_text = [] 42 | difficult = [] 43 | truncated = [] 44 | for obj in root.iter('object'): 45 | 46 | label = obj.find('name').text 47 | if label not in decorations.classes: # exclude difficult or unlisted classes 48 | continue 49 | 50 | labels.append(decorations.classes.index(label)) 51 | labels_text.append(label) 52 | 53 | if obj.find('difficult'): 54 | difficult.append(int(obj.find('difficult').text)) 55 | else: 56 | difficult.append(0) 57 | if obj.find('truncated'): 58 | truncated.append(int(obj.find('truncated').text)) 59 | else: 60 | truncated.append(0) 61 | 62 | bbox = obj.find('bndbox') 63 | bboxes.append((float(bbox.find('ymin').text) / shape[0], 64 | float(bbox.find('xmin').text) / shape[1], 65 | float(bbox.find('ymax').text) / shape[0], 66 | float(bbox.find('xmax').text) / shape[1] 67 | )) 68 | 69 | return image_data, shape, bboxes, labels, labels_text, difficult, truncated 70 | 71 | 72 | def convert_to_example(image_data, labels, labels_text, bboxes, shape, 73 | difficult, truncated): 74 | """Build an Example proto for an image example. 75 | 76 | Args: 77 | image_data: string, JPEG encoding of RGB image; 78 | labels: list of integers, identifier for the ground truth; 79 | labels_text: list of strings, human-readable labels; 80 | bboxes: list of bounding boxes; each box is a list of integers; 81 | specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong 82 | to the same label as the image label. 83 | shape: 3 integers, image shapes in pixels. 84 | Returns: 85 | Example proto 86 | """ 87 | xmin = [] 88 | ymin = [] 89 | xmax = [] 90 | ymax = [] 91 | for b in bboxes: 92 | assert len(b) == 4 93 | # pylint: disable=expression-not-assigned 94 | [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)] 95 | # pylint: enable=expression-not-assigned 96 | 97 | image_format = b'JPEG' 98 | example = tf.train.Example(features=tf.train.Features(feature={ 99 | 'image/height': int64_feature(shape[0]), 100 | 'image/width': int64_feature(shape[1]), 101 | 'image/channels': int64_feature(shape[2]), 102 | 'image/shape': int64_feature(shape), 103 | 'image/box_num': int64_feature(len(labels)), 104 | 'image/object/bbox/xmin': float_feature(xmin), 105 | 'image/object/bbox/xmax': float_feature(xmax), 106 | 'image/object/bbox/ymin': float_feature(ymin), 107 | 'image/object/bbox/ymax': float_feature(ymax), 108 | 'image/object/bbox/label': int64_feature(labels), 109 | 'image/object/bbox/label_text': bytes_feature(labels_text), 110 | 'image/object/bbox/difficult': int64_feature(difficult), 111 | 'image/object/bbox/truncated': int64_feature(truncated), 112 | 'image/format': bytes_feature(image_format), 113 | 'image/encoded': bytes_feature(image_data)})) 114 | return example 115 | 116 | 117 | def process_dataset(name, image_paths, anno_paths, result_path, example_num_per_file): 118 | """Process selected Pascal VOC type dataset to generate TFRecords files. 119 | 120 | Parameters 121 | ---------- 122 | name : string 123 | Name of resulting dataset 'train' or 'test'. 124 | image_paths : list 125 | List of paths to images to include in dataset. 126 | anno_paths : list 127 | List of paths to corresponding image annotations. 128 | result_path : string 129 | Path to put resulting TFRecord files. 130 | example_num_per_file : int 131 | how many examples one TFRecord file has. 132 | """ 133 | 134 | total_example_num = len(image_paths) 135 | total_files_num = int(math.ceil(total_example_num / example_num_per_file)) 136 | tfrecords_path_list_f = open(result_path + '/tfrecords_list.txt', 'w') 137 | writer = None 138 | max_box_num = 0 139 | for i in range(0, total_example_num): 140 | if i % example_num_per_file == 0: 141 | if i != 0: 142 | writer.close() 143 | tfrecords_name = '{}-{:05d}-of-{:05d}.tfrecords'.format(name, int(math.ceil(i / example_num_per_file)), 144 | total_files_num) 145 | tfrecords_name = os.path.join(result_path, tfrecords_name) 146 | print(tfrecords_name + '\n') 147 | tfrecords_path_list_f.write(tfrecords_name + '\n') 148 | writer = tf.python_io.TFRecordWriter(tfrecords_name) 149 | 150 | image_file = image_paths[i] 151 | anno_file = anno_paths[i] 152 | 153 | image_data, shape, bboxes, labels, labels_text, difficult, truncated = process_image(image_file, anno_file) 154 | if (len(labels) > max_box_num): 155 | max_box_num = len(labels) 156 | print(max_box_num) 157 | example = convert_to_example(image_data, labels, labels_text, bboxes, shape, difficult, truncated) 158 | 159 | # write to writer 160 | writer.write(example.SerializeToString()) 161 | writer.close() 162 | tfrecords_path_list_f.close() 163 | 164 | 165 | if __name__ == '__main__': 166 | 167 | """Locate files for data sets and then generate TFRecords.""" 168 | path = FLAGS.path_to_decorations 169 | path = os.path.expanduser(path) 170 | 171 | image_paths = [] 172 | anno_paths = [] 173 | 174 | f = open(FLAGS.image_list) 175 | for line in f: 176 | temp = line.rstrip() 177 | image_paths.append(temp) 178 | temp = temp.replace('images', 'annotations') 179 | temp = temp.replace('jpg', 'xml') 180 | anno_paths.append(temp) 181 | f.close() 182 | save_path = os.path.join(path, 'TFRecords', FLAGS.year, FLAGS.type) 183 | if not os.path.exists(save_path): 184 | os.makedirs(save_path) 185 | else: 186 | shutil.rmtree(save_path) 187 | os.makedirs(save_path) 188 | 189 | process_dataset(FLAGS.type, image_paths, anno_paths, save_path, 100) 190 | -------------------------------------------------------------------------------- /src/preprocessing/yolo_v2_preprocessing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | from utils import tf_utils 7 | 8 | slim = tf.contrib.slim 9 | 10 | 11 | def convert_box(bboxes): 12 | x = (bboxes[:, 0] + bboxes[:, 2]) / 2.0 13 | y = (bboxes[:, 1] + bboxes[:, 3]) / 2.0 14 | w = bboxes[:, 2] - bboxes[:, 0] 15 | h = bboxes[:, 3] - bboxes[:, 1] 16 | x = tf.reshape(x, (tf.size(x), 1)) 17 | y = tf.reshape(y, (tf.size(y), 1)) 18 | w = tf.reshape(w, (tf.size(w), 1)) 19 | h = tf.reshape(h, (tf.size(h), 1)) 20 | return x, y, w, h 21 | 22 | 23 | def get_index(index_0, index_1, index_2, index_3): 24 | tf_index = tf.concat( 25 | [tf.reshape(tf.cast(index_0, tf.int64), [1]), 26 | tf.reshape(tf.cast(index_1, tf.int64), [1]), 27 | tf.reshape(tf.cast(index_2, tf.int64), [1]), 28 | tf.reshape(tf.constant(index_3, tf.int64), [1])], 0) 29 | tf_index = tf.reshape(tf_index, [1, 4]) 30 | return tf_index 31 | 32 | 33 | def get_index_and_value_x(index_0, index_1, index_2, box, is_training): 34 | tf_index = get_index(index_0, index_1, index_2, 0) 35 | if is_training: 36 | tf_value = tf.reshape(box[0] - index_0, [1]) 37 | else: 38 | tf_value = tf.reshape(box[0], [1]) 39 | return tf_index, tf_value 40 | 41 | 42 | def get_index_and_value_y(index_0, index_1, index_2, box, is_training): 43 | tf_index = get_index(index_0, index_1, index_2, 1) 44 | if is_training: 45 | tf_value = tf.reshape(box[1] - index_1, [1]) 46 | else: 47 | tf_value = tf.reshape(box[1], [1]) 48 | return tf_index, tf_value 49 | 50 | 51 | def get_index_and_value_w(index_0, index_1, index_2, box, anchor): 52 | tf_index = get_index(index_0, index_1, index_2, 2) 53 | tf_value = tf.reshape(box[2], [1]) 54 | #tf_value = tf.reshape(tf.log(box[2] / anchor[0]), [1]) 55 | return tf_index, tf_value 56 | 57 | 58 | def get_index_and_value_h(index_0, index_1, index_2, box, anchor): 59 | tf_index = get_index(index_0, index_1, index_2, 3) 60 | tf_value = tf.reshape(box[3], [1]) 61 | #tf_value = tf.reshape(tf.log(box[3] / anchor[1]), [1]) 62 | return tf_index, tf_value 63 | 64 | 65 | def get_index_and_value_c(index_0, index_1, index_2, box): 66 | tf_index = get_index(index_0, index_1, index_2, 4) 67 | tf_value = tf.reshape(box[4], [1]) 68 | return tf_index, tf_value 69 | 70 | 71 | def process_gbboxes_with_anchors(gbboxes, image_size, anchors, box_num, is_training): 72 | gbboxes_coor = gbboxes[:, 0:4] * image_size[0] / 32 73 | gbboxes = tf.concat([gbboxes_coor, tf.expand_dims(gbboxes[:, 4], 1)], 1) 74 | 75 | indices = [] 76 | values = [] 77 | 78 | index = tf.floor(gbboxes[:, 0:2]) 79 | 80 | for i in range(box_num): 81 | box = gbboxes[i] 82 | max_iou = tf.constant(0, tf.float32) 83 | index_2 = 0 84 | anchor_wh = tf.constant([0,0],tf.float32) 85 | for j, anchor in enumerate(anchors): 86 | iou = tf_utils.tf_anchor_iou(box, anchor) 87 | max_iou, index_2, anchor_wh = tf.cond(iou > max_iou, lambda: (iou, j, tf.constant(anchor,tf.float32)), 88 | lambda: (max_iou, index_2, anchor_wh)) 89 | 90 | index_0 = index[i, 0] 91 | index_1 = index[i, 1] 92 | 93 | tf_index, tf_value = get_index_and_value_x(index_0, index_1, index_2, box, is_training) 94 | indices.append(tf_index) 95 | values.append(tf_value) 96 | 97 | tf_index, tf_value = get_index_and_value_y(index_0, index_1, index_2, box, is_training) 98 | indices.append(tf_index) 99 | values.append(tf_value) 100 | 101 | tf_index, tf_value = get_index_and_value_w(index_0, index_1, index_2, box, anchor_wh) 102 | indices.append(tf_index) 103 | values.append(tf_value) 104 | 105 | tf_index, tf_value = get_index_and_value_h(index_0, index_1, index_2, box, anchor_wh) 106 | indices.append(tf_index) 107 | values.append(tf_value) 108 | 109 | tf_index, tf_value = get_index_and_value_c(index_0, index_1, index_2, box) 110 | indices.append(tf_index) 111 | values.append(tf_value) 112 | 113 | for temp_index in range(len(indices)): 114 | if temp_index == 0: 115 | tf_indices = indices[temp_index] 116 | else: 117 | tf_indices = tf.concat([tf_indices, indices[temp_index]], 0) 118 | 119 | print(tf_indices) 120 | 121 | for temp_index in range(len(values)): 122 | if temp_index == 0: 123 | tf_values = values[temp_index] 124 | else: 125 | tf_values = tf.concat([tf_values, values[temp_index]], 0) 126 | 127 | print(tf_values) 128 | 129 | boxes = tf.SparseTensor(tf_indices, tf_values, [image_size[0] // 32, image_size[1] // 32, len(anchors), 5]) 130 | 131 | boxes = tf.sparse_tensor_to_dense(boxes, validate_indices=False) 132 | # return boxes, tf_indices, tf_values 133 | return boxes 134 | 135 | 136 | def preprocess_bboxes(labels, bboxes, box_num, is_training): 137 | convert_box(bboxes) 138 | x, y, w, h = convert_box(bboxes) 139 | bboxes = tf.concat([x, y, w, h, tf.cast(tf.reshape(labels, (tf.size(labels), 1)), tf.float32)], 1) 140 | boxes = tf.zeros((box_num - tf.shape(bboxes)[0], 5), tf.float32) 141 | boxes = tf.concat([bboxes, boxes], 0) 142 | boxes = tf.reshape(boxes, (box_num, 5)) 143 | 144 | boxes = process_gbboxes_with_anchors(boxes, [416, 416], [[1, 2], [1, 3], [2, 1], [3, 1], [1, 1]], box_num, is_training) 145 | 146 | return boxes 147 | 148 | 149 | def preprocess_for_train(image, labels, bboxes, out_size, box_num, angle, saturation, exposure, hue, jitter): 150 | resized_image = tf.image.resize_images(image, out_size) 151 | bboxes = preprocess_bboxes(labels, bboxes, box_num, True) 152 | 153 | return resized_image, bboxes 154 | 155 | 156 | def preprocess_for_eval(image, labels, bboxes, out_size, box_num): 157 | resized_image = tf.image.resize_images(image, out_size) 158 | 159 | bboxes = preprocess_bboxes(labels, bboxes, box_num, False) 160 | 161 | 162 | return resized_image, bboxes 163 | 164 | 165 | def preprocess_data(image, labels, bboxes, out_size, box_num, is_training=True, angle=0, 166 | saturation=0, exposure=0, hue=0, 167 | jitter=0): 168 | """Preprocesses the given image. 169 | 170 | Args: 171 | image: A `Tensor` representing an image of arbitrary size. 172 | output_height: The height of the image after preprocessing. 173 | output_width: The width of the image after preprocessing. 174 | is_training: `True` if we're preprocessing the image for training and 175 | `False` otherwise. 176 | resize_side_min: The lower bound for the smallest side of the image for 177 | aspect-preserving resizing. If `is_training` is `False`, then this value 178 | is used for rescaling. 179 | resize_side_max: The upper bound for the smallest side of the image for 180 | aspect-preserving resizing. If `is_training` is `False`, this value is 181 | ignored. Otherwise, the resize side is sampled from 182 | [resize_size_min, resize_size_max]. 183 | 184 | Returns: 185 | A preprocessed image. 186 | """ 187 | if is_training: 188 | return preprocess_for_train(image, labels, bboxes, out_size, box_num, angle, saturation, 189 | exposure, hue, jitter) 190 | else: 191 | return preprocess_for_eval(image, labels, bboxes, out_size, box_num) 192 | -------------------------------------------------------------------------------- /src/datasets/imagenet_1000.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the ImageNet ILSVRC 2012 Dataset plus some bounding boxes. 16 | 17 | Some images have one or more bounding boxes associated with the label of the 18 | image. See details here: http://image-net.org/download-bboxes 19 | 20 | ImageNet is based upon WordNet 3.0. To uniquely identify a synset, we use 21 | "WordNet ID" (wnid), which is a concatenation of POS ( i.e. part of speech ) 22 | and SYNSET OFFSET of WordNet. For more information, please refer to the 23 | WordNet documentation[http://wordnet.princeton.edu/wordnet/documentation/]. 24 | 25 | "There are bounding boxes for over 3000 popular synsets available. 26 | For each synset, there are on average 150 images with bounding boxes." 27 | 28 | WARNING: Don't use for object detection, in this case all the bounding boxes 29 | of the image belong to just one class. 30 | """ 31 | from __future__ import absolute_import 32 | from __future__ import division 33 | from __future__ import print_function 34 | 35 | import os 36 | from six.moves import urllib 37 | import tensorflow as tf 38 | 39 | from datasets import dataset_utils 40 | 41 | 42 | slim = tf.contrib.slim 43 | 44 | # TODO(nsilberman): Add tfrecord file type once the script is updated. 45 | _FILE_PATTERN = '%s-*' 46 | 47 | _SPLITS_TO_SIZES = { 48 | 'train': 1281167, 49 | 'validation': 50000, 50 | } 51 | 52 | _ITEMS_TO_DESCRIPTIONS = { 53 | 'image': 'A color image of varying height and width.', 54 | 'label': 'The label id of the image, integer between 0 and 999', 55 | 'label_text': 'The text of the label.', 56 | 'object/bbox': 'A list of bounding boxes.', 57 | 'object/label': 'A list of labels, one per each object.', 58 | } 59 | 60 | _NUM_CLASSES = 1001 61 | 62 | 63 | def create_readable_names_for_imagenet_labels(): 64 | """Create a dict mapping label id to human readable string. 65 | 66 | Returns: 67 | labels_to_names: dictionary where keys are integers from to 1000 68 | and values are human-readable names. 69 | 70 | We retrieve a synset file, which contains a list of valid synset labels used 71 | by ILSVRC competition. There is one synset one per line, eg. 72 | # n01440764 73 | # n01443537 74 | We also retrieve a synset_to_human_file, which contains a mapping from synsets 75 | to human-readable names for every synset in Imagenet. These are stored in a 76 | tsv format, as follows: 77 | # n02119247 black fox 78 | # n02119359 silver fox 79 | We assign each synset (in alphabetical order) an integer, starting from 1 80 | (since 0 is reserved for the background class). 81 | 82 | Code is based on 83 | https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463 84 | """ 85 | 86 | # pylint: disable=g-line-too-long 87 | base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/inception/inception/data/' 88 | synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url) 89 | synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url) 90 | 91 | filename, _ = urllib.request.urlretrieve(synset_url) 92 | synset_list = [s.strip() for s in open(filename).readlines()] 93 | num_synsets_in_ilsvrc = len(synset_list) 94 | assert num_synsets_in_ilsvrc == 1000 95 | 96 | filename, _ = urllib.request.urlretrieve(synset_to_human_url) 97 | synset_to_human_list = open(filename).readlines() 98 | num_synsets_in_all_imagenet = len(synset_to_human_list) 99 | assert num_synsets_in_all_imagenet == 21842 100 | 101 | synset_to_human = {} 102 | for s in synset_to_human_list: 103 | parts = s.strip().split('\t') 104 | assert len(parts) == 2 105 | synset = parts[0] 106 | human = parts[1] 107 | synset_to_human[synset] = human 108 | 109 | label_index = 1 110 | labels_to_names = {0: 'background'} 111 | for synset in synset_list: 112 | name = synset_to_human[synset] 113 | labels_to_names[label_index] = name 114 | label_index += 1 115 | 116 | return labels_to_names 117 | 118 | 119 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 120 | """Gets a dataset tuple with instructions for reading ImageNet. 121 | 122 | Args: 123 | split_name: A train/test split name. 124 | dataset_dir: The base directory of the dataset sources. 125 | file_pattern: The file pattern to use when matching the dataset sources. 126 | It is assumed that the pattern contains a '%s' string so that the split 127 | name can be inserted. 128 | reader: The TensorFlow reader type. 129 | 130 | Returns: 131 | A `Dataset` namedtuple. 132 | 133 | Raises: 134 | ValueError: if `split_name` is not a valid train/test split. 135 | """ 136 | if split_name not in _SPLITS_TO_SIZES: 137 | raise ValueError('split name %s was not recognized.' % split_name) 138 | 139 | if not file_pattern: 140 | file_pattern = _FILE_PATTERN 141 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 142 | 143 | # Allowing None in the signature so that dataset_factory can use the default. 144 | if reader is None: 145 | reader = tf.TFRecordReader 146 | 147 | keys_to_features = { 148 | 'image/encoded': tf.FixedLenFeature( 149 | (), tf.string, default_value=''), 150 | 'image/format': tf.FixedLenFeature( 151 | (), tf.string, default_value='jpeg'), 152 | 'image/class/label': tf.FixedLenFeature( 153 | [], dtype=tf.int64, default_value=-1), 154 | 'image/class/text': tf.FixedLenFeature( 155 | [], dtype=tf.string, default_value=''), 156 | 'image/object/bbox/xmin': tf.VarLenFeature( 157 | dtype=tf.float32), 158 | 'image/object/bbox/ymin': tf.VarLenFeature( 159 | dtype=tf.float32), 160 | 'image/object/bbox/xmax': tf.VarLenFeature( 161 | dtype=tf.float32), 162 | 'image/object/bbox/ymax': tf.VarLenFeature( 163 | dtype=tf.float32), 164 | 'image/object/class/label': tf.VarLenFeature( 165 | dtype=tf.int64), 166 | } 167 | 168 | items_to_handlers = { 169 | 'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'), 170 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 171 | 'label_text': slim.tfexample_decoder.Tensor('image/class/text'), 172 | 'object/bbox': slim.tfexample_decoder.BoundingBox( 173 | ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'), 174 | 'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'), 175 | } 176 | 177 | decoder = slim.tfexample_decoder.TFExampleDecoder( 178 | keys_to_features, items_to_handlers) 179 | 180 | labels_to_names = None 181 | if dataset_utils.has_labels(dataset_dir): 182 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 183 | else: 184 | labels_to_names = create_readable_names_for_imagenet_labels() 185 | dataset_utils.write_label_file(labels_to_names, dataset_dir) 186 | 187 | return slim.dataset.Dataset( 188 | data_sources=file_pattern, 189 | reader=reader, 190 | decoder=decoder, 191 | num_samples=_SPLITS_TO_SIZES[split_name], 192 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 193 | num_classes=_NUM_CLASSES, 194 | labels_to_names=labels_to_names) 195 | -------------------------------------------------------------------------------- /src/datasets/voc_to_tfrecords.py: -------------------------------------------------------------------------------- 1 | from dataset_utils import int64_feature, float_feature, bytes_feature 2 | import voc_common 3 | import math 4 | import os 5 | import xml.etree.ElementTree as ElementTree 6 | import shutil 7 | import tensorflow as tf 8 | 9 | tf.app.flags.DEFINE_string( 10 | 'path_to_voc', '~/Data/VOC/VOCdevkit/', 'path to Pascal VOC dataset') 11 | tf.app.flags.DEFINE_string( 12 | 'year', 'all', '2007 or 2012 or all') 13 | tf.app.flags.DEFINE_string( 14 | 'type', 'all', 'train or val or test') 15 | 16 | FLAGS = tf.app.flags.FLAGS 17 | 18 | # Small graph for image decoding 19 | decoder_sess = tf.Session() 20 | image_placeholder = tf.placeholder(dtype=tf.string) 21 | decoded_jpeg = tf.image.decode_jpeg(image_placeholder, channels=3) 22 | 23 | 24 | def get_data_set(): 25 | data_set = [] 26 | if (FLAGS.year == 'all'): 27 | if (FLAGS.type == 'all'): 28 | data_set = [('2007', 'train'), ('2007', 'val'), ('2007', 'test'), ('2012', 'train'), ('2012', 'val')] 29 | elif ((FLAGS.type == 'train') | (FLAGS.type == 'val')): 30 | data_set.append(('2007', FLAGS.type)) 31 | data_set.append(('2012', FLAGS.type)) 32 | elif (FLAGS.type == 'test'): 33 | data_set.append(('2007', 'test')) 34 | print("only 2007 has test dataset !\n") 35 | else: 36 | print("unknow dataset type !\n") 37 | elif (FLAGS.year == '2007'): 38 | if (FLAGS.type == 'all'): 39 | data_set = [(FLAGS.year, 'train'), (FLAGS.year, 'val'), (FLAGS.year, 'test')] 40 | elif ((FLAGS.type == 'train') | (FLAGS.type == 'val') | (FLAGS.type == 'test')): 41 | data_set.append((FLAGS.year, FLAGS.type)) 42 | else: 43 | print("unknow dataset type !\n") 44 | elif (FLAGS.year == '2012'): 45 | if (FLAGS.type == 'all'): 46 | data_set = [(FLAGS.year, 'train'), (FLAGS.year, 'val')] 47 | elif ((FLAGS.type == 'train') | (FLAGS.type == 'val')): 48 | data_set.append((FLAGS.year, FLAGS.type)) 49 | elif (FLAGS.type == 'test'): 50 | print("only 2007 has test dataset !\n") 51 | else: 52 | print("unknow dataset type !\n") 53 | else: 54 | print("unknow dataset year !\n") 55 | 56 | print('data_set: ' + str(data_set) + '\n') 57 | return data_set 58 | 59 | 60 | def get_ids(voc_path, year, type): 61 | ids = [] 62 | id_file = os.path.join(voc_path, 'VOC{}/ImageSets/Main/{}.txt'.format( 63 | year, type)) 64 | with open(id_file, 'r') as image_ids: 65 | ids.extend(map(str.strip, image_ids.readlines())) 66 | return ids 67 | 68 | 69 | def process_image(image_path, anno_path): 70 | image_data = tf.gfile.FastGFile(image_path, 'r').read() 71 | 72 | with open(anno_path) as f: 73 | xml_tree = ElementTree.parse(f) 74 | root = xml_tree.getroot() 75 | 76 | # Image shape. 77 | size = root.find('size') 78 | shape = [int(size.find('height').text), 79 | int(size.find('width').text), 80 | int(size.find('depth').text)] 81 | # Find annotations. 82 | bboxes = [] 83 | labels = [] 84 | labels_text = [] 85 | difficult = [] 86 | truncated = [] 87 | for obj in root.iter('object'): 88 | 89 | label = obj.find('name').text 90 | if label not in voc_common.classes: # exclude difficult or unlisted classes 91 | continue 92 | 93 | labels.append(voc_common.classes.index(label)) 94 | labels_text.append(label) 95 | 96 | if obj.find('difficult'): 97 | difficult.append(int(obj.find('difficult').text)) 98 | else: 99 | difficult.append(0) 100 | if obj.find('truncated'): 101 | truncated.append(int(obj.find('truncated').text)) 102 | else: 103 | truncated.append(0) 104 | 105 | bbox = obj.find('bndbox') 106 | bboxes.append((float(bbox.find('ymin').text) / shape[0], 107 | float(bbox.find('xmin').text) / shape[1], 108 | float(bbox.find('ymax').text) / shape[0], 109 | float(bbox.find('xmax').text) / shape[1] 110 | )) 111 | 112 | return image_data, shape, bboxes, labels, labels_text, difficult, truncated 113 | 114 | 115 | def convert_to_example(image_data, labels, labels_text, bboxes, shape, 116 | difficult, truncated): 117 | """Build an Example proto for an image example. 118 | 119 | Args: 120 | image_data: string, JPEG encoding of RGB image; 121 | labels: list of integers, identifier for the ground truth; 122 | labels_text: list of strings, human-readable labels; 123 | bboxes: list of bounding boxes; each box is a list of integers; 124 | specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong 125 | to the same label as the image label. 126 | shape: 3 integers, image shapes in pixels. 127 | Returns: 128 | Example proto 129 | """ 130 | xmin = [] 131 | ymin = [] 132 | xmax = [] 133 | ymax = [] 134 | for b in bboxes: 135 | assert len(b) == 4 136 | # pylint: disable=expression-not-assigned 137 | [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)] 138 | # pylint: enable=expression-not-assigned 139 | 140 | image_format = b'JPEG' 141 | example = tf.train.Example(features=tf.train.Features(feature={ 142 | 'image/height': int64_feature(shape[0]), 143 | 'image/width': int64_feature(shape[1]), 144 | 'image/channels': int64_feature(shape[2]), 145 | 'image/shape': int64_feature(shape), 146 | 'image/box_num':int64_feature(len(labels)), 147 | 'image/object/bbox/xmin': float_feature(xmin), 148 | 'image/object/bbox/xmax': float_feature(xmax), 149 | 'image/object/bbox/ymin': float_feature(ymin), 150 | 'image/object/bbox/ymax': float_feature(ymax), 151 | 'image/object/bbox/label': int64_feature(labels), 152 | 'image/object/bbox/label_text': bytes_feature(labels_text), 153 | 'image/object/bbox/difficult': int64_feature(difficult), 154 | 'image/object/bbox/truncated': int64_feature(truncated), 155 | 'image/format': bytes_feature(image_format), 156 | 'image/encoded': bytes_feature(image_data)})) 157 | return example 158 | 159 | 160 | def get_image_path(voc_path, year, image_id): 161 | """Get path to image for given year and image id.""" 162 | return os.path.join(voc_path, 'VOC{}/JPEGImages/{}.jpg'.format(year, 163 | image_id)) 164 | 165 | 166 | def get_anno_path(voc_path, year, image_id): 167 | """Get path to image annotation for given year and image id.""" 168 | return os.path.join(voc_path, 'VOC{}/Annotations/{}.xml'.format(year, 169 | image_id)) 170 | 171 | 172 | def get_save_path(voc_path, year, type): 173 | save_path = os.path.join(voc_path, 'TFRecords', year, type) 174 | if not os.path.exists(save_path): 175 | os.makedirs(save_path) 176 | else: 177 | shutil.rmtree(save_path) 178 | os.makedirs(save_path) 179 | return save_path 180 | 181 | 182 | def get_process_dataset_params(voc_path, year, type): 183 | ids = get_ids(voc_path, year, type) 184 | image_paths = [get_image_path(voc_path, year, i) for i in ids] 185 | anno_paths = [get_anno_path(voc_path, year, i) for i in ids] 186 | save_path = get_save_path(voc_path, year, type) 187 | return image_paths, anno_paths, save_path 188 | 189 | 190 | def process_dataset(name, image_paths, anno_paths, result_path, example_num_per_file): 191 | """Process selected Pascal VOC dataset to generate TFRecords files. 192 | 193 | Parameters 194 | ---------- 195 | name : string 196 | Name of resulting dataset 'train' or 'test'. 197 | image_paths : list 198 | List of paths to images to include in dataset. 199 | anno_paths : list 200 | List of paths to corresponding image annotations. 201 | result_path : string 202 | Path to put resulting TFRecord files. 203 | example_num_per_file : int 204 | how many examples one TFRecord file has. 205 | """ 206 | 207 | total_example_num = len(image_paths) 208 | total_files_num = int(math.ceil(total_example_num / example_num_per_file)) 209 | tfrecords_path_list_f = open(result_path + '/tfrecords_list.txt', 'w') 210 | writer = None 211 | max_box_num = 0 212 | for i in range(0, total_example_num): 213 | if i % example_num_per_file == 0: 214 | if i != 0: 215 | writer.close() 216 | tfrecords_name = '{}-{:05d}-of-{:05d}.tfrecords'.format(name, int(math.ceil(i / example_num_per_file)), 217 | total_files_num) 218 | tfrecords_name = os.path.join(result_path, tfrecords_name) 219 | print(tfrecords_name + '\n') 220 | tfrecords_path_list_f.write(tfrecords_name + '\n') 221 | writer = tf.python_io.TFRecordWriter(tfrecords_name) 222 | 223 | image_file = image_paths[i] 224 | anno_file = anno_paths[i] 225 | 226 | image_data, shape, bboxes, labels, labels_text, difficult, truncated = process_image(image_file, anno_file) 227 | if(len(labels)> max_box_num): 228 | max_box_num = len(labels) 229 | print(max_box_num) 230 | example = convert_to_example(image_data, labels, labels_text, bboxes, shape, difficult, truncated) 231 | 232 | # write to writer 233 | writer.write(example.SerializeToString()) 234 | writer.close() 235 | tfrecords_path_list_f.close() 236 | 237 | 238 | if __name__ == '__main__': 239 | 240 | """Locate files for data sets and then generate TFRecords.""" 241 | voc_path = FLAGS.path_to_voc 242 | voc_path = os.path.expanduser(voc_path) 243 | 244 | data_set = get_data_set() 245 | for year, type in data_set: 246 | image_paths, anno_paths, save_path = get_process_dataset_params(voc_path, year, type) 247 | process_dataset(type, image_paths, anno_paths, save_path, 100) 248 | -------------------------------------------------------------------------------- /src/nets/yolo_v2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | from utils import tf_utils 4 | 5 | slim = tf.contrib.slim 6 | 7 | 8 | def yolo_v2_arg_scope(weight_decay=0.0005): 9 | with slim.arg_scope([slim.max_pool2d], kernel_size=[2, 2]): 10 | with slim.arg_scope([slim.conv2d], 11 | kernel_size=[3, 3], 12 | activation_fn=tf.nn.relu, 13 | normalizer_fn=slim.batch_norm, 14 | weights_regularizer=slim.l2_regularizer(weight_decay)) as arg_sc: 15 | return arg_sc 16 | 17 | 18 | def yolo_v2(inputs, num_classes, is_training, num_anchors=5, scope='yolo_v2'): 19 | with tf.variable_scope(scope, 'yolo_v2', [inputs]) as sc: 20 | end_points_collection = sc.name + '_end_points' 21 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 22 | outputs_collections=end_points_collection): 23 | net = slim.conv2d(inputs, 32, scope='layer_0') 24 | net = slim.max_pool2d(net, scope='layer_1') 25 | net = slim.conv2d(net, 64, scope='layer_2') 26 | net = slim.max_pool2d(net, scope='layer_3') 27 | net = slim.conv2d(net, 128, scope='layer_4') 28 | net = slim.conv2d(net, 64, kernel_size=[1, 1], scope='layer_5') 29 | net = slim.conv2d(net, 128, scope='layer_6') 30 | net = slim.max_pool2d(net, scope='layer_7') 31 | net = slim.conv2d(net, 256, scope='layer_8') 32 | net = slim.conv2d(net, 128, kernel_size=[1, 1], scope='layer_9') 33 | net = slim.conv2d(net, 256, scope='layer_10') 34 | net = slim.max_pool2d(net, scope='layer_11') 35 | net = slim.conv2d(net, 512, scope='layer_12') 36 | net = slim.conv2d(net, 256, kernel_size=[1, 1], scope='layer_13') 37 | net = slim.conv2d(net, 512, scope='layer_14') 38 | net = slim.conv2d(net, 256, kernel_size=[1, 1], scope='layer_15') 39 | net = slim.conv2d(net, 512, scope='layer_16') 40 | path_1 = tf.space_to_depth(net, block_size=2, name='path_1') 41 | net = slim.max_pool2d(net, scope='layer_17') 42 | net = slim.conv2d(net, 1024, scope='layer_18') 43 | net = slim.conv2d(net, 512, kernel_size=[1, 1], scope='layer_19') 44 | net = slim.conv2d(net, 1024, scope='layer_20') 45 | net = slim.conv2d(net, 512, kernel_size=[1, 1], scope='layer_21') 46 | net = slim.conv2d(net, 1024, scope='layer_22') 47 | net = slim.conv2d(net, 1024, scope='layer_23') 48 | net = slim.conv2d(net, 1024, scope='layer_24') 49 | path_2 = net 50 | net = tf.concat([path_1, path_2], 3, name='concat2path') 51 | net = slim.conv2d(net, 1024, scope='layer_25') 52 | net = slim.conv2d(net, (num_classes + 5) * num_anchors, kernel_size=[1, 1], scope='layer_26') 53 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 54 | return net, end_points 55 | 56 | 57 | def yolo_v2_head(inputs, num_classes, anchors, is_training=True): 58 | input_shape = tf.shape(inputs) 59 | anchors_num = len(anchors) 60 | 61 | preds = tf.reshape(inputs, (input_shape[0], input_shape[1], input_shape[2], anchors_num, num_classes + 5)) 62 | box_coordinate = preds[:, :, :, :, 0:4] 63 | 64 | if is_training: 65 | box_coordinate_xy = tf.sigmoid(box_coordinate[:, :, :, :, 0:2]) 66 | # box_coordinate_wh = box_coordinate[:, :, :, :, 2:4] 67 | # box_coordinate = tf.concat([box_coordinate_xy, box_coordinate_wh], 4) 68 | 69 | anchors_tensor_w = tf.constant(anchors, tf.float32)[:, 0] 70 | anchors_tensor_h = tf.constant(anchors, tf.float32)[:, 1] 71 | anchors_tensor_w = tf.tile(anchors_tensor_w, [input_shape[0] * input_shape[1] * input_shape[2]]) 72 | anchors_tensor_h = tf.tile(anchors_tensor_h, [input_shape[0] * input_shape[1] * input_shape[2]]) 73 | anchors_tensor_w = tf.reshape(anchors_tensor_w, (input_shape[0], input_shape[1], input_shape[2], anchors_num)) 74 | anchors_tensor_h = tf.reshape(anchors_tensor_h, (input_shape[0], input_shape[1], input_shape[2], anchors_num)) 75 | 76 | box_coordinate_w = tf.expand_dims(tf.exp(preds[:, :, :, :, 2]) * anchors_tensor_w, 4) 77 | box_coordinate_h = tf.expand_dims(tf.exp(preds[:, :, :, :, 3]) * anchors_tensor_h, 4) 78 | 79 | box_coordinate = tf.concat([box_coordinate_xy, box_coordinate_w, box_coordinate_h], 4) 80 | else: 81 | anchors_tensor_w = tf.constant(anchors, tf.float32)[:, 0] 82 | anchors_tensor_h = tf.constant(anchors, tf.float32)[:, 1] 83 | anchors_tensor_w = tf.tile(anchors_tensor_w, [input_shape[0] * input_shape[1] * input_shape[2]]) 84 | anchors_tensor_h = tf.tile(anchors_tensor_h, [input_shape[0] * input_shape[1] * input_shape[2]]) 85 | anchors_tensor_w = tf.reshape(anchors_tensor_w, (input_shape[0], input_shape[1], input_shape[2], anchors_num)) 86 | anchors_tensor_h = tf.reshape(anchors_tensor_h, (input_shape[0], input_shape[1], input_shape[2], anchors_num)) 87 | 88 | conv_height_index = tf.range(input_shape[1]) 89 | conv_width_index = tf.range(input_shape[2]) 90 | conv_height_index = tf.tile(conv_height_index, [input_shape[0] * input_shape[2] * anchors_num]) 91 | conv_width_index = tf.tile(conv_width_index, [input_shape[0] * input_shape[1] * anchors_num]) 92 | conv_height_index = tf.reshape(conv_height_index, (input_shape[0], input_shape[1], input_shape[2], anchors_num)) 93 | conv_width_index = tf.reshape(conv_width_index, (input_shape[0], input_shape[1], input_shape[2], anchors_num)) 94 | 95 | box_coordinate_x = tf.expand_dims(tf.sigmoid(preds[:, :, :, :, 0]) + tf.cast(conv_width_index, tf.float32), 4) 96 | box_coordinate_y = tf.expand_dims(tf.sigmoid(preds[:, :, :, :, 1]) + tf.cast(conv_height_index, tf.float32), 4) 97 | box_coordinate_w = tf.expand_dims(tf.exp(preds[:, :, :, :, 2]) * anchors_tensor_w, 4) 98 | box_coordinate_h = tf.expand_dims(tf.exp(preds[:, :, :, :, 3]) * anchors_tensor_h, 4) 99 | box_coordinate = tf.concat([box_coordinate_x, box_coordinate_y, box_coordinate_w, box_coordinate_h], 4) 100 | 101 | box_confidence = preds[:, :, :, :, 4] 102 | box_class_probs = preds[:, :, :, :, 5:] 103 | 104 | return box_coordinate, box_confidence, box_class_probs 105 | 106 | 107 | def yolo_v2_confidence_loss(box_coordinate, box_confidence, gbboxes_batch, object_mask, object_scale, no_object_scale): 108 | iou = tf_utils.tf_boxes_iou(box_coordinate, gbboxes_batch) 109 | object_no_detections = tf.cast(iou < 0.6, tf.float32) 110 | 111 | no_objects_loss = no_object_scale * (1 - object_mask) * tf.square(0 - box_confidence) 112 | # 该栅格被标记有物体,但是预测值和标记值的IOU小于0.6,则该栅格的预测值计算object_loss 113 | objects_loss = object_scale * object_mask * object_no_detections * tf.square(1 - box_confidence) 114 | 115 | no_objects_loss = tf.reduce_sum(no_objects_loss) 116 | objects_loss = tf.reduce_sum(objects_loss) 117 | 118 | no_objects_loss = no_objects_loss/tf.cast(tf.shape(gbboxes_batch)[0],tf.float32) 119 | objects_loss = objects_loss/tf.cast(tf.shape(gbboxes_batch)[0],tf.float32) 120 | 121 | confidence_loss = objects_loss + no_objects_loss 122 | return confidence_loss, objects_loss, no_objects_loss 123 | 124 | 125 | def yolo_v2_coordinate_loss(box_coordinate, gbboxes_batch, object_mask, coordinates_scale): 126 | xy_loss = box_coordinate[..., 0:2] - gbboxes_batch[..., 0:2] 127 | xy_loss = tf.square(xy_loss) 128 | xy_loss = object_mask * tf.reduce_sum(xy_loss, 4) 129 | xy_loss = coordinates_scale * tf.reduce_sum(xy_loss) 130 | xy_loss = xy_loss/tf.cast(tf.shape(gbboxes_batch)[0],tf.float32) 131 | 132 | wh_loss = tf.sqrt(box_coordinate[..., 2:4]) - tf.sqrt(gbboxes_batch[..., 2:4]) 133 | wh_loss = tf.square(wh_loss) 134 | wh_loss = object_mask * tf.reduce_sum(wh_loss, 4) 135 | wh_loss = coordinates_scale * tf.reduce_sum(wh_loss) 136 | wh_loss = wh_loss/tf.cast(tf.shape(gbboxes_batch)[0],tf.float32) 137 | 138 | coordinate_loss = xy_loss + wh_loss 139 | # coordinate_loss = object_mask * tf.reduce_sum(coordinate_loss, 4) 140 | # coordinate_loss = coordinates_scale * tf.reduce_sum(coordinate_loss) 141 | 142 | 143 | return coordinate_loss, xy_loss, wh_loss 144 | 145 | 146 | def yolo_v2_category_loss(box_class_probs, gbboxes_batch, object_mask, num_classes, class_scale): 147 | # TODO 改善为标记的默认box[0,0,0,0,0]对分类损失的贡献 148 | # 方向: 149 | # 1.加上背景类 150 | # 2.默认box的类别one hot 全为0 151 | gbboxes_classs = tf.cast(gbboxes_batch[..., 4], tf.int32) 152 | gbboxes_classs = tf.one_hot(gbboxes_classs, num_classes) 153 | category_loss = gbboxes_classs - box_class_probs 154 | category_loss = tf.square(category_loss) 155 | category_loss = object_mask * tf.reduce_sum(category_loss, 4) 156 | category_loss = class_scale * tf.reduce_sum(category_loss) 157 | category_loss = category_loss/tf.cast(tf.shape(gbboxes_batch)[0],tf.float32) 158 | return category_loss 159 | 160 | 161 | def yolo_v2_loss(box_coordinate, box_confidence, box_class_probs, anchors, gbboxes_batch, num_classes, object_scale=5, 162 | no_object_scale=1, class_scale=1, coordinates_scale=1): 163 | object_mask = tf.reduce_sum(gbboxes_batch, 4) 164 | object_mask = tf.cast(object_mask > 0, tf.float32) 165 | 166 | confidence_loss, objects_loss, no_objects_loss = yolo_v2_confidence_loss(box_coordinate, box_confidence, 167 | gbboxes_batch, object_mask, object_scale, 168 | no_object_scale) 169 | coordinate_loss, xy_loss, wh_loss = yolo_v2_coordinate_loss(box_coordinate, gbboxes_batch, object_mask, 170 | coordinates_scale) 171 | category_loss = yolo_v2_category_loss(box_class_probs, gbboxes_batch, object_mask, num_classes, class_scale) 172 | 173 | total_loss = confidence_loss + coordinate_loss + category_loss 174 | return total_loss, confidence_loss, coordinate_loss, category_loss, xy_loss, wh_loss, objects_loss, no_objects_loss 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | 3 | python src/train.py \ 4 | --train_dir=/raid/pengchong_data/tfmodel_test/ \ 5 | --dataset_dir=/raid/pengchong_data/Data/VOC/VOCdevkit/TFRecords/2007 \ 6 | --max_number_of_steps=100 \ 7 | --batch_size=2 8 | 9 | ''' 10 | import os 11 | import tensorflow as tf 12 | 13 | from datasets import dataset_factory 14 | from nets import nets_factory, yolo_v2 15 | from preprocessing import yolo_v2_preprocessing 16 | 17 | slim = tf.contrib.slim 18 | 19 | tf.app.flags.DEFINE_string( 20 | 'master', '', 'The address of the TensorFlow master to use.') 21 | 22 | tf.app.flags.DEFINE_string( 23 | 'train_dir', '/tmp/tfmodel/', 24 | 'Directory where checkpoints and event logs are written to.') 25 | 26 | tf.app.flags.DEFINE_integer('num_clones', 1, 27 | 'Number of model clones to deploy.') 28 | 29 | tf.app.flags.DEFINE_boolean('clone_on_cpu', False, 30 | 'Use CPUs to deploy clones.') 31 | 32 | tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.') 33 | 34 | tf.app.flags.DEFINE_integer( 35 | 'num_ps_tasks', 0, 36 | 'The number of parameter servers. If the value is 0, then the parameters ' 37 | 'are handled locally by the worker.') 38 | 39 | tf.app.flags.DEFINE_integer( 40 | 'num_readers', 4, 41 | 'The number of parallel readers that read data from the dataset.') 42 | 43 | tf.app.flags.DEFINE_integer( 44 | 'num_preprocessing_threads', 4, 45 | 'The number of threads used to create the batches.') 46 | 47 | tf.app.flags.DEFINE_integer( 48 | 'log_every_n_steps', 10, 49 | 'The frequency with which logs are print.') 50 | 51 | tf.app.flags.DEFINE_integer( 52 | 'save_summaries_secs', 60, 53 | 'The frequency with which summaries are saved, in seconds.') 54 | 55 | tf.app.flags.DEFINE_integer( 56 | 'save_interval_secs', 600, 57 | 'The frequency with which the model is saved, in seconds.') 58 | 59 | tf.app.flags.DEFINE_integer( 60 | 'task', 0, 'Task id of the replica running the training.') 61 | 62 | ###################### 63 | # Optimization Flags # 64 | ###################### 65 | 66 | tf.app.flags.DEFINE_float( 67 | 'weight_decay', 0.0005, 'The weight decay on the model weights.') 68 | 69 | tf.app.flags.DEFINE_string( 70 | 'optimizer', 'rmsprop', 71 | 'The name of the optimizer, one of "adadelta", "adagrad", "adam",' 72 | '"ftrl", "momentum", "sgd" or "rmsprop".') 73 | 74 | tf.app.flags.DEFINE_float( 75 | 'adadelta_rho', 0.95, 76 | 'The decay rate for adadelta.') 77 | 78 | tf.app.flags.DEFINE_float( 79 | 'adagrad_initial_accumulator_value', 0.1, 80 | 'Starting value for the AdaGrad accumulators.') 81 | 82 | tf.app.flags.DEFINE_float( 83 | 'adam_beta1', 0.9, 84 | 'The exponential decay rate for the 1st moment estimates.') 85 | 86 | tf.app.flags.DEFINE_float( 87 | 'adam_beta2', 0.999, 88 | 'The exponential decay rate for the 2nd moment estimates.') 89 | 90 | tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.') 91 | 92 | tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5, 93 | 'The learning rate power.') 94 | 95 | tf.app.flags.DEFINE_float( 96 | 'ftrl_initial_accumulator_value', 0.1, 97 | 'Starting value for the FTRL accumulators.') 98 | 99 | tf.app.flags.DEFINE_float( 100 | 'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.') 101 | 102 | tf.app.flags.DEFINE_float( 103 | 'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.') 104 | 105 | tf.app.flags.DEFINE_float( 106 | 'momentum', 0.9, 107 | 'The momentum for the MomentumOptimizer and RMSPropOptimizer.') 108 | 109 | tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.') 110 | 111 | ####################### 112 | # Learning Rate Flags # 113 | ####################### 114 | 115 | tf.app.flags.DEFINE_string( 116 | 'learning_rate_decay_type', 117 | 'exponential', 118 | 'Specifies how the learning rate is decayed. One of "fixed", "exponential",' 119 | ' or "polynomial"') 120 | 121 | tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') 122 | 123 | tf.app.flags.DEFINE_float( 124 | 'end_learning_rate', 0.0001, 125 | 'The minimal end learning rate used by a polynomial decay learning rate.') 126 | 127 | tf.app.flags.DEFINE_float( 128 | 'label_smoothing', 0.0, 'The amount of label smoothing.') 129 | 130 | tf.app.flags.DEFINE_float( 131 | 'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.') 132 | 133 | tf.app.flags.DEFINE_float( 134 | 'num_epochs_per_decay', 2.0, 135 | 'Number of epochs after which learning rate decays.') 136 | 137 | tf.app.flags.DEFINE_bool( 138 | 'sync_replicas', False, 139 | 'Whether or not to synchronize the replicas during training.') 140 | 141 | tf.app.flags.DEFINE_integer( 142 | 'replicas_to_aggregate', 1, 143 | 'The Number of gradients to collect before updating params.') 144 | 145 | tf.app.flags.DEFINE_float( 146 | 'moving_average_decay', None, 147 | 'The decay to use for the moving average.' 148 | 'If left as None, then moving averages are not used.') 149 | 150 | ####################### 151 | # Dataset Flags # 152 | ####################### 153 | 154 | tf.app.flags.DEFINE_string( 155 | 'dataset_name', 'voc_2007', 'The name of the dataset to load.') 156 | 157 | tf.app.flags.DEFINE_string( 158 | 'dataset_split_name', 'train', 'The name of the train/test split.') 159 | 160 | tf.app.flags.DEFINE_string( 161 | 'dataset_dir', None, 'The directory where the dataset files are stored.') 162 | 163 | tf.app.flags.DEFINE_integer( 164 | 'labels_offset', 0, 165 | 'An offset for the labels in the dataset. This flag is primarily used to ' 166 | 'evaluate the VGG and ResNet architectures which do not use a background ' 167 | 'class for the ImageNet dataset.') 168 | 169 | tf.app.flags.DEFINE_string( 170 | 'model_name', 'yolo_v2', 'The name of the architecture to train.') 171 | 172 | tf.app.flags.DEFINE_string( 173 | 'preprocessing_name', None, 'The name of the preprocessing to use. If left ' 174 | 'as `None`, then the model_name flag is used.') 175 | 176 | tf.app.flags.DEFINE_integer( 177 | 'batch_size', 32, 'The number of samples in each batch.') 178 | 179 | tf.app.flags.DEFINE_integer( 180 | 'num_classes', 20, 'The number of classes.') 181 | 182 | tf.app.flags.DEFINE_integer( 183 | 'train_image_size', (416, 416), 'Train image size') 184 | 185 | tf.app.flags.DEFINE_integer('max_number_of_steps', 10000, 186 | 'The maximum number of training steps.') 187 | 188 | ##################### 189 | # Fine-Tuning Flags # 190 | ##################### 191 | 192 | tf.app.flags.DEFINE_string( 193 | 'checkpoint_path', None, 194 | 'The path to a checkpoint from which to fine-tune.') 195 | 196 | tf.app.flags.DEFINE_string( 197 | 'checkpoint_exclude_scopes', None, 198 | 'Comma-separated list of scopes of variables to exclude when restoring ' 199 | 'from a checkpoint.') 200 | 201 | tf.app.flags.DEFINE_string( 202 | 'trainable_scopes', None, 203 | 'Comma-separated list of scopes to filter the set of variables to train.' 204 | 'By default, None would train all the variables.') 205 | 206 | tf.app.flags.DEFINE_boolean( 207 | 'ignore_missing_vars', False, 208 | 'When restoring a checkpoint would ignore missing variables.') 209 | 210 | FLAGS = tf.app.flags.FLAGS 211 | 212 | 213 | def inference_sequential(image_batch): 214 | network_fn = nets_factory.get_network_fn( 215 | name=FLAGS.model_name, 216 | num_classes=FLAGS.num_classes, 217 | is_training=True, 218 | weight_decay=FLAGS.weight_decay, 219 | num_anchors=5) 220 | net, end_points = network_fn(image_batch) 221 | 222 | box_coordinate, box_confidence, box_class_probs = yolo_v2.yolo_v2_head(net, FLAGS.num_classes, 223 | [[1, 2], [1, 3], [2, 1], [3, 1], [1, 1]], 224 | True) 225 | 226 | # preds = tf.reduce_max(box_class_probs, 4) 227 | # preds = tf.one_hot(tf.cast(preds, tf.int32), FLAGS.num_classes) 228 | 229 | # return preds 230 | 231 | return box_coordinate, box_confidence, box_class_probs 232 | 233 | 234 | # =========================================================================== # 235 | # Main training routine. 236 | # =========================================================================== # 237 | def main(_): 238 | with tf.Graph().as_default(): 239 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 240 | global_step = tf.train.create_global_step() 241 | 242 | # Select the dataset. 243 | dataset = dataset_factory.get_dataset( 244 | FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir) 245 | max_box_num_per_image = dataset_factory.get_box_num_per_image(FLAGS.dataset_name, FLAGS.dataset_split_name) 246 | provider = slim.dataset_data_provider.DatasetDataProvider( 247 | dataset, 248 | num_readers=FLAGS.num_readers, 249 | common_queue_capacity=20 * FLAGS.batch_size, 250 | common_queue_min=10 * FLAGS.batch_size) 251 | # Get input for network: image, labels, bboxes. 252 | [image, glabels, gbboxes, box_num] = provider.get(['image', 'object/label', 'object/bbox', 'box_num']) 253 | 254 | train_image_size = FLAGS.train_image_size 255 | image, gbboxes = yolo_v2_preprocessing.preprocess_data(image, glabels, gbboxes, train_image_size, 256 | max_box_num_per_image) 257 | 258 | image_batch, gbboxes_batch = tf.train.batch( 259 | [image, gbboxes], 260 | batch_size=FLAGS.batch_size, 261 | num_threads=FLAGS.num_preprocessing_threads, 262 | capacity=5 * FLAGS.batch_size) 263 | 264 | batch_queue = slim.prefetch_queue.prefetch_queue( 265 | [image_batch, gbboxes_batch], capacity=2) 266 | 267 | image_batch, gbboxes_batch = batch_queue.dequeue() 268 | 269 | summaries.add(tf.summary.image('batch image', image_batch)) 270 | 271 | print(gbboxes_batch) 272 | 273 | box_coordinate, box_confidence, box_class_probs = inference_sequential(image_batch) 274 | total_loss, confidence_loss, coordinate_loss, category_loss, xy_loss, wh_loss, objects_loss, no_objects_loss = yolo_v2.yolo_v2_loss( 275 | box_coordinate, 276 | box_confidence, 277 | box_class_probs, 278 | [[1, 2], [1, 3], [2, 1],[3, 1], [1, 1]], 279 | gbboxes_batch, 280 | num_classes=FLAGS.num_classes) 281 | 282 | summaries.add(tf.summary.scalar('loss_total', total_loss)) 283 | summaries.add(tf.summary.scalar('loss_confidence', confidence_loss)) 284 | summaries.add(tf.summary.scalar('loss_confidence_object', objects_loss)) 285 | summaries.add(tf.summary.scalar('loss_confidence_no_object', no_objects_loss)) 286 | summaries.add(tf.summary.scalar('loss_coordinate', coordinate_loss)) 287 | summaries.add(tf.summary.scalar('loss_coordinate_xy', xy_loss)) 288 | summaries.add(tf.summary.scalar('loss_coordinate_wh', wh_loss)) 289 | summaries.add(tf.summary.scalar('loss_category', category_loss)) 290 | 291 | # optimizer = tf.train.GradientDescentOptimizer(0.01) 292 | optimizer = tf.train.AdamOptimizer( 293 | learning_rate=FLAGS.learning_rate, 294 | beta1=FLAGS.adam_beta1, 295 | beta2=FLAGS.adam_beta2, 296 | epsilon=FLAGS.opt_epsilon) 297 | 298 | train_op = slim.learning.create_train_op(total_loss, optimizer) 299 | 300 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 301 | 302 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 303 | 304 | sess_config = tf.ConfigProto() 305 | sess_config.gpu_options.allow_growth = True 306 | 307 | final_loss = slim.learning.train(train_op, 308 | logdir=FLAGS.train_dir, 309 | summary_op=summary_op, 310 | global_step=global_step, 311 | number_of_steps=FLAGS.max_number_of_steps, 312 | log_every_n_steps=FLAGS.log_every_n_steps, 313 | save_summaries_secs=FLAGS.save_summaries_secs, 314 | save_interval_secs=FLAGS.save_interval_secs, 315 | session_config=sess_config) 316 | 317 | print('Finished training. Last batch loss %f' % final_loss) 318 | 319 | 320 | if __name__ == '__main__': 321 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 322 | tf.app.run() 323 | -------------------------------------------------------------------------------- /src/utils/tf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Paul Balanca. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Diverse TensorFlow utils, for training, evaluation and so on! 16 | """ 17 | import os 18 | from pprint import pprint 19 | 20 | import tensorflow as tf 21 | from tensorflow.contrib.slim.python.slim.data import parallel_reader 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | # =========================================================================== # 27 | # General tools. 28 | # =========================================================================== # 29 | def reshape_list(l, shape=None): 30 | """Reshape list of (list): 1D to 2D or the other way around. 31 | 32 | Args: 33 | l: List or List of list. 34 | shape: 1D or 2D shape. 35 | Return 36 | Reshaped list. 37 | """ 38 | r = [] 39 | if shape is None: 40 | # Flatten everything. 41 | for a in l: 42 | if isinstance(a, (list, tuple)): 43 | r = r + list(a) 44 | else: 45 | r.append(a) 46 | else: 47 | # Reshape to list of list. 48 | i = 0 49 | for s in shape: 50 | if s == 1: 51 | r.append(l[i]) 52 | else: 53 | r.append(l[i:i + s]) 54 | i += s 55 | return r 56 | 57 | 58 | def tf_overlap(x1, len1, x2, len2): 59 | len1_half = len1 / 2 60 | len2_half = len2 / 2 61 | 62 | left = tf.maximum(x1 - len1_half, x2 - len2_half) 63 | right = tf.minimum(x1 + len1_half, x2 + len2_half) 64 | 65 | overlap = right - left 66 | 67 | overlap = tf.maximum(overlap, tf.constant(0, tf.float32)) 68 | 69 | return overlap 70 | 71 | 72 | def tf_multi_overlap(x1, len1, x2, len2): 73 | len1_half = len1 / 2 74 | len2_half = len2 / 2 75 | 76 | left = tf.maximum(x1 - len1_half, x2 - len2_half) 77 | right = tf.minimum(x1 + len1_half, x2 + len2_half) 78 | 79 | overlap = right - left 80 | overlap = tf.maximum(overlap, tf.zeros_like(overlap)) 81 | 82 | return overlap 83 | 84 | 85 | def tf_box_intersection(a, b): 86 | w = tf_overlap(a[0], a[2], b[0], b[2]) 87 | h = tf_overlap(a[1], a[3], b[1], b[3]) 88 | area = w * h 89 | return area 90 | 91 | 92 | def tf_boxes_intersection(a, b): 93 | w = tf_multi_overlap(a[..., 0], a[..., 2], b[..., 0], b[..., 2]) 94 | h = tf_multi_overlap(a[..., 1], a[..., 3], b[..., 1], b[..., 3]) 95 | area = w * h 96 | return area 97 | 98 | 99 | def tf_box_union(a, b): 100 | i = tf_box_intersection(a, b) 101 | u = a[2] * a[3] + b[2] * b[3] - i 102 | return u 103 | 104 | 105 | def tf_boxes_union(a, b): 106 | i = tf_boxes_intersection(a, b) 107 | u = a[..., 2] * a[..., 3] + b[..., 2] * b[..., 3] - i 108 | return u 109 | 110 | 111 | def tf_box_iou(a, b): 112 | iou = tf_box_intersection(a, b) / tf_box_union(a, b) 113 | return iou 114 | 115 | 116 | def tf_boxes_iou(a, b): 117 | iou = tf_boxes_intersection(a, b) / tf_boxes_union(a, b) 118 | return iou 119 | 120 | 121 | def tf_anchor_iou(a, b): 122 | box_a = a[2:4] 123 | box_a = tf.concat([tf.constant([0, 0], tf.float32), box_a], 0) 124 | 125 | # box_a = tf.constant([0, 0, a[2], a[3]], tf.float32) 126 | box_b = tf.constant([0, 0, b[0], b[1]], tf.float32) 127 | iou = tf_box_iou(box_a, box_b) 128 | return iou 129 | 130 | 131 | # =========================================================================== # 132 | # Training utils. 133 | # =========================================================================== # 134 | ''' 135 | def print_configuration(flags, ssd_params, data_sources, save_dir=None): 136 | """Print the training configuration. 137 | """ 138 | def print_config(stream=None): 139 | print('\n# =========================================================================== #', file=stream) 140 | print('# Training | Evaluation flags:', file=stream) 141 | print('# =========================================================================== #', file=stream) 142 | pprint(flags, stream=stream) 143 | 144 | print('\n# =========================================================================== #', file=stream) 145 | print('# SSD net parameters:', file=stream) 146 | print('# =========================================================================== #', file=stream) 147 | pprint(dict(ssd_params._asdict()), stream=stream) 148 | 149 | print('\n# =========================================================================== #', file=stream) 150 | print('# Training | Evaluation dataset files:', file=stream) 151 | print('# =========================================================================== #', file=stream) 152 | data_files = parallel_reader.get_data_files(data_sources) 153 | pprint(sorted(data_files), stream=stream) 154 | print('', file=stream) 155 | 156 | print_config(None) 157 | # Save to a text file as well. 158 | if save_dir is not None: 159 | if not os.path.exists(save_dir): 160 | os.makedirs(save_dir) 161 | path = os.path.join(save_dir, 'training_config.txt') 162 | with open(path, "w") as out: 163 | print_config(out) 164 | ''' 165 | 166 | 167 | def configure_learning_rate(flags, num_samples_per_epoch, global_step): 168 | """Configures the learning rate. 169 | 170 | Args: 171 | num_samples_per_epoch: The number of samples in each epoch of training. 172 | global_step: The global_step tensor. 173 | Returns: 174 | A `Tensor` representing the learning rate. 175 | """ 176 | decay_steps = int(num_samples_per_epoch / flags.batch_size * 177 | flags.num_epochs_per_decay) 178 | 179 | if flags.learning_rate_decay_type == 'exponential': 180 | return tf.train.exponential_decay(flags.learning_rate, 181 | global_step, 182 | decay_steps, 183 | flags.learning_rate_decay_factor, 184 | staircase=True, 185 | name='exponential_decay_learning_rate') 186 | elif flags.learning_rate_decay_type == 'fixed': 187 | return tf.constant(flags.learning_rate, name='fixed_learning_rate') 188 | elif flags.learning_rate_decay_type == 'polynomial': 189 | return tf.train.polynomial_decay(flags.learning_rate, 190 | global_step, 191 | decay_steps, 192 | flags.end_learning_rate, 193 | power=1.0, 194 | cycle=False, 195 | name='polynomial_decay_learning_rate') 196 | else: 197 | raise ValueError('learning_rate_decay_type [%s] was not recognized', 198 | flags.learning_rate_decay_type) 199 | 200 | 201 | def configure_optimizer(flags, learning_rate): 202 | """Configures the optimizer used for training. 203 | 204 | Args: 205 | learning_rate: A scalar or `Tensor` learning rate. 206 | Returns: 207 | An instance of an optimizer. 208 | """ 209 | if flags.optimizer == 'adadelta': 210 | optimizer = tf.train.AdadeltaOptimizer( 211 | learning_rate, 212 | rho=flags.adadelta_rho, 213 | epsilon=flags.opt_epsilon) 214 | elif flags.optimizer == 'adagrad': 215 | optimizer = tf.train.AdagradOptimizer( 216 | learning_rate, 217 | initial_accumulator_value=flags.adagrad_initial_accumulator_value) 218 | elif flags.optimizer == 'adam': 219 | optimizer = tf.train.AdamOptimizer( 220 | learning_rate, 221 | beta1=flags.adam_beta1, 222 | beta2=flags.adam_beta2, 223 | epsilon=flags.opt_epsilon) 224 | elif flags.optimizer == 'ftrl': 225 | optimizer = tf.train.FtrlOptimizer( 226 | learning_rate, 227 | learning_rate_power=flags.ftrl_learning_rate_power, 228 | initial_accumulator_value=flags.ftrl_initial_accumulator_value, 229 | l1_regularization_strength=flags.ftrl_l1, 230 | l2_regularization_strength=flags.ftrl_l2) 231 | elif flags.optimizer == 'momentum': 232 | optimizer = tf.train.MomentumOptimizer( 233 | learning_rate, 234 | momentum=flags.momentum, 235 | name='Momentum') 236 | elif flags.optimizer == 'rmsprop': 237 | optimizer = tf.train.RMSPropOptimizer( 238 | learning_rate, 239 | decay=flags.rmsprop_decay, 240 | momentum=flags.rmsprop_momentum, 241 | epsilon=flags.opt_epsilon) 242 | elif flags.optimizer == 'sgd': 243 | optimizer = tf.train.GradientDescentOptimizer(learning_rate) 244 | else: 245 | raise ValueError('Optimizer [%s] was not recognized', flags.optimizer) 246 | return optimizer 247 | 248 | 249 | def add_variables_summaries(learning_rate): 250 | summaries = [] 251 | for variable in slim.get_model_variables(): 252 | summaries.append(tf.summary.histogram(variable.op.name, variable)) 253 | summaries.append(tf.summary.scalar('training/Learning Rate', learning_rate)) 254 | return summaries 255 | 256 | 257 | def update_model_scope(var, ckpt_scope, new_scope): 258 | return var.op.name.replace(new_scope, 'vgg_16') 259 | 260 | 261 | def get_init_fn(flags): 262 | """Returns a function run by the chief worker to warm-start the training. 263 | Note that the init_fn is only run when initializing the model during the very 264 | first global step. 265 | 266 | Returns: 267 | An init function run by the supervisor. 268 | """ 269 | if flags.checkpoint_path is None: 270 | return None 271 | # Warn the user if a checkpoint exists in the train_dir. Then ignore. 272 | if tf.train.latest_checkpoint(flags.train_dir): 273 | tf.logging.info( 274 | 'Ignoring --checkpoint_path because a checkpoint already exists in %s' 275 | % flags.train_dir) 276 | return None 277 | 278 | exclusions = [] 279 | if flags.checkpoint_exclude_scopes: 280 | exclusions = [scope.strip() 281 | for scope in flags.checkpoint_exclude_scopes.split(',')] 282 | 283 | # TODO(sguada) variables.filter_variables() 284 | variables_to_restore = [] 285 | for var in slim.get_model_variables(): 286 | excluded = False 287 | for exclusion in exclusions: 288 | if var.op.name.startswith(exclusion): 289 | excluded = True 290 | break 291 | if not excluded: 292 | variables_to_restore.append(var) 293 | # Change model scope if necessary. 294 | if flags.checkpoint_model_scope is not None: 295 | variables_to_restore = \ 296 | {var.op.name.replace(flags.model_name, 297 | flags.checkpoint_model_scope): var 298 | for var in variables_to_restore} 299 | 300 | if tf.gfile.IsDirectory(flags.checkpoint_path): 301 | checkpoint_path = tf.train.latest_checkpoint(flags.checkpoint_path) 302 | else: 303 | checkpoint_path = flags.checkpoint_path 304 | tf.logging.info('Fine-tuning from %s. Ignoring missing vars: %s' % (checkpoint_path, flags.ignore_missing_vars)) 305 | 306 | return slim.assign_from_checkpoint_fn( 307 | checkpoint_path, 308 | variables_to_restore, 309 | ignore_missing_vars=flags.ignore_missing_vars) 310 | 311 | 312 | def get_variables_to_train(flags): 313 | """Returns a list of variables to train. 314 | 315 | Returns: 316 | A list of variables to train by the optimizer. 317 | """ 318 | if flags.trainable_scopes is None: 319 | return tf.trainable_variables() 320 | else: 321 | scopes = [scope.strip() for scope in flags.trainable_scopes.split(',')] 322 | 323 | variables_to_train = [] 324 | for scope in scopes: 325 | variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope) 326 | variables_to_train.extend(variables) 327 | return variables_to_train 328 | 329 | # =========================================================================== # 330 | # Evaluation utils. 331 | # =========================================================================== # 332 | -------------------------------------------------------------------------------- /src/deployment/model_deploy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for model_deploy.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from deployment import model_deploy 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | class DeploymentConfigTest(tf.test.TestCase): 30 | 31 | def testDefaults(self): 32 | deploy_config = model_deploy.DeploymentConfig() 33 | 34 | self.assertEqual(slim.get_variables(), []) 35 | self.assertEqual(deploy_config.caching_device(), None) 36 | self.assertDeviceEqual(deploy_config.clone_device(0), '') 37 | self.assertEqual(deploy_config.clone_scope(0), '') 38 | self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0') 39 | self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0') 40 | self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0') 41 | 42 | def testCPUonly(self): 43 | deploy_config = model_deploy.DeploymentConfig(clone_on_cpu=True) 44 | 45 | self.assertEqual(deploy_config.caching_device(), None) 46 | self.assertDeviceEqual(deploy_config.clone_device(0), 'CPU:0') 47 | self.assertEqual(deploy_config.clone_scope(0), '') 48 | self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0') 49 | self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0') 50 | self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0') 51 | 52 | def testMultiGPU(self): 53 | deploy_config = model_deploy.DeploymentConfig(num_clones=2) 54 | 55 | self.assertEqual(deploy_config.caching_device(), None) 56 | self.assertDeviceEqual(deploy_config.clone_device(0), 'GPU:0') 57 | self.assertDeviceEqual(deploy_config.clone_device(1), 'GPU:1') 58 | self.assertEqual(deploy_config.clone_scope(0), 'clone_0') 59 | self.assertEqual(deploy_config.clone_scope(1), 'clone_1') 60 | self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0') 61 | self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0') 62 | self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0') 63 | 64 | def testPS(self): 65 | deploy_config = model_deploy.DeploymentConfig(num_clones=1, num_ps_tasks=1) 66 | 67 | self.assertDeviceEqual(deploy_config.clone_device(0), 68 | '/job:worker') 69 | self.assertEqual(deploy_config.clone_scope(0), '') 70 | self.assertDeviceEqual(deploy_config.optimizer_device(), 71 | '/job:worker/device:CPU:0') 72 | self.assertDeviceEqual(deploy_config.inputs_device(), 73 | '/job:worker/device:CPU:0') 74 | with tf.device(deploy_config.variables_device()): 75 | a = tf.Variable(0) 76 | b = tf.Variable(0) 77 | c = tf.no_op() 78 | d = slim.variable('a', [], 79 | caching_device=deploy_config.caching_device()) 80 | self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0') 81 | self.assertDeviceEqual(a.device, a.value().device) 82 | self.assertDeviceEqual(b.device, '/job:ps/task:0/device:CPU:0') 83 | self.assertDeviceEqual(b.device, b.value().device) 84 | self.assertDeviceEqual(c.device, '') 85 | self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0') 86 | self.assertDeviceEqual(d.value().device, '') 87 | 88 | def testMultiGPUPS(self): 89 | deploy_config = model_deploy.DeploymentConfig(num_clones=2, num_ps_tasks=1) 90 | 91 | self.assertEqual(deploy_config.caching_device()(tf.no_op()), '') 92 | self.assertDeviceEqual(deploy_config.clone_device(0), 93 | '/job:worker/device:GPU:0') 94 | self.assertDeviceEqual(deploy_config.clone_device(1), 95 | '/job:worker/device:GPU:1') 96 | self.assertEqual(deploy_config.clone_scope(0), 'clone_0') 97 | self.assertEqual(deploy_config.clone_scope(1), 'clone_1') 98 | self.assertDeviceEqual(deploy_config.optimizer_device(), 99 | '/job:worker/device:CPU:0') 100 | self.assertDeviceEqual(deploy_config.inputs_device(), 101 | '/job:worker/device:CPU:0') 102 | 103 | def testReplicasPS(self): 104 | deploy_config = model_deploy.DeploymentConfig(num_replicas=2, 105 | num_ps_tasks=2) 106 | 107 | self.assertDeviceEqual(deploy_config.clone_device(0), 108 | '/job:worker') 109 | self.assertEqual(deploy_config.clone_scope(0), '') 110 | self.assertDeviceEqual(deploy_config.optimizer_device(), 111 | '/job:worker/device:CPU:0') 112 | self.assertDeviceEqual(deploy_config.inputs_device(), 113 | '/job:worker/device:CPU:0') 114 | 115 | def testReplicasMultiGPUPS(self): 116 | deploy_config = model_deploy.DeploymentConfig(num_replicas=2, 117 | num_clones=2, 118 | num_ps_tasks=2) 119 | self.assertDeviceEqual(deploy_config.clone_device(0), 120 | '/job:worker/device:GPU:0') 121 | self.assertDeviceEqual(deploy_config.clone_device(1), 122 | '/job:worker/device:GPU:1') 123 | self.assertEqual(deploy_config.clone_scope(0), 'clone_0') 124 | self.assertEqual(deploy_config.clone_scope(1), 'clone_1') 125 | self.assertDeviceEqual(deploy_config.optimizer_device(), 126 | '/job:worker/device:CPU:0') 127 | self.assertDeviceEqual(deploy_config.inputs_device(), 128 | '/job:worker/device:CPU:0') 129 | 130 | def testVariablesPS(self): 131 | deploy_config = model_deploy.DeploymentConfig(num_ps_tasks=2) 132 | 133 | with tf.device(deploy_config.variables_device()): 134 | a = tf.Variable(0) 135 | b = tf.Variable(0) 136 | c = tf.no_op() 137 | d = slim.variable('a', [], 138 | caching_device=deploy_config.caching_device()) 139 | 140 | self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0') 141 | self.assertDeviceEqual(a.device, a.value().device) 142 | self.assertDeviceEqual(b.device, '/job:ps/task:1/device:CPU:0') 143 | self.assertDeviceEqual(b.device, b.value().device) 144 | self.assertDeviceEqual(c.device, '') 145 | self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0') 146 | self.assertDeviceEqual(d.value().device, '') 147 | 148 | 149 | def LogisticClassifier(inputs, labels, scope=None, reuse=None): 150 | with tf.variable_scope(scope, 'LogisticClassifier', [inputs, labels], 151 | reuse=reuse): 152 | predictions = slim.fully_connected(inputs, 1, activation_fn=tf.sigmoid, 153 | scope='fully_connected') 154 | slim.losses.log_loss(predictions, labels) 155 | return predictions 156 | 157 | 158 | def BatchNormClassifier(inputs, labels, scope=None, reuse=None): 159 | with tf.variable_scope(scope, 'BatchNormClassifier', [inputs, labels], 160 | reuse=reuse): 161 | inputs = slim.batch_norm(inputs, decay=0.1) 162 | predictions = slim.fully_connected(inputs, 1, 163 | activation_fn=tf.sigmoid, 164 | scope='fully_connected') 165 | slim.losses.log_loss(predictions, labels) 166 | return predictions 167 | 168 | 169 | class CreatecloneTest(tf.test.TestCase): 170 | 171 | def setUp(self): 172 | # Create an easy training set: 173 | np.random.seed(0) 174 | 175 | self._inputs = np.zeros((16, 4)) 176 | self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 177 | self._logdir = self.get_temp_dir() 178 | 179 | for i in range(16): 180 | j = int(2 * self._labels[i] + np.random.randint(0, 2)) 181 | self._inputs[i, j] = 1 182 | 183 | def testCreateLogisticClassifier(self): 184 | g = tf.Graph() 185 | with g.as_default(): 186 | tf.set_random_seed(0) 187 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 188 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 189 | 190 | model_fn = LogisticClassifier 191 | clone_args = (tf_inputs, tf_labels) 192 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 193 | 194 | self.assertEqual(slim.get_variables(), []) 195 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 196 | clone = clones[0] 197 | self.assertEqual(len(slim.get_variables()), 2) 198 | for v in slim.get_variables(): 199 | self.assertDeviceEqual(v.device, 'CPU:0') 200 | self.assertDeviceEqual(v.value().device, 'CPU:0') 201 | self.assertEqual(clone.outputs.op.name, 202 | 'LogisticClassifier/fully_connected/Sigmoid') 203 | self.assertEqual(clone.scope, '') 204 | self.assertDeviceEqual(clone.device, '') 205 | self.assertEqual(len(slim.losses.get_losses()), 1) 206 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 207 | self.assertEqual(update_ops, []) 208 | 209 | def testCreateSingleclone(self): 210 | g = tf.Graph() 211 | with g.as_default(): 212 | tf.set_random_seed(0) 213 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 214 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 215 | 216 | model_fn = BatchNormClassifier 217 | clone_args = (tf_inputs, tf_labels) 218 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 219 | 220 | self.assertEqual(slim.get_variables(), []) 221 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 222 | clone = clones[0] 223 | self.assertEqual(len(slim.get_variables()), 5) 224 | for v in slim.get_variables(): 225 | self.assertDeviceEqual(v.device, 'CPU:0') 226 | self.assertDeviceEqual(v.value().device, 'CPU:0') 227 | self.assertEqual(clone.outputs.op.name, 228 | 'BatchNormClassifier/fully_connected/Sigmoid') 229 | self.assertEqual(clone.scope, '') 230 | self.assertDeviceEqual(clone.device, '') 231 | self.assertEqual(len(slim.losses.get_losses()), 1) 232 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 233 | self.assertEqual(len(update_ops), 2) 234 | 235 | def testCreateMulticlone(self): 236 | g = tf.Graph() 237 | with g.as_default(): 238 | tf.set_random_seed(0) 239 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 240 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 241 | 242 | model_fn = BatchNormClassifier 243 | clone_args = (tf_inputs, tf_labels) 244 | num_clones = 4 245 | deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones) 246 | 247 | self.assertEqual(slim.get_variables(), []) 248 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 249 | self.assertEqual(len(slim.get_variables()), 5) 250 | for v in slim.get_variables(): 251 | self.assertDeviceEqual(v.device, 'CPU:0') 252 | self.assertDeviceEqual(v.value().device, 'CPU:0') 253 | self.assertEqual(len(clones), num_clones) 254 | for i, clone in enumerate(clones): 255 | self.assertEqual( 256 | clone.outputs.op.name, 257 | 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i) 258 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope) 259 | self.assertEqual(len(update_ops), 2) 260 | self.assertEqual(clone.scope, 'clone_%d/' % i) 261 | self.assertDeviceEqual(clone.device, 'GPU:%d' % i) 262 | 263 | def testCreateOnecloneWithPS(self): 264 | g = tf.Graph() 265 | with g.as_default(): 266 | tf.set_random_seed(0) 267 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 268 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 269 | 270 | model_fn = BatchNormClassifier 271 | clone_args = (tf_inputs, tf_labels) 272 | deploy_config = model_deploy.DeploymentConfig(num_clones=1, 273 | num_ps_tasks=1) 274 | 275 | self.assertEqual(slim.get_variables(), []) 276 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 277 | self.assertEqual(len(clones), 1) 278 | clone = clones[0] 279 | self.assertEqual(clone.outputs.op.name, 280 | 'BatchNormClassifier/fully_connected/Sigmoid') 281 | self.assertDeviceEqual(clone.device, '/job:worker') 282 | self.assertEqual(clone.scope, '') 283 | self.assertEqual(len(slim.get_variables()), 5) 284 | for v in slim.get_variables(): 285 | self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0') 286 | self.assertDeviceEqual(v.device, v.value().device) 287 | 288 | def testCreateMulticloneWithPS(self): 289 | g = tf.Graph() 290 | with g.as_default(): 291 | tf.set_random_seed(0) 292 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 293 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 294 | 295 | model_fn = BatchNormClassifier 296 | clone_args = (tf_inputs, tf_labels) 297 | deploy_config = model_deploy.DeploymentConfig(num_clones=2, 298 | num_ps_tasks=2) 299 | 300 | self.assertEqual(slim.get_variables(), []) 301 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 302 | self.assertEqual(len(slim.get_variables()), 5) 303 | for i, v in enumerate(slim.get_variables()): 304 | t = i % 2 305 | self.assertDeviceEqual(v.device, '/job:ps/task:%d/device:CPU:0' % t) 306 | self.assertDeviceEqual(v.device, v.value().device) 307 | self.assertEqual(len(clones), 2) 308 | for i, clone in enumerate(clones): 309 | self.assertEqual( 310 | clone.outputs.op.name, 311 | 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i) 312 | self.assertEqual(clone.scope, 'clone_%d/' % i) 313 | self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:%d' % i) 314 | 315 | 316 | class OptimizeclonesTest(tf.test.TestCase): 317 | 318 | def setUp(self): 319 | # Create an easy training set: 320 | np.random.seed(0) 321 | 322 | self._inputs = np.zeros((16, 4)) 323 | self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 324 | self._logdir = self.get_temp_dir() 325 | 326 | for i in range(16): 327 | j = int(2 * self._labels[i] + np.random.randint(0, 2)) 328 | self._inputs[i, j] = 1 329 | 330 | def testCreateLogisticClassifier(self): 331 | g = tf.Graph() 332 | with g.as_default(): 333 | tf.set_random_seed(0) 334 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 335 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 336 | 337 | model_fn = LogisticClassifier 338 | clone_args = (tf_inputs, tf_labels) 339 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 340 | 341 | self.assertEqual(slim.get_variables(), []) 342 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 343 | self.assertEqual(len(slim.get_variables()), 2) 344 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 345 | self.assertEqual(update_ops, []) 346 | 347 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 348 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 349 | optimizer) 350 | self.assertEqual(len(grads_and_vars), len(tf.trainable_variables())) 351 | self.assertEqual(total_loss.op.name, 'total_loss') 352 | for g, v in grads_and_vars: 353 | self.assertDeviceEqual(g.device, '') 354 | self.assertDeviceEqual(v.device, 'CPU:0') 355 | 356 | def testCreateSingleclone(self): 357 | g = tf.Graph() 358 | with g.as_default(): 359 | tf.set_random_seed(0) 360 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 361 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 362 | 363 | model_fn = BatchNormClassifier 364 | clone_args = (tf_inputs, tf_labels) 365 | deploy_config = model_deploy.DeploymentConfig(num_clones=1) 366 | 367 | self.assertEqual(slim.get_variables(), []) 368 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 369 | self.assertEqual(len(slim.get_variables()), 5) 370 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 371 | self.assertEqual(len(update_ops), 2) 372 | 373 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 374 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 375 | optimizer) 376 | self.assertEqual(len(grads_and_vars), len(tf.trainable_variables())) 377 | self.assertEqual(total_loss.op.name, 'total_loss') 378 | for g, v in grads_and_vars: 379 | self.assertDeviceEqual(g.device, '') 380 | self.assertDeviceEqual(v.device, 'CPU:0') 381 | 382 | def testCreateMulticlone(self): 383 | g = tf.Graph() 384 | with g.as_default(): 385 | tf.set_random_seed(0) 386 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 387 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 388 | 389 | model_fn = BatchNormClassifier 390 | clone_args = (tf_inputs, tf_labels) 391 | num_clones = 4 392 | deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones) 393 | 394 | self.assertEqual(slim.get_variables(), []) 395 | clones = model_deploy.create_clones(deploy_config, model_fn, clone_args) 396 | self.assertEqual(len(slim.get_variables()), 5) 397 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 398 | self.assertEqual(len(update_ops), num_clones * 2) 399 | 400 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 401 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 402 | optimizer) 403 | self.assertEqual(len(grads_and_vars), len(tf.trainable_variables())) 404 | self.assertEqual(total_loss.op.name, 'total_loss') 405 | for g, v in grads_and_vars: 406 | self.assertDeviceEqual(g.device, '') 407 | self.assertDeviceEqual(v.device, 'CPU:0') 408 | 409 | def testCreateMulticloneCPU(self): 410 | g = tf.Graph() 411 | with g.as_default(): 412 | tf.set_random_seed(0) 413 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 414 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 415 | 416 | model_fn = BatchNormClassifier 417 | model_args = (tf_inputs, tf_labels) 418 | num_clones = 4 419 | deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones, 420 | clone_on_cpu=True) 421 | 422 | self.assertEqual(slim.get_variables(), []) 423 | clones = model_deploy.create_clones(deploy_config, model_fn, model_args) 424 | self.assertEqual(len(slim.get_variables()), 5) 425 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 426 | self.assertEqual(len(update_ops), num_clones * 2) 427 | 428 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 429 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 430 | optimizer) 431 | self.assertEqual(len(grads_and_vars), len(tf.trainable_variables())) 432 | self.assertEqual(total_loss.op.name, 'total_loss') 433 | for g, v in grads_and_vars: 434 | self.assertDeviceEqual(g.device, '') 435 | self.assertDeviceEqual(v.device, 'CPU:0') 436 | 437 | def testCreateOnecloneWithPS(self): 438 | g = tf.Graph() 439 | with g.as_default(): 440 | tf.set_random_seed(0) 441 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 442 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 443 | 444 | model_fn = BatchNormClassifier 445 | model_args = (tf_inputs, tf_labels) 446 | deploy_config = model_deploy.DeploymentConfig(num_clones=1, 447 | num_ps_tasks=1) 448 | 449 | self.assertEqual(slim.get_variables(), []) 450 | clones = model_deploy.create_clones(deploy_config, model_fn, model_args) 451 | self.assertEqual(len(slim.get_variables()), 5) 452 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 453 | self.assertEqual(len(update_ops), 2) 454 | 455 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 456 | total_loss, grads_and_vars = model_deploy.optimize_clones(clones, 457 | optimizer) 458 | self.assertEqual(len(grads_and_vars), len(tf.trainable_variables())) 459 | self.assertEqual(total_loss.op.name, 'total_loss') 460 | for g, v in grads_and_vars: 461 | self.assertDeviceEqual(g.device, '/job:worker') 462 | self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0') 463 | 464 | 465 | class DeployTest(tf.test.TestCase): 466 | 467 | def setUp(self): 468 | # Create an easy training set: 469 | np.random.seed(0) 470 | 471 | self._inputs = np.zeros((16, 4)) 472 | self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 473 | self._logdir = self.get_temp_dir() 474 | 475 | for i in range(16): 476 | j = int(2 * self._labels[i] + np.random.randint(0, 2)) 477 | self._inputs[i, j] = 1 478 | 479 | def testLocalTrainOp(self): 480 | g = tf.Graph() 481 | with g.as_default(): 482 | tf.set_random_seed(0) 483 | tf_inputs = tf.constant(self._inputs, dtype=tf.float32) 484 | tf_labels = tf.constant(self._labels, dtype=tf.float32) 485 | 486 | model_fn = BatchNormClassifier 487 | model_args = (tf_inputs, tf_labels) 488 | deploy_config = model_deploy.DeploymentConfig(num_clones=2, 489 | clone_on_cpu=True) 490 | 491 | optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) 492 | 493 | self.assertEqual(slim.get_variables(), []) 494 | model = model_deploy.deploy(deploy_config, model_fn, model_args, 495 | optimizer=optimizer) 496 | 497 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 498 | self.assertEqual(len(update_ops), 4) 499 | self.assertEqual(len(model.clones), 2) 500 | self.assertEqual(model.total_loss.op.name, 'total_loss') 501 | self.assertEqual(model.summary_op.op.name, 'summary_op/summary_op') 502 | self.assertEqual(model.train_op.op.name, 'train_op') 503 | 504 | with tf.Session() as sess: 505 | sess.run(tf.global_variables_initializer()) 506 | moving_mean = tf.contrib.framework.get_variables_by_name( 507 | 'moving_mean')[0] 508 | moving_variance = tf.contrib.framework.get_variables_by_name( 509 | 'moving_variance')[0] 510 | initial_loss = sess.run(model.total_loss) 511 | initial_mean, initial_variance = sess.run([moving_mean, 512 | moving_variance]) 513 | self.assertAllClose(initial_mean, [0.0, 0.0, 0.0, 0.0]) 514 | self.assertAllClose(initial_variance, [1.0, 1.0, 1.0, 1.0]) 515 | for _ in range(10): 516 | sess.run(model.train_op) 517 | final_loss = sess.run(model.total_loss) 518 | self.assertLess(final_loss, initial_loss / 10.0) 519 | 520 | final_mean, final_variance = sess.run([moving_mean, 521 | moving_variance]) 522 | self.assertAllClose(final_mean, [0.125, 0.25, 0.375, 0.25]) 523 | self.assertAllClose(final_variance, [0.109375, 0.1875, 524 | 0.234375, 0.1875]) 525 | 526 | def testNoSummariesOnGPU(self): 527 | with tf.Graph().as_default(): 528 | deploy_config = model_deploy.DeploymentConfig(num_clones=2) 529 | 530 | # clone function creates a fully_connected layer with a regularizer loss. 531 | def ModelFn(): 532 | inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32) 533 | reg = tf.contrib.layers.l2_regularizer(0.001) 534 | tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg) 535 | 536 | model = model_deploy.deploy( 537 | deploy_config, ModelFn, 538 | optimizer=tf.train.GradientDescentOptimizer(1.0)) 539 | # The model summary op should have a few summary inputs and all of them 540 | # should be on the CPU. 541 | self.assertTrue(model.summary_op.op.inputs) 542 | for inp in model.summary_op.op.inputs: 543 | self.assertEqual('/device:CPU:0', inp.device) 544 | 545 | def testNoSummariesOnGPUForEvals(self): 546 | with tf.Graph().as_default(): 547 | deploy_config = model_deploy.DeploymentConfig(num_clones=2) 548 | 549 | # clone function creates a fully_connected layer with a regularizer loss. 550 | def ModelFn(): 551 | inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32) 552 | reg = tf.contrib.layers.l2_regularizer(0.001) 553 | tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg) 554 | 555 | # No optimizer here, it's an eval. 556 | model = model_deploy.deploy(deploy_config, ModelFn) 557 | # The model summary op should have a few summary inputs and all of them 558 | # should be on the CPU. 559 | self.assertTrue(model.summary_op.op.inputs) 560 | for inp in model.summary_op.op.inputs: 561 | self.assertEqual('/device:CPU:0', inp.device) 562 | 563 | 564 | if __name__ == '__main__': 565 | tf.test.main() 566 | -------------------------------------------------------------------------------- /src/deployment/model_deploy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Deploy Slim models across multiple clones and replicas. 16 | 17 | # TODO(sguada) docstring paragraph by (a) motivating the need for the file and 18 | # (b) defining clones. 19 | 20 | # TODO(sguada) describe the high-level components of model deployment. 21 | # E.g. "each model deployment is composed of several parts: a DeploymentConfig, 22 | # which captures A, B and C, an input_fn which loads data.. etc 23 | 24 | To easily train a model on multiple GPUs or across multiple machines this 25 | module provides a set of helper functions: `create_clones`, 26 | `optimize_clones` and `deploy`. 27 | 28 | Usage: 29 | 30 | g = tf.Graph() 31 | 32 | # Set up DeploymentConfig 33 | config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True) 34 | 35 | # Create the global step on the device storing the variables. 36 | with tf.device(config.variables_device()): 37 | global_step = slim.create_global_step() 38 | 39 | # Define the inputs 40 | with tf.device(config.inputs_device()): 41 | images, labels = LoadData(...) 42 | inputs_queue = slim.data.prefetch_queue((images, labels)) 43 | 44 | # Define the optimizer. 45 | with tf.device(config.optimizer_device()): 46 | optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) 47 | 48 | # Define the model including the loss. 49 | def model_fn(inputs_queue): 50 | images, labels = inputs_queue.dequeue() 51 | predictions = CreateNetwork(images) 52 | slim.losses.log_loss(predictions, labels) 53 | 54 | model_dp = model_deploy.deploy(config, model_fn, [inputs_queue], 55 | optimizer=optimizer) 56 | 57 | # Run training. 58 | slim.learning.train(model_dp.train_op, my_log_dir, 59 | summary_op=model_dp.summary_op) 60 | 61 | The Clone namedtuple holds together the values associated with each call to 62 | model_fn: 63 | * outputs: The return values of the calls to `model_fn()`. 64 | * scope: The scope used to create the clone. 65 | * device: The device used to create the clone. 66 | 67 | DeployedModel namedtuple, holds together the values needed to train multiple 68 | clones: 69 | * train_op: An operation that run the optimizer training op and include 70 | all the update ops created by `model_fn`. Present only if an optimizer 71 | was specified. 72 | * summary_op: An operation that run the summaries created by `model_fn` 73 | and process_gradients. 74 | * total_loss: A `Tensor` that contains the sum of all losses created by 75 | `model_fn` plus the regularization losses. 76 | * clones: List of `Clone` tuples returned by `create_clones()`. 77 | 78 | DeploymentConfig parameters: 79 | * num_clones: Number of model clones to deploy in each replica. 80 | * clone_on_cpu: True if clones should be placed on CPU. 81 | * replica_id: Integer. Index of the replica for which the model is 82 | deployed. Usually 0 for the chief replica. 83 | * num_replicas: Number of replicas to use. 84 | * num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 85 | * worker_job_name: A name for the worker job. 86 | * ps_job_name: A name for the parameter server job. 87 | 88 | TODO(sguada): 89 | - describe side effect to the graph. 90 | - what happens to summaries and update_ops. 91 | - which graph collections are altered. 92 | - write a tutorial on how to use this. 93 | - analyze the possibility of calling deploy more than once. 94 | 95 | 96 | """ 97 | 98 | from __future__ import absolute_import 99 | from __future__ import division 100 | from __future__ import print_function 101 | 102 | import collections 103 | 104 | import tensorflow as tf 105 | 106 | from tensorflow.python.ops import control_flow_ops 107 | 108 | slim = tf.contrib.slim 109 | 110 | 111 | __all__ = ['create_clones', 112 | 'deploy', 113 | 'optimize_clones', 114 | 'DeployedModel', 115 | 'DeploymentConfig', 116 | 'Clone', 117 | ] 118 | 119 | 120 | # Namedtuple used to represent a clone during deployment. 121 | Clone = collections.namedtuple('Clone', 122 | ['outputs', # Whatever model_fn() returned. 123 | 'scope', # The scope used to create it. 124 | 'device', # The device used to create. 125 | ]) 126 | 127 | # Namedtuple used to represent a DeployedModel, returned by deploy(). 128 | DeployedModel = collections.namedtuple('DeployedModel', 129 | ['train_op', # The `train_op` 130 | 'summary_op', # The `summary_op` 131 | 'total_loss', # The loss `Tensor` 132 | 'clones', # A list of `Clones` tuples. 133 | ]) 134 | 135 | # Default parameters for DeploymentConfig 136 | _deployment_params = {'num_clones': 1, 137 | 'clone_on_cpu': False, 138 | 'replica_id': 0, 139 | 'num_replicas': 1, 140 | 'num_ps_tasks': 0, 141 | 'worker_job_name': 'worker', 142 | 'ps_job_name': 'ps'} 143 | 144 | 145 | def create_clones(config, model_fn, args=None, kwargs=None): 146 | """Creates multiple clones according to config using a `model_fn`. 147 | 148 | The returned values of `model_fn(*args, **kwargs)` are collected along with 149 | the scope and device used to created it in a namedtuple 150 | `Clone(outputs, scope, device)` 151 | 152 | Note: it is assumed that any loss created by `model_fn` is collected at 153 | the tf.GraphKeys.LOSSES collection. 154 | 155 | To recover the losses, summaries or update_ops created by the clone use: 156 | ```python 157 | losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 158 | summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope) 159 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope) 160 | ``` 161 | 162 | The deployment options are specified by the config object and support 163 | deploying one or several clones on different GPUs and one or several replicas 164 | of such clones. 165 | 166 | The argument `model_fn` is called `config.num_clones` times to create the 167 | model clones as `model_fn(*args, **kwargs)`. 168 | 169 | If `config` specifies deployment on multiple replicas then the default 170 | tensorflow device is set appropriatly for each call to `model_fn` and for the 171 | slim variable creation functions: model and global variables will be created 172 | on the `ps` device, the clone operations will be on the `worker` device. 173 | 174 | Args: 175 | config: A DeploymentConfig object. 176 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 177 | args: Optional list of arguments to pass to `model_fn`. 178 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 179 | 180 | Returns: 181 | A list of namedtuples `Clone`. 182 | """ 183 | clones = [] 184 | args = args or [] 185 | kwargs = kwargs or {} 186 | with slim.arg_scope([slim.model_variable, slim.variable], 187 | device=config.variables_device()): 188 | # Create clones. 189 | for i in range(0, config.num_clones): 190 | with tf.name_scope(config.clone_scope(i)) as clone_scope: 191 | clone_device = config.clone_device(i) 192 | with tf.device(clone_device): 193 | with tf.variable_scope(tf.get_variable_scope(), 194 | reuse=True if i > 0 else None): 195 | outputs = model_fn(*args, **kwargs) 196 | clones.append(Clone(outputs, clone_scope, clone_device)) 197 | return clones 198 | 199 | 200 | def _gather_clone_loss(clone, num_clones, regularization_losses): 201 | """Gather the loss for a single clone. 202 | 203 | Args: 204 | clone: A Clone namedtuple. 205 | num_clones: The number of clones being deployed. 206 | regularization_losses: Possibly empty list of regularization_losses 207 | to add to the clone losses. 208 | 209 | Returns: 210 | A tensor for the total loss for the clone. Can be None. 211 | """ 212 | # The return value. 213 | sum_loss = None 214 | # Individual components of the loss that will need summaries. 215 | clone_loss = None 216 | regularization_loss = None 217 | # Compute and aggregate losses on the clone device. 218 | with tf.device(clone.device): 219 | all_losses = [] 220 | clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope) 221 | if clone_losses: 222 | clone_loss = tf.add_n(clone_losses, name='clone_loss') 223 | if num_clones > 1: 224 | clone_loss = tf.div(clone_loss, 1.0 * num_clones, 225 | name='scaled_clone_loss') 226 | all_losses.append(clone_loss) 227 | if regularization_losses: 228 | regularization_loss = tf.add_n(regularization_losses, 229 | name='regularization_loss') 230 | all_losses.append(regularization_loss) 231 | if all_losses: 232 | sum_loss = tf.add_n(all_losses) 233 | # Add the summaries out of the clone device block. 234 | if clone_loss is not None: 235 | tf.summary.scalar(clone.scope + '/clone_loss', clone_loss) 236 | if regularization_loss is not None: 237 | tf.summary.scalar('regularization_loss', regularization_loss) 238 | return sum_loss 239 | 240 | 241 | def _optimize_clone(optimizer, clone, num_clones, regularization_losses, 242 | **kwargs): 243 | """Compute losses and gradients for a single clone. 244 | 245 | Args: 246 | optimizer: A tf.Optimizer object. 247 | clone: A Clone namedtuple. 248 | num_clones: The number of clones being deployed. 249 | regularization_losses: Possibly empty list of regularization_losses 250 | to add to the clone losses. 251 | **kwargs: Dict of kwarg to pass to compute_gradients(). 252 | 253 | Returns: 254 | A tuple (clone_loss, clone_grads_and_vars). 255 | - clone_loss: A tensor for the total loss for the clone. Can be None. 256 | - clone_grads_and_vars: List of (gradient, variable) for the clone. 257 | Can be empty. 258 | """ 259 | sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses) 260 | clone_grad = None 261 | if sum_loss is not None: 262 | with tf.device(clone.device): 263 | clone_grad = optimizer.compute_gradients(sum_loss, **kwargs) 264 | return sum_loss, clone_grad 265 | 266 | 267 | def optimize_clones(clones, optimizer, 268 | regularization_losses=None, 269 | **kwargs): 270 | """Compute clone losses and gradients for the given list of `Clones`. 271 | 272 | Note: The regularization_losses are added to the first clone losses. 273 | 274 | Args: 275 | clones: List of `Clones` created by `create_clones()`. 276 | optimizer: An `Optimizer` object. 277 | regularization_losses: Optional list of regularization losses. If None it 278 | will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to 279 | exclude them. 280 | **kwargs: Optional list of keyword arguments to pass to `compute_gradients`. 281 | 282 | Returns: 283 | A tuple (total_loss, grads_and_vars). 284 | - total_loss: A Tensor containing the average of the clone losses including 285 | the regularization loss. 286 | - grads_and_vars: A List of tuples (gradient, variable) containing the sum 287 | of the gradients for each variable. 288 | 289 | """ 290 | grads_and_vars = [] 291 | clones_losses = [] 292 | num_clones = len(clones) 293 | if regularization_losses is None: 294 | regularization_losses = tf.get_collection( 295 | tf.GraphKeys.REGULARIZATION_LOSSES) 296 | for clone in clones: 297 | with tf.name_scope(clone.scope): 298 | clone_loss, clone_grad = _optimize_clone( 299 | optimizer, clone, num_clones, regularization_losses, **kwargs) 300 | if clone_loss is not None: 301 | clones_losses.append(clone_loss) 302 | grads_and_vars.append(clone_grad) 303 | # Only use regularization_losses for the first clone 304 | regularization_losses = None 305 | # Compute the total_loss summing all the clones_losses. 306 | total_loss = tf.add_n(clones_losses, name='total_loss') 307 | # Sum the gradients across clones. 308 | grads_and_vars = _sum_clones_gradients(grads_and_vars) 309 | return total_loss, grads_and_vars 310 | 311 | 312 | def deploy(config, 313 | model_fn, 314 | args=None, 315 | kwargs=None, 316 | optimizer=None, 317 | summarize_gradients=False): 318 | """Deploys a Slim-constructed model across multiple clones. 319 | 320 | The deployment options are specified by the config object and support 321 | deploying one or several clones on different GPUs and one or several replicas 322 | of such clones. 323 | 324 | The argument `model_fn` is called `config.num_clones` times to create the 325 | model clones as `model_fn(*args, **kwargs)`. 326 | 327 | The optional argument `optimizer` is an `Optimizer` object. If not `None`, 328 | the deployed model is configured for training with that optimizer. 329 | 330 | If `config` specifies deployment on multiple replicas then the default 331 | tensorflow device is set appropriatly for each call to `model_fn` and for the 332 | slim variable creation functions: model and global variables will be created 333 | on the `ps` device, the clone operations will be on the `worker` device. 334 | 335 | Args: 336 | config: A `DeploymentConfig` object. 337 | model_fn: A callable. Called as `model_fn(*args, **kwargs)` 338 | args: Optional list of arguments to pass to `model_fn`. 339 | kwargs: Optional list of keyword arguments to pass to `model_fn`. 340 | optimizer: Optional `Optimizer` object. If passed the model is deployed 341 | for training with that optimizer. 342 | summarize_gradients: Whether or not add summaries to the gradients. 343 | 344 | Returns: 345 | A `DeployedModel` namedtuple. 346 | 347 | """ 348 | # Gather initial summaries. 349 | summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES)) 350 | 351 | # Create Clones. 352 | clones = create_clones(config, model_fn, args, kwargs) 353 | first_clone = clones[0] 354 | 355 | # Gather update_ops from the first clone. These contain, for example, 356 | # the updates for the batch_norm variables created by model_fn. 357 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope) 358 | 359 | train_op = None 360 | total_loss = None 361 | with tf.device(config.optimizer_device()): 362 | if optimizer: 363 | # Place the global step on the device storing the variables. 364 | with tf.device(config.variables_device()): 365 | global_step = slim.get_or_create_global_step() 366 | 367 | # Compute the gradients for the clones. 368 | total_loss, clones_gradients = optimize_clones(clones, optimizer) 369 | 370 | if clones_gradients: 371 | if summarize_gradients: 372 | # Add summaries to the gradients. 373 | summaries |= set(_add_gradients_summaries(clones_gradients)) 374 | 375 | # Create gradient updates. 376 | grad_updates = optimizer.apply_gradients(clones_gradients, 377 | global_step=global_step) 378 | update_ops.append(grad_updates) 379 | 380 | update_op = tf.group(*update_ops) 381 | with tf.control_dependencies([update_op]): 382 | train_op = tf.identity(total_loss, name='train_op') 383 | else: 384 | clones_losses = [] 385 | regularization_losses = tf.get_collection( 386 | tf.GraphKeys.REGULARIZATION_LOSSES) 387 | for clone in clones: 388 | with tf.name_scope(clone.scope): 389 | clone_loss = _gather_clone_loss(clone, len(clones), 390 | regularization_losses) 391 | if clone_loss is not None: 392 | clones_losses.append(clone_loss) 393 | # Only use regularization_losses for the first clone 394 | regularization_losses = None 395 | if clones_losses: 396 | total_loss = tf.add_n(clones_losses, name='total_loss') 397 | 398 | # Add the summaries from the first clone. These contain the summaries 399 | # created by model_fn and either optimize_clones() or _gather_clone_loss(). 400 | summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES, 401 | first_clone.scope)) 402 | 403 | if total_loss is not None: 404 | # Add total_loss to summary. 405 | summaries.add(tf.summary.scalar('total_loss', total_loss)) 406 | 407 | if summaries: 408 | # Merge all summaries together. 409 | summary_op = tf.summary.merge(list(summaries), name='summary_op') 410 | else: 411 | summary_op = None 412 | 413 | return DeployedModel(train_op, summary_op, total_loss, clones) 414 | 415 | 416 | def _sum_clones_gradients(clone_grads): 417 | """Calculate the sum gradient for each shared variable across all clones. 418 | 419 | This function assumes that the clone_grads has been scaled appropriately by 420 | 1 / num_clones. 421 | 422 | Args: 423 | clone_grads: A List of List of tuples (gradient, variable), one list per 424 | `Clone`. 425 | 426 | Returns: 427 | List of tuples of (gradient, variable) where the gradient has been summed 428 | across all clones. 429 | """ 430 | sum_grads = [] 431 | for grad_and_vars in zip(*clone_grads): 432 | # Note that each grad_and_vars looks like the following: 433 | # ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN)) 434 | grads = [] 435 | var = grad_and_vars[0][1] 436 | for g, v in grad_and_vars: 437 | assert v == var 438 | if g is not None: 439 | grads.append(g) 440 | if grads: 441 | if len(grads) > 1: 442 | sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads') 443 | else: 444 | sum_grad = grads[0] 445 | sum_grads.append((sum_grad, var)) 446 | return sum_grads 447 | 448 | 449 | def _add_gradients_summaries(grads_and_vars): 450 | """Add histogram summaries to gradients. 451 | 452 | Note: The summaries are also added to the SUMMARIES collection. 453 | 454 | Args: 455 | grads_and_vars: A list of gradient to variable pairs (tuples). 456 | 457 | Returns: 458 | The _list_ of the added summaries for grads_and_vars. 459 | """ 460 | summaries = [] 461 | for grad, var in grads_and_vars: 462 | if grad is not None: 463 | if isinstance(grad, tf.IndexedSlices): 464 | grad_values = grad.values 465 | else: 466 | grad_values = grad 467 | summaries.append(tf.summary.histogram(var.op.name + ':gradient', 468 | grad_values)) 469 | summaries.append(tf.summary.histogram(var.op.name + ':gradient_norm', 470 | tf.global_norm([grad_values]))) 471 | else: 472 | tf.logging.info('Var %s has no gradient', var.op.name) 473 | return summaries 474 | 475 | 476 | class DeploymentConfig(object): 477 | """Configuration for deploying a model with `deploy()`. 478 | 479 | You can pass an instance of this class to `deploy()` to specify exactly 480 | how to deploy the model to build. If you do not pass one, an instance built 481 | from the default deployment_hparams will be used. 482 | """ 483 | 484 | def __init__(self, 485 | num_clones=1, 486 | clone_on_cpu=False, 487 | replica_id=0, 488 | num_replicas=1, 489 | num_ps_tasks=0, 490 | worker_job_name='worker', 491 | ps_job_name='ps'): 492 | """Create a DeploymentConfig. 493 | 494 | The config describes how to deploy a model across multiple clones and 495 | replicas. The model will be replicated `num_clones` times in each replica. 496 | If `clone_on_cpu` is True, each clone will placed on CPU. 497 | 498 | If `num_replicas` is 1, the model is deployed via a single process. In that 499 | case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored. 500 | 501 | If `num_replicas` is greater than 1, then `worker_device` and `ps_device` 502 | must specify TensorFlow devices for the `worker` and `ps` jobs and 503 | `num_ps_tasks` must be positive. 504 | 505 | Args: 506 | num_clones: Number of model clones to deploy in each replica. 507 | clone_on_cpu: If True clones would be placed on CPU. 508 | replica_id: Integer. Index of the replica for which the model is 509 | deployed. Usually 0 for the chief replica. 510 | num_replicas: Number of replicas to use. 511 | num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas. 512 | worker_job_name: A name for the worker job. 513 | ps_job_name: A name for the parameter server job. 514 | 515 | Raises: 516 | ValueError: If the arguments are invalid. 517 | """ 518 | if num_replicas > 1: 519 | if num_ps_tasks < 1: 520 | raise ValueError('When using replicas num_ps_tasks must be positive') 521 | if num_replicas > 1 or num_ps_tasks > 0: 522 | if not worker_job_name: 523 | raise ValueError('Must specify worker_job_name when using replicas') 524 | if not ps_job_name: 525 | raise ValueError('Must specify ps_job_name when using parameter server') 526 | if replica_id >= num_replicas: 527 | raise ValueError('replica_id must be less than num_replicas') 528 | self._num_clones = num_clones 529 | self._clone_on_cpu = clone_on_cpu 530 | self._replica_id = replica_id 531 | self._num_replicas = num_replicas 532 | self._num_ps_tasks = num_ps_tasks 533 | self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else '' 534 | self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else '' 535 | 536 | @property 537 | def num_clones(self): 538 | return self._num_clones 539 | 540 | @property 541 | def clone_on_cpu(self): 542 | return self._clone_on_cpu 543 | 544 | @property 545 | def replica_id(self): 546 | return self._replica_id 547 | 548 | @property 549 | def num_replicas(self): 550 | return self._num_replicas 551 | 552 | @property 553 | def num_ps_tasks(self): 554 | return self._num_ps_tasks 555 | 556 | @property 557 | def ps_device(self): 558 | return self._ps_device 559 | 560 | @property 561 | def worker_device(self): 562 | return self._worker_device 563 | 564 | def caching_device(self): 565 | """Returns the device to use for caching variables. 566 | 567 | Variables are cached on the worker CPU when using replicas. 568 | 569 | Returns: 570 | A device string or None if the variables do not need to be cached. 571 | """ 572 | if self._num_ps_tasks > 0: 573 | return lambda op: op.device 574 | else: 575 | return None 576 | 577 | def clone_device(self, clone_index): 578 | """Device used to create the clone and all the ops inside the clone. 579 | 580 | Args: 581 | clone_index: Int, representing the clone_index. 582 | 583 | Returns: 584 | A value suitable for `tf.device()`. 585 | 586 | Raises: 587 | ValueError: if `clone_index` is greater or equal to the number of clones". 588 | """ 589 | if clone_index >= self._num_clones: 590 | raise ValueError('clone_index must be less than num_clones') 591 | device = '' 592 | if self._num_ps_tasks > 0: 593 | device += self._worker_device 594 | if self._clone_on_cpu: 595 | device += '/device:CPU:0' 596 | else: 597 | if self._num_clones > 1: 598 | device += '/device:GPU:%d' % clone_index 599 | return device 600 | 601 | def clone_scope(self, clone_index): 602 | """Name scope to create the clone. 603 | 604 | Args: 605 | clone_index: Int, representing the clone_index. 606 | 607 | Returns: 608 | A name_scope suitable for `tf.name_scope()`. 609 | 610 | Raises: 611 | ValueError: if `clone_index` is greater or equal to the number of clones". 612 | """ 613 | if clone_index >= self._num_clones: 614 | raise ValueError('clone_index must be less than num_clones') 615 | scope = '' 616 | if self._num_clones > 1: 617 | scope = 'clone_%d' % clone_index 618 | return scope 619 | 620 | def optimizer_device(self): 621 | """Device to use with the optimizer. 622 | 623 | Returns: 624 | A value suitable for `tf.device()`. 625 | """ 626 | if self._num_ps_tasks > 0 or self._num_clones > 0: 627 | return self._worker_device + '/device:CPU:0' 628 | else: 629 | return '' 630 | 631 | def inputs_device(self): 632 | """Device to use to build the inputs. 633 | 634 | Returns: 635 | A value suitable for `tf.device()`. 636 | """ 637 | device = '' 638 | if self._num_ps_tasks > 0: 639 | device += self._worker_device 640 | device += '/device:CPU:0' 641 | return device 642 | 643 | def variables_device(self): 644 | """Returns the device to use for variables created inside the clone. 645 | 646 | Returns: 647 | A value suitable for `tf.device()`. 648 | """ 649 | device = '' 650 | if self._num_ps_tasks > 0: 651 | device += self._ps_device 652 | device += '/device:CPU:0' 653 | 654 | class _PSDeviceChooser(object): 655 | """Slim device chooser for variables when using PS.""" 656 | 657 | def __init__(self, device, tasks): 658 | self._device = device 659 | self._tasks = tasks 660 | self._task = 0 661 | 662 | def choose(self, op): 663 | if op.device: 664 | return op.device 665 | node_def = op if isinstance(op, tf.NodeDef) else op.node_def 666 | if node_def.op == 'Variable': 667 | t = self._task 668 | self._task = (self._task + 1) % self._tasks 669 | d = '%s/task:%d' % (self._device, t) 670 | return d 671 | else: 672 | return op.device 673 | 674 | if not self._num_ps_tasks: 675 | return device 676 | else: 677 | chooser = _PSDeviceChooser(device, self._num_ps_tasks) 678 | return chooser.choose 679 | --------------------------------------------------------------------------------