├── README.md ├── imagenet_main.py ├── imagenet_preprocessing.py ├── resnet_model.py ├── resnet_run_loop.py └── utils ├── __init__.py ├── export ├── __init__.py └── export.py ├── flags ├── __init__.py ├── _base.py ├── _benchmark.py ├── _conventions.py ├── _device.py ├── _misc.py ├── _performance.py └── core.py ├── logs ├── __init__.py ├── cloud_lib.py ├── hooks.py ├── hooks_helper.py ├── logger.py ├── metric_hook.py └── mlperf_helper.py └── misc ├── __init__.py ├── distribution_utils.py └── model_helpers.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep-Compressive-Offloading 2 | Deep Compressive Offloading: Speeding Up Neural Network Inference by Trading Edge Computation for Network Latency 3 | -------------------------------------------------------------------------------- /imagenet_main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Runs a ResNet model on the ImageNet dataset.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | 23 | from absl import app as absl_app 24 | from absl import flags 25 | import tensorflow as tf # pylint: disable=g-bad-import-order 26 | 27 | from utils.flags import core as flags_core 28 | from utils.logs import logger 29 | import imagenet_preprocessing 30 | import resnet_model 31 | import resnet_run_loop 32 | 33 | _DEFAULT_IMAGE_SIZE = 224 34 | _NUM_CHANNELS = 3 35 | _NUM_CLASSES = 1001 36 | 37 | _NUM_IMAGES = { 38 | 'train': 1281167, 39 | 'validation': 50000, 40 | } 41 | 42 | _NUM_TRAIN_FILES = 1024 43 | _SHUFFLE_BUFFER = 10000 44 | 45 | DATASET_NAME = 'ImageNet' 46 | 47 | ############################################################################### 48 | # Data processing 49 | ############################################################################### 50 | def get_filenames(is_training, data_dir): 51 | """Return filenames for dataset.""" 52 | if is_training: 53 | return [ 54 | os.path.join(data_dir, 'train-%05d-of-01024' % i) 55 | for i in range(_NUM_TRAIN_FILES)] 56 | else: 57 | return [ 58 | os.path.join(data_dir, 'validation-%05d-of-00128' % i) 59 | for i in range(128)] 60 | 61 | 62 | def _parse_example_proto(example_serialized): 63 | """Parses an Example proto containing a training example of an image. 64 | 65 | The output of the build_image_data.py image preprocessing script is a dataset 66 | containing serialized Example protocol buffers. Each Example proto contains 67 | the following fields (values are included as examples): 68 | 69 | image/height: 462 70 | image/width: 581 71 | image/colorspace: 'RGB' 72 | image/channels: 3 73 | image/class/label: 615 74 | image/class/synset: 'n03623198' 75 | image/class/text: 'knee pad' 76 | image/object/bbox/xmin: 0.1 77 | image/object/bbox/xmax: 0.9 78 | image/object/bbox/ymin: 0.2 79 | image/object/bbox/ymax: 0.6 80 | image/object/bbox/label: 615 81 | image/format: 'JPEG' 82 | image/filename: 'ILSVRC2012_val_00041207.JPEG' 83 | image/encoded: 84 | 85 | Args: 86 | example_serialized: scalar Tensor tf.string containing a serialized 87 | Example protocol buffer. 88 | 89 | Returns: 90 | image_buffer: Tensor tf.string containing the contents of a JPEG file. 91 | label: Tensor tf.int32 containing the label. 92 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 93 | where each coordinate is [0, 1) and the coordinates are arranged as 94 | [ymin, xmin, ymax, xmax]. 95 | """ 96 | # Dense features in Example proto. 97 | feature_map = { 98 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, 99 | default_value=''), 100 | 'image/class/label': tf.FixedLenFeature([], dtype=tf.int64, 101 | default_value=-1), 102 | 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, 103 | default_value=''), 104 | } 105 | sparse_float32 = tf.VarLenFeature(dtype=tf.float32) 106 | # Sparse features in Example proto. 107 | feature_map.update( 108 | {k: sparse_float32 for k in ['image/object/bbox/xmin', 109 | 'image/object/bbox/ymin', 110 | 'image/object/bbox/xmax', 111 | 'image/object/bbox/ymax']}) 112 | 113 | features = tf.parse_single_example(example_serialized, feature_map) 114 | label = tf.cast(features['image/class/label'], dtype=tf.int32) 115 | 116 | xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) 117 | ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) 118 | xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0) 119 | ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0) 120 | 121 | # Note that we impose an ordering of (y, x) just to make life difficult. 122 | bbox = tf.concat([ymin, xmin, ymax, xmax], 0) 123 | 124 | # Force the variable number of bounding boxes into the shape 125 | # [1, num_boxes, coords]. 126 | bbox = tf.expand_dims(bbox, 0) 127 | bbox = tf.transpose(bbox, [0, 2, 1]) 128 | 129 | return features['image/encoded'], label, bbox 130 | 131 | 132 | def parse_record(raw_record, is_training, dtype): 133 | """Parses a record containing a training example of an image. 134 | 135 | The input record is parsed into a label and image, and the image is passed 136 | through preprocessing steps (cropping, flipping, and so on). 137 | 138 | Args: 139 | raw_record: scalar Tensor tf.string containing a serialized 140 | Example protocol buffer. 141 | is_training: A boolean denoting whether the input is for training. 142 | dtype: data type to use for images/features. 143 | 144 | Returns: 145 | Tuple with processed image tensor and one-hot-encoded label tensor. 146 | """ 147 | image_buffer, label, bbox = _parse_example_proto(raw_record) 148 | 149 | image = imagenet_preprocessing.preprocess_image( 150 | image_buffer=image_buffer, 151 | bbox=bbox, 152 | output_height=_DEFAULT_IMAGE_SIZE, 153 | output_width=_DEFAULT_IMAGE_SIZE, 154 | num_channels=_NUM_CHANNELS, 155 | is_training=is_training) 156 | image = tf.cast(image, dtype) 157 | 158 | return image, label 159 | 160 | 161 | def input_fn(is_training, data_dir, batch_size, num_epochs=1, 162 | dtype=tf.float32, datasets_num_private_threads=None, 163 | num_parallel_batches=1): 164 | """Input function which provides batches for train or eval. 165 | 166 | Args: 167 | is_training: A boolean denoting whether the input is for training. 168 | data_dir: The directory containing the input data. 169 | batch_size: The number of samples per batch. 170 | num_epochs: The number of epochs to repeat the dataset. 171 | dtype: Data type to use for images/features 172 | datasets_num_private_threads: Number of private threads for tf.data. 173 | num_parallel_batches: Number of parallel batches for tf.data. 174 | 175 | Returns: 176 | A dataset that can be used for iteration. 177 | """ 178 | filenames = get_filenames(is_training, data_dir) 179 | dataset = tf.data.Dataset.from_tensor_slices(filenames) 180 | 181 | if is_training: 182 | # Shuffle the input files 183 | dataset = dataset.shuffle(buffer_size=_NUM_TRAIN_FILES) 184 | 185 | # Convert to individual records. 186 | # cycle_length = 10 means 10 files will be read and deserialized in parallel. 187 | # This number is low enough to not cause too much contention on small systems 188 | # but high enough to provide the benefits of parallelization. You may want 189 | # to increase this number if you have a large number of CPU cores. 190 | dataset = dataset.apply(tf.contrib.data.parallel_interleave( 191 | tf.data.TFRecordDataset, cycle_length=10)) 192 | 193 | return resnet_run_loop.process_record_dataset( 194 | dataset=dataset, 195 | is_training=is_training, 196 | batch_size=batch_size, 197 | shuffle_buffer=_SHUFFLE_BUFFER, 198 | parse_record_fn=parse_record, 199 | num_epochs=num_epochs, 200 | dtype=dtype, 201 | datasets_num_private_threads=datasets_num_private_threads, 202 | num_parallel_batches=num_parallel_batches 203 | ) 204 | 205 | 206 | def get_synth_input_fn(dtype): 207 | return resnet_run_loop.get_synth_input_fn( 208 | _DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS, _NUM_CLASSES, 209 | dtype=dtype) 210 | 211 | 212 | ############################################################################### 213 | # Running the model 214 | ############################################################################### 215 | class ImagenetModel(resnet_model.Model): 216 | """Model class with appropriate defaults for Imagenet data.""" 217 | 218 | def __init__(self, resnet_size, data_format=None, spectral_norm=False, 219 | reuse=False, offload=False, compress_ratio=0.05, 220 | num_classes=_NUM_CLASSES, 221 | resnet_version=resnet_model.DEFAULT_VERSION, 222 | dtype=resnet_model.DEFAULT_DTYPE): 223 | """These are the parameters that work for Imagenet data. 224 | 225 | Args: 226 | resnet_size: The number of convolutional layers needed in the model. 227 | data_format: Either 'channels_first' or 'channels_last', specifying which 228 | data format to use when setting up the model. 229 | num_classes: The number of output classes needed from the model. This 230 | enables users to extend the same model to their own datasets. 231 | resnet_version: Integer representing which version of the ResNet network 232 | to use. See README for details. Valid values: [1, 2] 233 | dtype: The TensorFlow dtype to use for calculations. 234 | """ 235 | 236 | # For bigger models, we want to use "bottleneck" layers 237 | if resnet_size < 50: 238 | bottleneck = False 239 | else: 240 | bottleneck = True 241 | 242 | super(ImagenetModel, self).__init__( 243 | resnet_size=resnet_size, 244 | bottleneck=bottleneck, 245 | num_classes=num_classes, 246 | num_filters=64, 247 | kernel_size=7, 248 | conv_stride=2, 249 | first_pool_size=3, 250 | first_pool_stride=2, 251 | block_sizes=_get_block_sizes(resnet_size), 252 | block_strides=[1, 2, 2, 2], 253 | resnet_version=resnet_version, 254 | data_format=data_format, 255 | dtype=dtype, 256 | spectral_norm=spectral_norm, 257 | reuse=reuse, 258 | offload=offload, 259 | compress_ratio=compress_ratio 260 | ) 261 | 262 | 263 | def _get_block_sizes(resnet_size): 264 | """Retrieve the size of each block_layer in the ResNet model. 265 | 266 | The number of block layers used for the Resnet model varies according 267 | to the size of the model. This helper grabs the layer set we want, throwing 268 | an error if a non-standard size has been selected. 269 | 270 | Args: 271 | resnet_size: The number of convolutional layers needed in the model. 272 | 273 | Returns: 274 | A list of block sizes to use in building the model. 275 | 276 | Raises: 277 | KeyError: if invalid resnet_size is received. 278 | """ 279 | choices = { 280 | 18: [2, 2, 2, 2], 281 | 34: [3, 4, 6, 3], 282 | 50: [3, 4, 6, 3], 283 | 101: [3, 4, 23, 3], 284 | 152: [3, 8, 36, 3], 285 | 200: [3, 24, 36, 3] 286 | } 287 | 288 | try: 289 | return choices[resnet_size] 290 | except KeyError: 291 | err = ('Could not find layers for selected Resnet size.\n' 292 | 'Size received: {}; sizes allowed: {}.'.format( 293 | resnet_size, choices.keys())) 294 | raise ValueError(err) 295 | 296 | 297 | def imagenet_model_fn(features, labels, mode, params): 298 | """Our model_fn for ResNet to be used with our Estimator.""" 299 | 300 | # Warmup and higher lr may not be valid for fine tuning with small batches 301 | # and smaller numbers of training images. 302 | if params['fine_tune']: 303 | warmup = False 304 | if params['no_dense_init']: 305 | base_lr = 0.1 306 | # boundary_ep = [2, 4, 6, 8] 307 | boundary_ep = [4, 8, 10, 12] 308 | # boundary_ep = [5, 10, 13, 16] 309 | else: 310 | base_lr = 0.01 311 | boundary_ep = [4, 8, 10, 12] 312 | # boundary_ep = [5, 10, 13, 16] 313 | else: 314 | warmup = True 315 | base_lr = .128 316 | boundary_ep = [30, 60, 80, 90] 317 | 318 | learning_rate_fn = resnet_run_loop.learning_rate_with_decay( 319 | batch_size=params['batch_size'], batch_denom=256, 320 | num_images=_NUM_IMAGES['train'], boundary_epochs=boundary_ep, 321 | decay_rates=[1, 0.1, 0.01, 0.001, 1e-4], warmup=warmup, base_lr=base_lr) 322 | 323 | return resnet_run_loop.resnet_model_fn( 324 | features=features, 325 | labels=labels, 326 | mode=mode, 327 | model_class=ImagenetModel, 328 | resnet_size=params['resnet_size'], 329 | weight_decay=1e-4, 330 | learning_rate_fn=learning_rate_fn, 331 | momentum=0.9, 332 | data_format=params['data_format'], 333 | resnet_version=params['resnet_version'], 334 | loss_scale=params['loss_scale'], 335 | loss_filter_fn=None, 336 | dtype=params['dtype'], 337 | fine_tune=params['fine_tune'], 338 | reconst_loss_scale=params['reconst_loss_scale'], 339 | use_ce=params['use_ce'], 340 | opt_chos=params['optimizer'], 341 | clip_grad=params['clip_grad'], 342 | spectral_norm=params['spectral_norm'], 343 | ce_scale=params['ce_scale'], 344 | sep_grad_nrom=params['sep_grad_nrom'], 345 | norm_teach_feature=params['norm_teach_feature'], 346 | compress_ratio=params['compress_ratio'] 347 | ) 348 | 349 | 350 | def define_imagenet_flags(): 351 | resnet_run_loop.define_resnet_flags( 352 | resnet_size_choices=['18', '34', '50', '101', '152', '200']) 353 | flags.adopt_module_key_flags(resnet_run_loop) 354 | flags_core.set_defaults(train_epochs=90) 355 | 356 | 357 | def run_imagenet(flags_obj): 358 | """Run ResNet ImageNet training and eval loop. 359 | 360 | Args: 361 | flags_obj: An object containing parsed flag values. 362 | """ 363 | input_function = (flags_obj.use_synthetic_data and 364 | get_synth_input_fn(flags_core.get_tf_dtype(flags_obj)) or 365 | input_fn) 366 | 367 | resnet_run_loop.resnet_main( 368 | flags_obj, imagenet_model_fn, input_function, DATASET_NAME, 369 | shape=[_DEFAULT_IMAGE_SIZE, _DEFAULT_IMAGE_SIZE, _NUM_CHANNELS]) 370 | 371 | 372 | def main(_): 373 | with logger.benchmark_context(flags.FLAGS): 374 | run_imagenet(flags.FLAGS) 375 | 376 | 377 | if __name__ == '__main__': 378 | tf.logging.set_verbosity(tf.logging.INFO) 379 | define_imagenet_flags() 380 | absl_app.run(main) 381 | -------------------------------------------------------------------------------- /imagenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images. 16 | 17 | Training images are sampled using the provided bounding boxes, and subsequently 18 | cropped to the sampled bounding box. Images are additionally flipped randomly, 19 | then resized to the target output size (without aspect-ratio preservation). 20 | 21 | Images used during evaluation are resized (with aspect-ratio preservation) and 22 | centrally cropped. 23 | 24 | All images undergo mean color subtraction. 25 | 26 | Note that these steps are colloquially referred to as "ResNet preprocessing," 27 | and they differ from "VGG preprocessing," which does not use bounding boxes 28 | and instead does an aspect-preserving resize followed by random crop during 29 | training. (These both differ from "Inception preprocessing," which introduces 30 | color distortion steps.) 31 | 32 | """ 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import tensorflow as tf 39 | 40 | _R_MEAN = 123.68 41 | _G_MEAN = 116.78 42 | _B_MEAN = 103.94 43 | _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN] 44 | 45 | # The lower bound for the smallest side of the image for aspect-preserving 46 | # resizing. For example, if an image is 500 x 1000, it will be resized to 47 | # _RESIZE_MIN x (_RESIZE_MIN * 2). 48 | _RESIZE_MIN = 256 49 | 50 | 51 | def _decode_crop_and_flip(image_buffer, bbox, num_channels): 52 | """Crops the given image to a random part of the image, and randomly flips. 53 | 54 | We use the fused decode_and_crop op, which performs better than the two ops 55 | used separately in series, but note that this requires that the image be 56 | passed in as an un-decoded string Tensor. 57 | 58 | Args: 59 | image_buffer: scalar string Tensor representing the raw JPEG image buffer. 60 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 61 | where each coordinate is [0, 1) and the coordinates are arranged as 62 | [ymin, xmin, ymax, xmax]. 63 | num_channels: Integer depth of the image buffer for decoding. 64 | 65 | Returns: 66 | 3-D tensor with cropped image. 67 | 68 | """ 69 | # A large fraction of image datasets contain a human-annotated bounding box 70 | # delineating the region of the image containing the object of interest. We 71 | # choose to create a new bounding box for the object which is a randomly 72 | # distorted version of the human-annotated bounding box that obeys an 73 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 74 | # bounding box. If no box is supplied, then we assume the bounding box is 75 | # the entire image. 76 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 77 | tf.image.extract_jpeg_shape(image_buffer), 78 | bounding_boxes=bbox, 79 | min_object_covered=0.1, 80 | aspect_ratio_range=[0.75, 1.33], 81 | area_range=[0.05, 1.0], 82 | max_attempts=100, 83 | use_image_if_no_bounding_boxes=True) 84 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 85 | 86 | # Reassemble the bounding box in the format the crop op requires. 87 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 88 | target_height, target_width, _ = tf.unstack(bbox_size) 89 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 90 | 91 | # Use the fused decode and crop op here, which is faster than each in series. 92 | cropped = tf.image.decode_and_crop_jpeg( 93 | image_buffer, crop_window, channels=num_channels) 94 | 95 | # Flip to add a little more random distortion in. 96 | cropped = tf.image.random_flip_left_right(cropped) 97 | return cropped 98 | 99 | 100 | def _central_crop(image, crop_height, crop_width): 101 | """Performs central crops of the given image list. 102 | 103 | Args: 104 | image: a 3-D image tensor 105 | crop_height: the height of the image following the crop. 106 | crop_width: the width of the image following the crop. 107 | 108 | Returns: 109 | 3-D tensor with cropped image. 110 | """ 111 | shape = tf.shape(image) 112 | height, width = shape[0], shape[1] 113 | 114 | amount_to_be_cropped_h = (height - crop_height) 115 | crop_top = amount_to_be_cropped_h // 2 116 | amount_to_be_cropped_w = (width - crop_width) 117 | crop_left = amount_to_be_cropped_w // 2 118 | return tf.slice( 119 | image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) 120 | 121 | 122 | def _mean_image_subtraction(image, means, num_channels): 123 | """Subtracts the given means from each image channel. 124 | 125 | For example: 126 | means = [123.68, 116.779, 103.939] 127 | image = _mean_image_subtraction(image, means) 128 | 129 | Note that the rank of `image` must be known. 130 | 131 | Args: 132 | image: a tensor of size [height, width, C]. 133 | means: a C-vector of values to subtract from each channel. 134 | num_channels: number of color channels in the image that will be distorted. 135 | 136 | Returns: 137 | the centered image. 138 | 139 | Raises: 140 | ValueError: If the rank of `image` is unknown, if `image` has a rank other 141 | than three or if the number of channels in `image` doesn't match the 142 | number of values in `means`. 143 | """ 144 | if image.get_shape().ndims != 3: 145 | raise ValueError('Input must be of size [height, width, C>0]') 146 | 147 | if len(means) != num_channels: 148 | raise ValueError('len(means) must match the number of channels') 149 | 150 | # We have a 1-D tensor of means; convert to 3-D. 151 | means = tf.expand_dims(tf.expand_dims(means, 0), 0) 152 | 153 | return image - means 154 | 155 | 156 | def _smallest_size_at_least(height, width, resize_min): 157 | """Computes new shape with the smallest side equal to `smallest_side`. 158 | 159 | Computes new shape with the smallest side equal to `smallest_side` while 160 | preserving the original aspect ratio. 161 | 162 | Args: 163 | height: an int32 scalar tensor indicating the current height. 164 | width: an int32 scalar tensor indicating the current width. 165 | resize_min: A python integer or scalar `Tensor` indicating the size of 166 | the smallest side after resize. 167 | 168 | Returns: 169 | new_height: an int32 scalar tensor indicating the new height. 170 | new_width: an int32 scalar tensor indicating the new width. 171 | """ 172 | resize_min = tf.cast(resize_min, tf.float32) 173 | 174 | # Convert to floats to make subsequent calculations go smoothly. 175 | height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) 176 | 177 | smaller_dim = tf.minimum(height, width) 178 | scale_ratio = resize_min / smaller_dim 179 | 180 | # Convert back to ints to make heights and widths that TF ops will accept. 181 | new_height = tf.cast(height * scale_ratio, tf.int32) 182 | new_width = tf.cast(width * scale_ratio, tf.int32) 183 | 184 | return new_height, new_width 185 | 186 | 187 | def _aspect_preserving_resize(image, resize_min): 188 | """Resize images preserving the original aspect ratio. 189 | 190 | Args: 191 | image: A 3-D image `Tensor`. 192 | resize_min: A python integer or scalar `Tensor` indicating the size of 193 | the smallest side after resize. 194 | 195 | Returns: 196 | resized_image: A 3-D tensor containing the resized image. 197 | """ 198 | shape = tf.shape(image) 199 | height, width = shape[0], shape[1] 200 | 201 | new_height, new_width = _smallest_size_at_least(height, width, resize_min) 202 | 203 | return _resize_image(image, new_height, new_width) 204 | 205 | 206 | def _resize_image(image, height, width): 207 | """Simple wrapper around tf.resize_images. 208 | 209 | This is primarily to make sure we use the same `ResizeMethod` and other 210 | details each time. 211 | 212 | Args: 213 | image: A 3-D image `Tensor`. 214 | height: The target height for the resized image. 215 | width: The target width for the resized image. 216 | 217 | Returns: 218 | resized_image: A 3-D tensor containing the resized image. The first two 219 | dimensions have the shape [height, width]. 220 | """ 221 | return tf.image.resize_images( 222 | image, [height, width], method=tf.image.ResizeMethod.BILINEAR, 223 | align_corners=False) 224 | 225 | 226 | def preprocess_image(image_buffer, bbox, output_height, output_width, 227 | num_channels, is_training=False): 228 | """Preprocesses the given image. 229 | 230 | Preprocessing includes decoding, cropping, and resizing for both training 231 | and eval images. Training preprocessing, however, introduces some random 232 | distortion of the image to improve accuracy. 233 | 234 | Args: 235 | image_buffer: scalar string Tensor representing the raw JPEG image buffer. 236 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 237 | where each coordinate is [0, 1) and the coordinates are arranged as 238 | [ymin, xmin, ymax, xmax]. 239 | output_height: The height of the image after preprocessing. 240 | output_width: The width of the image after preprocessing. 241 | num_channels: Integer depth of the image buffer for decoding. 242 | is_training: `True` if we're preprocessing the image for training and 243 | `False` otherwise. 244 | 245 | Returns: 246 | A preprocessed image. 247 | """ 248 | if is_training: 249 | # For training, we want to randomize some of the distortions. 250 | image = _decode_crop_and_flip(image_buffer, bbox, num_channels) 251 | image = _resize_image(image, output_height, output_width) 252 | else: 253 | # For validation, we want to decode, resize, then just crop the middle. 254 | image = tf.image.decode_jpeg(image_buffer, channels=num_channels) 255 | image = _aspect_preserving_resize(image, _RESIZE_MIN) 256 | image = _central_crop(image, output_height, output_width) 257 | 258 | image.set_shape([output_height, output_width, num_channels]) 259 | 260 | return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels) 261 | -------------------------------------------------------------------------------- /resnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for Residual Networks. 16 | 17 | Residual networks ('v1' ResNets) were originally proposed in: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | 21 | The full preactivation 'v2' ResNet variant was introduced by: 22 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 24 | 25 | The key difference of the full preactivation 'v2' variant compared to the 26 | 'v1' variant in [1] is the use of batch normalization before every weight layer 27 | rather than after. 28 | """ 29 | 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | import numpy as np 36 | 37 | _BATCH_NORM_DECAY = 0.997 38 | _BATCH_NORM_EPSILON = 1e-5 39 | DEFAULT_VERSION = 2 40 | DEFAULT_DTYPE = tf.float32 41 | CASTABLE_TYPES = (tf.float16,) 42 | ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES 43 | 44 | 45 | ################################################################################ 46 | # Convenience functions for building the ResNet model. 47 | ################################################################################ 48 | def batch_norm(inputs, training, data_format): 49 | """Performs a batch normalization using a standard set of parameters.""" 50 | # We set fused=True for a significant performance boost. See 51 | # https://www.tensorflow.org/performance/performance_guide#common_fused_ops 52 | return tf.layers.batch_normalization( 53 | inputs=inputs, axis=1 if data_format == 'channels_first' else 3, 54 | momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True, 55 | scale=True, training=training, fused=True) 56 | 57 | 58 | def fixed_padding(inputs, kernel_size, data_format): 59 | """Pads the input along the spatial dimensions independently of input size. 60 | 61 | Args: 62 | inputs: A tensor of size [batch, channels, height_in, width_in] or 63 | [batch, height_in, width_in, channels] depending on data_format. 64 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation. 65 | Should be a positive integer. 66 | data_format: The input format ('channels_last' or 'channels_first'). 67 | 68 | Returns: 69 | A tensor with the same format as the input with the data either intact 70 | (if kernel_size == 1) or padded (if kernel_size > 1). 71 | """ 72 | pad_total = kernel_size - 1 73 | pad_beg = pad_total // 2 74 | pad_end = pad_total - pad_beg 75 | 76 | if data_format == 'channels_first': 77 | padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], 78 | [pad_beg, pad_end], [pad_beg, pad_end]]) 79 | else: 80 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 81 | [pad_beg, pad_end], [0, 0]]) 82 | return padded_inputs 83 | 84 | 85 | def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format): 86 | """Strided 2-D convolution with explicit padding.""" 87 | # The padding is consistent and is based only on `kernel_size`, not on the 88 | # dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). 89 | if strides > 1: 90 | inputs = fixed_padding(inputs, kernel_size, data_format) 91 | 92 | return tf.layers.conv2d( 93 | inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, 94 | padding=('SAME' if strides == 1 else 'VALID'), use_bias=False, 95 | kernel_initializer=tf.variance_scaling_initializer(), 96 | data_format=data_format) 97 | 98 | 99 | ################################################################################ 100 | # ResNet block definitions. 101 | ################################################################################ 102 | def _building_block_v1(inputs, filters, training, projection_shortcut, strides, 103 | data_format): 104 | """A single block for ResNet v1, without a bottleneck. 105 | 106 | Convolution then batch normalization then ReLU as described by: 107 | Deep Residual Learning for Image Recognition 108 | https://arxiv.org/pdf/1512.03385.pdf 109 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. 110 | 111 | Args: 112 | inputs: A tensor of size [batch, channels, height_in, width_in] or 113 | [batch, height_in, width_in, channels] depending on data_format. 114 | filters: The number of filters for the convolutions. 115 | training: A Boolean for whether the model is in training or inference 116 | mode. Needed for batch normalization. 117 | projection_shortcut: The function to use for projection shortcuts 118 | (typically a 1x1 convolution when downsampling the input). 119 | strides: The block's stride. If greater than 1, this block will ultimately 120 | downsample the input. 121 | data_format: The input format ('channels_last' or 'channels_first'). 122 | 123 | Returns: 124 | The output tensor of the block; shape should match inputs. 125 | """ 126 | shortcut = inputs 127 | 128 | if projection_shortcut is not None: 129 | shortcut = projection_shortcut(inputs) 130 | shortcut = batch_norm(inputs=shortcut, training=training, 131 | data_format=data_format) 132 | 133 | inputs = conv2d_fixed_padding( 134 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 135 | data_format=data_format) 136 | inputs = batch_norm(inputs, training, data_format) 137 | inputs = tf.nn.relu(inputs) 138 | 139 | inputs = conv2d_fixed_padding( 140 | inputs=inputs, filters=filters, kernel_size=3, strides=1, 141 | data_format=data_format) 142 | inputs = batch_norm(inputs, training, data_format) 143 | inputs += shortcut 144 | inputs = tf.nn.relu(inputs) 145 | 146 | return inputs 147 | 148 | 149 | def _building_block_v2(inputs, filters, training, projection_shortcut, strides, 150 | data_format): 151 | """A single block for ResNet v2, without a bottleneck. 152 | 153 | Batch normalization then ReLu then convolution as described by: 154 | Identity Mappings in Deep Residual Networks 155 | https://arxiv.org/pdf/1603.05027.pdf 156 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. 157 | 158 | Args: 159 | inputs: A tensor of size [batch, channels, height_in, width_in] or 160 | [batch, height_in, width_in, channels] depending on data_format. 161 | filters: The number of filters for the convolutions. 162 | training: A Boolean for whether the model is in training or inference 163 | mode. Needed for batch normalization. 164 | projection_shortcut: The function to use for projection shortcuts 165 | (typically a 1x1 convolution when downsampling the input). 166 | strides: The block's stride. If greater than 1, this block will ultimately 167 | downsample the input. 168 | data_format: The input format ('channels_last' or 'channels_first'). 169 | 170 | Returns: 171 | The output tensor of the block; shape should match inputs. 172 | """ 173 | shortcut = inputs 174 | inputs = batch_norm(inputs, training, data_format) 175 | inputs = tf.nn.relu(inputs) 176 | 177 | # The projection shortcut should come after the first batch norm and ReLU 178 | # since it performs a 1x1 convolution. 179 | if projection_shortcut is not None: 180 | shortcut = projection_shortcut(inputs) 181 | 182 | inputs = conv2d_fixed_padding( 183 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 184 | data_format=data_format) 185 | 186 | inputs = batch_norm(inputs, training, data_format) 187 | inputs = tf.nn.relu(inputs) 188 | inputs = conv2d_fixed_padding( 189 | inputs=inputs, filters=filters, kernel_size=3, strides=1, 190 | data_format=data_format) 191 | 192 | return inputs + shortcut 193 | 194 | 195 | def _bottleneck_block_v1(inputs, filters, training, projection_shortcut, 196 | strides, data_format): 197 | """A single block for ResNet v1, with a bottleneck. 198 | 199 | Similar to _building_block_v1(), except using the "bottleneck" blocks 200 | described in: 201 | Convolution then batch normalization then ReLU as described by: 202 | Deep Residual Learning for Image Recognition 203 | https://arxiv.org/pdf/1512.03385.pdf 204 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. 205 | 206 | Args: 207 | inputs: A tensor of size [batch, channels, height_in, width_in] or 208 | [batch, height_in, width_in, channels] depending on data_format. 209 | filters: The number of filters for the convolutions. 210 | training: A Boolean for whether the model is in training or inference 211 | mode. Needed for batch normalization. 212 | projection_shortcut: The function to use for projection shortcuts 213 | (typically a 1x1 convolution when downsampling the input). 214 | strides: The block's stride. If greater than 1, this block will ultimately 215 | downsample the input. 216 | data_format: The input format ('channels_last' or 'channels_first'). 217 | 218 | Returns: 219 | The output tensor of the block; shape should match inputs. 220 | """ 221 | shortcut = inputs 222 | 223 | if projection_shortcut is not None: 224 | shortcut = projection_shortcut(inputs) 225 | shortcut = batch_norm(inputs=shortcut, training=training, 226 | data_format=data_format) 227 | 228 | inputs = conv2d_fixed_padding( 229 | inputs=inputs, filters=filters, kernel_size=1, strides=1, 230 | data_format=data_format) 231 | inputs = batch_norm(inputs, training, data_format) 232 | inputs = tf.nn.relu(inputs) 233 | 234 | inputs = conv2d_fixed_padding( 235 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 236 | data_format=data_format) 237 | inputs = batch_norm(inputs, training, data_format) 238 | inputs = tf.nn.relu(inputs) 239 | 240 | inputs = conv2d_fixed_padding( 241 | inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, 242 | data_format=data_format) 243 | inputs = batch_norm(inputs, training, data_format) 244 | inputs += shortcut 245 | inputs = tf.nn.relu(inputs) 246 | 247 | return inputs 248 | 249 | 250 | def _bottleneck_block_v2(inputs, filters, training, projection_shortcut, 251 | strides, data_format): 252 | """A single block for ResNet v2, with a bottleneck. 253 | 254 | Similar to _building_block_v2(), except using the "bottleneck" blocks 255 | described in: 256 | Convolution then batch normalization then ReLU as described by: 257 | Deep Residual Learning for Image Recognition 258 | https://arxiv.org/pdf/1512.03385.pdf 259 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. 260 | 261 | Adapted to the ordering conventions of: 262 | Batch normalization then ReLu then convolution as described by: 263 | Identity Mappings in Deep Residual Networks 264 | https://arxiv.org/pdf/1603.05027.pdf 265 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. 266 | 267 | Args: 268 | inputs: A tensor of size [batch, channels, height_in, width_in] or 269 | [batch, height_in, width_in, channels] depending on data_format. 270 | filters: The number of filters for the convolutions. 271 | training: A Boolean for whether the model is in training or inference 272 | mode. Needed for batch normalization. 273 | projection_shortcut: The function to use for projection shortcuts 274 | (typically a 1x1 convolution when downsampling the input). 275 | strides: The block's stride. If greater than 1, this block will ultimately 276 | downsample the input. 277 | data_format: The input format ('channels_last' or 'channels_first'). 278 | 279 | Returns: 280 | The output tensor of the block; shape should match inputs. 281 | """ 282 | shortcut = inputs 283 | inputs = batch_norm(inputs, training, data_format) 284 | inputs = tf.nn.relu(inputs) 285 | 286 | # The projection shortcut should come after the first batch norm and ReLU 287 | # since it performs a 1x1 convolution. 288 | if projection_shortcut is not None: 289 | shortcut = projection_shortcut(inputs) 290 | 291 | inputs = conv2d_fixed_padding( 292 | inputs=inputs, filters=filters, kernel_size=1, strides=1, 293 | data_format=data_format) 294 | 295 | inputs = batch_norm(inputs, training, data_format) 296 | inputs = tf.nn.relu(inputs) 297 | inputs = conv2d_fixed_padding( 298 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 299 | data_format=data_format) 300 | 301 | inputs = batch_norm(inputs, training, data_format) 302 | inputs = tf.nn.relu(inputs) 303 | inputs = conv2d_fixed_padding( 304 | inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, 305 | data_format=data_format) 306 | 307 | return inputs + shortcut 308 | 309 | 310 | def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides, 311 | training, name, data_format): 312 | """Creates one layer of blocks for the ResNet model. 313 | 314 | Args: 315 | inputs: A tensor of size [batch, channels, height_in, width_in] or 316 | [batch, height_in, width_in, channels] depending on data_format. 317 | filters: The number of filters for the first convolution of the layer. 318 | bottleneck: Is the block created a bottleneck block. 319 | block_fn: The block to use within the model, either `building_block` or 320 | `bottleneck_block`. 321 | blocks: The number of blocks contained in the layer. 322 | strides: The stride to use for the first convolution of the layer. If 323 | greater than 1, this layer will ultimately downsample the input. 324 | training: Either True or False, whether we are currently training the 325 | model. Needed for batch norm. 326 | name: A string name for the tensor output of the block layer. 327 | data_format: The input format ('channels_last' or 'channels_first'). 328 | 329 | Returns: 330 | The output tensor of the block layer. 331 | """ 332 | 333 | # Bottleneck blocks end with 4x the number of filters as they start with 334 | filters_out = filters * 4 if bottleneck else filters 335 | 336 | def projection_shortcut(inputs): 337 | return conv2d_fixed_padding( 338 | inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, 339 | data_format=data_format) 340 | 341 | # Only the first block per block_layer uses projection_shortcut and strides 342 | inputs = block_fn(inputs, filters, training, projection_shortcut, strides, 343 | data_format) 344 | 345 | for _ in range(1, blocks): 346 | inputs = block_fn(inputs, filters, training, None, 1, data_format) 347 | 348 | return tf.identity(inputs, name) 349 | 350 | def orthogonal_regularizer(scale, data_format) : 351 | """ Defining the Orthogonal regularizer and return the function at last to be used in Conv layer as kernel regularizer""" 352 | 353 | def ortho_reg(w) : 354 | """ Reshaping the matrxi in to 2D tensor for enforcing orthogonality""" 355 | # if data_format == 'channels_first': 356 | # _, c, _, _ = w.get_shape().as_list() 357 | # else: 358 | # _, _, _, c = w.get_shape().as_list() 359 | if data_format == 'channels_first': 360 | w = tf.transpose(w, [0, 2, 3, 1]) 361 | _, _, _, c = w.get_shape().as_list() 362 | w = tf.reshape(w, [-1, c]) 363 | 364 | """ Declaring a Identity Tensor of appropriate size""" 365 | identity = tf.eye(c) 366 | 367 | """ Regularizer Wt*W - I """ 368 | w_transpose = tf.transpose(w) 369 | w_mul = tf.matmul(w_transpose, w) 370 | reg = tf.subtract(w_mul, identity) 371 | 372 | """Calculating the Loss Obtained""" 373 | ortho_loss = tf.nn.l2_loss(reg) 374 | 375 | return scale * ortho_loss 376 | 377 | return ortho_reg 378 | 379 | def orthogonal_regularizer_fully(scale) : 380 | def ortho_reg_fully(w) : 381 | _, c = w.get_shape().as_list() 382 | identity = tf.eye(c) 383 | w_transpose = tf.transpose(w) 384 | w_mul = tf.matmul(w_transpose, w) 385 | reg = tf.subtract(w_mul, identity) 386 | 387 | ortho_loss = tf.nn.l2_loss(reg) 388 | 389 | return scale * ortho_loss 390 | return ortho_reg_fully 391 | 392 | def spectral_norm(w, iteration=1): 393 | w_shape = w.shape.as_list() 394 | w = tf.reshape(w, [-1, w_shape[-1]]) 395 | 396 | u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.random_normal_initializer(), trainable=False) 397 | 398 | u_hat = u 399 | v_hat = None 400 | for i in range(iteration): 401 | """ 402 | power iteration 403 | Usually iteration = 1 will be enough 404 | """ 405 | 406 | v_ = tf.matmul(u_hat, tf.transpose(w)) 407 | v_hat = tf.nn.l2_normalize(v_) 408 | 409 | u_ = tf.matmul(v_hat, w) 410 | u_hat = tf.nn.l2_normalize(u_) 411 | 412 | u_hat = tf.stop_gradient(u_hat) 413 | v_hat = tf.stop_gradient(v_hat) 414 | 415 | sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat)) 416 | 417 | with tf.control_dependencies([u.assign(u_hat)]): 418 | w_norm = w / sigma 419 | w_norm = tf.reshape(w_norm, w_shape) 420 | 421 | return w_norm 422 | 423 | def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, 424 | sn=False, data_format=None, scope='conv_0'): 425 | with tf.variable_scope(scope): 426 | if pad > 0: 427 | h = x.get_shape().as_list()[2] 428 | if h % stride == 0: 429 | pad = pad * 2 430 | else: 431 | pad = max(kernel - (h % stride), 0) 432 | 433 | pad_top = pad // 2 434 | pad_bottom = pad - pad_top 435 | pad_left = pad // 2 436 | pad_right = pad - pad_left 437 | 438 | if data_format == 'channels_first': 439 | if pad_type == 'zero' : 440 | x = tf.pad(x, [[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]]) 441 | if pad_type == 'reflect' : 442 | x = tf.pad(x, [[0, 0], [0, 0], [pad_top, pad_bottom], [pad_left, pad_right]], 443 | mode='REFLECT') 444 | else: 445 | if pad_type == 'zero' : 446 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]]) 447 | if pad_type == 'reflect' : 448 | x = tf.pad(x, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], 449 | mode='REFLECT') 450 | 451 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 452 | weight_regularizer = orthogonal_regularizer(0.0001, data_format) 453 | if sn : 454 | if data_format == 'channels_first': 455 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[1], channels], 456 | initializer=weight_init, regularizer=weight_regularizer) 457 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 458 | strides=[1, 1, stride, stride], padding='VALID', data_format='NCHW') 459 | else: 460 | w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], 461 | initializer=weight_init, regularizer=weight_regularizer) 462 | x = tf.nn.conv2d(input=x, filter=spectral_norm(w), 463 | strides=[1, stride, stride, 1], padding='VALID', data_format='NHWC') 464 | if use_bias : 465 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 466 | if data_format == 'channels_first': 467 | x = tf.nn.bias_add(x, bias, data_format='NCHW') 468 | else: 469 | x = tf.nn.bias_add(x, bias, data_format='NHWC') 470 | else: 471 | x = tf.layers.conv2d(inputs=x, filters=channels, 472 | kernel_size=kernel, 473 | kernel_initializer=weight_init, 474 | kernel_regularizer=weight_regularizer, 475 | strides=stride, use_bias=use_bias, 476 | data_format=data_format) 477 | return x 478 | 479 | 480 | def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, 481 | sn=False, data_format=None, scope='deconv_0'): 482 | with tf.variable_scope(scope): 483 | x_shape = x.get_shape().as_list() 484 | if data_format == 'channels_first': 485 | if padding == 'SAME': 486 | output_shape = [tf.shape(x)[0], channels, x_shape[2] * stride, x_shape[3] * stride] 487 | else: 488 | output_shape =[tf.shape(x)[0], channels, x_shape[2] * stride + max(kernel - stride, 0), 489 | x_shape[3] * stride + max(kernel - stride, 0)] 490 | else: 491 | if padding == 'SAME': 492 | output_shape = [tf.shape(x)[0], x_shape[1] * stride, x_shape[2] * stride, channels] 493 | else: 494 | output_shape =[tf.shape(x)[0], x_shape[1] * stride + max(kernel - stride, 0), 495 | x_shape[2] * stride + max(kernel - stride, 0), channels] 496 | # print('output_shape', output_shape) 497 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 498 | weight_regularizer = orthogonal_regularizer(0.0001, data_format) 499 | if sn: 500 | if data_format == 'channels_first': 501 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape().as_list()[1]], 502 | initializer=weight_init, regularizer=weight_regularizer) 503 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, 504 | strides=[1, 1, stride, stride], padding=padding, data_format='NCHW') 505 | else: 506 | w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape().as_list()[-1]], 507 | initializer=weight_init, regularizer=weight_regularizer) 508 | x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, 509 | strides=[1, stride, stride, 1], padding=padding, data_format='NHWC') 510 | if use_bias : 511 | bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0)) 512 | if data_format == 'channels_first': 513 | x = tf.nn.bias_add(x, bias, data_format='NCHW') 514 | else: 515 | x = tf.nn.bias_add(x, bias, data_format='NHWC') 516 | else: 517 | x = tf.layers.conv2d_transpose(inputs=x, filters=channels, 518 | kernel_size=kernel, 519 | kernel_initializer=weight_init, 520 | kernel_regularizer=weight_regularizer, 521 | strides=stride, padding=padding, use_bias=use_bias, 522 | data_format=data_format) 523 | return x 524 | 525 | def fully_conneted(x, units, use_bias=True, sn=False, scope='fully_0'): 526 | with tf.variable_scope(scope): 527 | shape = x.get_shape().as_list() 528 | channels = shape[-1] 529 | 530 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 531 | weight_regularizer_fully = orthogonal_regularizer_fully(0.0001) 532 | if sn : 533 | w = tf.get_variable("kernel", [channels, units], tf.float32, 534 | initializer=weight_init, regularizer=weight_regularizer_fully) 535 | if use_bias : 536 | bias = tf.get_variable("bias", [units], initializer=tf.constant_initializer(0.0)) 537 | x = tf.matmul(x, spectral_norm(w)) + bias 538 | else: 539 | x = tf.matmul(x, spectral_norm(w)) 540 | else: 541 | x = tf.layers.dense(x, units=units, 542 | kernel_initializer=weight_init, 543 | kernel_regularizer=weight_regularizer_fully, 544 | use_bias=use_bias) 545 | return x 546 | 547 | def resblock_up(x_init, channels, use_bias=True, is_training=True, 548 | sn=False, data_format=None, scope='resblock_up'): 549 | with tf.variable_scope(scope): 550 | with tf.variable_scope('res1'): 551 | x = batch_norm(x_init, is_training, data_format) 552 | x = tf.nn.relu(x) 553 | x = deconv(x, channels, kernel=3, stride=2, 554 | use_bias=use_bias, sn=sn, data_format=data_format) 555 | with tf.variable_scope('res2') : 556 | x = batch_norm(x, is_training, data_format) 557 | x = tf.nn.relu(x) 558 | x = deconv(x, channels, kernel=3, stride=1, 559 | use_bias=use_bias, sn=sn, data_format=data_format) 560 | with tf.variable_scope('skip') : 561 | x_init = deconv(x_init, channels, kernel=3, stride=2, 562 | use_bias=use_bias, sn=sn, data_format=data_format) 563 | return x + x_init 564 | 565 | def hw_flatten(x) : 566 | x_shape = x.get_shape().as_list() 567 | return tf.reshape(x, shape=[-1, x_shape[1]*x_shape[2], x_shape[3]]) 568 | 569 | def self_attention_2(x, channels, sn=False, data_format=None, scope='self_attention'): 570 | with tf.variable_scope(scope): 571 | # print('atten_in', x.get_shape()) 572 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 573 | weight_regularizer = orthogonal_regularizer(0.0001, data_format) 574 | with tf.variable_scope('f_conv'): 575 | f = conv(x, channels // 8, kernel=1, stride=1, sn=sn, data_format=data_format) 576 | f = tf.layers.max_pooling2d(f, pool_size=8, strides=8, padding='SAME', 577 | data_format=data_format) 578 | if data_format == 'channels_first': 579 | f = tf.transpose(f, [0, 2, 3, 1]) 580 | # print('f', f.get_shape(), channels // 8) 581 | with tf.variable_scope('g_conv'): 582 | g = conv(x, channels // 8, kernel=1, stride=1, sn=sn, data_format=data_format) 583 | if data_format == 'channels_first': 584 | g = tf.transpose(g, [0, 2, 3, 1]) 585 | # print('g', g.get_shape(), channels // 8) 586 | with tf.variable_scope('h_conv'): 587 | h = conv(x, channels // 4, kernel=1, stride=1, sn=sn, data_format=data_format) 588 | # h = tf.layers.max_pooling2d(h, pool_size=6, strides=6, padding='SAME', 589 | # data_format=data_format) 590 | h = tf.layers.max_pooling2d(h, pool_size=8, strides=8, padding='SAME', 591 | data_format=data_format) 592 | if data_format == 'channels_first': 593 | h = tf.transpose(h, [0, 2, 3, 1]) 594 | # print('h', h.get_shape(), channels // 4) 595 | 596 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) 597 | beta = tf.nn.softmax(s) # attention map 598 | o = tf.matmul(beta, hw_flatten(h)) 599 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 600 | if data_format == 'channels_first': 601 | o = tf.transpose(o, [0, 2, 1]) 602 | 603 | x_shape = x.get_shape().as_list() 604 | if data_format == 'channels_first': 605 | o = tf.reshape(o, shape=[-1, channels//4, x_shape[2], x_shape[3]]) 606 | else: 607 | o = tf.reshape(o, shape=[-1, x_shape[1], x_shape[2], channels//4]) 608 | 609 | o = conv(o, channels, kernel=1, stride=1, sn=sn, data_format=data_format, scope='attn_conv') 610 | x = gamma * o + x 611 | return x 612 | 613 | def self_attention_full(x, channels, sn=False, data_format=None, scope='self_attention'): 614 | with tf.variable_scope(scope): 615 | # print('atten_in', x.get_shape()) 616 | weight_init = tf.truncated_normal_initializer(mean=0.0, stddev=0.02) 617 | weight_regularizer = orthogonal_regularizer(0.0001, data_format) 618 | with tf.variable_scope('f_conv'): 619 | f = conv(x, channels, kernel=1, stride=1, sn=sn, data_format=data_format) 620 | f = tf.layers.max_pooling2d(f, pool_size=4, strides=4, padding='SAME', 621 | data_format=data_format) 622 | if data_format == 'channels_first': 623 | f = tf.transpose(f, [0, 2, 3, 1]) 624 | with tf.variable_scope('g_conv'): 625 | g = conv(x, channels, kernel=1, stride=1, sn=sn, data_format=data_format) 626 | if data_format == 'channels_first': 627 | g = tf.transpose(g, [0, 2, 3, 1]) 628 | with tf.variable_scope('h_conv'): 629 | h = conv(x, channels, kernel=1, stride=1, sn=sn, data_format=data_format) 630 | h = tf.layers.max_pooling2d(h, pool_size=4, strides=4, padding='SAME', 631 | data_format=data_format) 632 | if data_format == 'channels_first': 633 | h = tf.transpose(h, [0, 2, 3, 1]) 634 | 635 | s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) 636 | beta = tf.nn.softmax(s) # attention map 637 | o = tf.matmul(beta, hw_flatten(h)) 638 | gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0)) 639 | if data_format == 'channels_first': 640 | o = tf.transpose(o, [0, 2, 1]) 641 | 642 | x_shape = x.get_shape().as_list() 643 | if data_format == 'channels_first': 644 | o = tf.reshape(o, shape=[-1, channels, x_shape[2], x_shape[3]]) 645 | else: 646 | o = tf.reshape(o, shape=[-1, x_shape[1], x_shape[2], channels]) 647 | 648 | o = conv(o, channels, kernel=1, stride=1, sn=sn, data_format=data_format, scope='attn_conv') 649 | x = gamma * o + x 650 | return x 651 | 652 | class Model(object): 653 | """Base class for building the Resnet Model.""" 654 | 655 | def __init__(self, resnet_size, bottleneck, num_classes, num_filters, 656 | kernel_size, 657 | conv_stride, first_pool_size, first_pool_stride, 658 | block_sizes, block_strides, 659 | resnet_version=DEFAULT_VERSION, data_format=None, 660 | dtype=DEFAULT_DTYPE, spectral_norm=False, 661 | reuse=False, offload=False, compress_ratio=0.05): 662 | """Creates a model for classifying an image. 663 | 664 | Args: 665 | resnet_size: A single integer for the size of the ResNet model. 666 | bottleneck: Use regular blocks or bottleneck blocks. 667 | num_classes: The number of classes used as labels. 668 | num_filters: The number of filters to use for the first block layer 669 | of the model. This number is then doubled for each subsequent block 670 | layer. 671 | kernel_size: The kernel size to use for convolution. 672 | conv_stride: stride size for the initial convolutional layer 673 | first_pool_size: Pool size to be used for the first pooling layer. 674 | If none, the first pooling layer is skipped. 675 | first_pool_stride: stride size for the first pooling layer. Not used 676 | if first_pool_size is None. 677 | block_sizes: A list containing n values, where n is the number of sets of 678 | block layers desired. Each value should be the number of blocks in the 679 | i-th set. 680 | block_strides: List of integers representing the desired stride size for 681 | each of the sets of block layers. Should be same length as block_sizes. 682 | resnet_version: Integer representing which version of the ResNet network 683 | to use. See README for details. Valid values: [1, 2] 684 | data_format: Input format ('channels_last', 'channels_first', or None). 685 | If set to None, the format is dependent on whether a GPU is available. 686 | dtype: The TensorFlow dtype to use for calculations. If not specified 687 | tf.float32 is used. 688 | 689 | Raises: 690 | ValueError: if invalid version is selected. 691 | """ 692 | self.resnet_size = resnet_size 693 | 694 | if not data_format: 695 | data_format = ( 696 | 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') 697 | 698 | self.resnet_version = resnet_version 699 | if resnet_version not in (1, 2): 700 | raise ValueError( 701 | 'Resnet version should be 1 or 2. See README for citations.') 702 | 703 | self.bottleneck = bottleneck 704 | if bottleneck: 705 | if resnet_version == 1: 706 | self.block_fn = _bottleneck_block_v1 707 | else: 708 | self.block_fn = _bottleneck_block_v2 709 | else: 710 | if resnet_version == 1: 711 | self.block_fn = _building_block_v1 712 | else: 713 | self.block_fn = _building_block_v2 714 | 715 | if dtype not in ALLOWED_TYPES: 716 | raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES)) 717 | 718 | self.data_format = data_format 719 | self.num_classes = num_classes 720 | self.num_filters = num_filters 721 | self.kernel_size = kernel_size 722 | self.conv_stride = conv_stride 723 | self.first_pool_size = first_pool_size 724 | self.first_pool_stride = first_pool_stride 725 | self.block_sizes = block_sizes 726 | self.block_strides = block_strides 727 | self.dtype = dtype 728 | self.pre_activation = resnet_version == 2 729 | self.sn = spectral_norm 730 | self.reuse = reuse 731 | self.offload = offload 732 | self.compress_ratio = compress_ratio 733 | 734 | def _custom_dtype_getter(self, getter, name, shape=None, dtype=DEFAULT_DTYPE, 735 | *args, **kwargs): 736 | """Creates variables in fp32, then casts to fp16 if necessary. 737 | 738 | This function is a custom getter. A custom getter is a function with the 739 | same signature as tf.get_variable, except it has an additional getter 740 | parameter. Custom getters can be passed as the `custom_getter` parameter of 741 | tf.variable_scope. Then, tf.get_variable will call the custom getter, 742 | instead of directly getting a variable itself. This can be used to change 743 | the types of variables that are retrieved with tf.get_variable. 744 | The `getter` parameter is the underlying variable getter, that would have 745 | been called if no custom getter was used. Custom getters typically get a 746 | variable with `getter`, then modify it in some way. 747 | 748 | This custom getter will create an fp32 variable. If a low precision 749 | (e.g. float16) variable was requested it will then cast the variable to the 750 | requested dtype. The reason we do not directly create variables in low 751 | precision dtypes is that applying small gradients to such variables may 752 | cause the variable not to change. 753 | 754 | Args: 755 | getter: The underlying variable getter, that has the same signature as 756 | tf.get_variable and returns a variable. 757 | name: The name of the variable to get. 758 | shape: The shape of the variable to get. 759 | dtype: The dtype of the variable to get. Note that if this is a low 760 | precision dtype, the variable will be created as a tf.float32 variable, 761 | then cast to the appropriate dtype 762 | *args: Additional arguments to pass unmodified to getter. 763 | **kwargs: Additional keyword arguments to pass unmodified to getter. 764 | 765 | Returns: 766 | A variable which is cast to fp16 if necessary. 767 | """ 768 | 769 | if dtype in CASTABLE_TYPES: 770 | var = getter(name, shape, tf.float32, *args, **kwargs) 771 | return tf.cast(var, dtype=dtype, name=name + '_cast') 772 | else: 773 | return getter(name, shape, dtype, *args, **kwargs) 774 | 775 | def _model_variable_scope(self): 776 | """Returns a variable scope that the model should be created under. 777 | 778 | If self.dtype is a castable type, model variable will be created in fp32 779 | then cast to self.dtype before being used. 780 | 781 | Returns: 782 | A variable scope for the model. 783 | """ 784 | 785 | return tf.variable_scope('resnet_model', reuse=self.reuse, 786 | custom_getter=self._custom_dtype_getter) 787 | 788 | def __call__(self, inputs, training): 789 | """Add operations to classify a batch of input images. 790 | 791 | Args: 792 | inputs: A Tensor representing a batch of input images. 793 | training: A boolean. Set to True to add operations required only when 794 | training the classifier. 795 | 796 | Returns: 797 | A logits Tensor with shape [, self.num_classes]. 798 | """ 799 | 800 | def endecoder(inter_rep): 801 | with tf.variable_scope('endecoder') as scope: 802 | axes = [2, 3] if self.data_format == 'channels_first' else [1, 2] 803 | 804 | out_size = max(int(3*self.compress_ratio*4*4), 1) 805 | print('out_size', out_size) 806 | 807 | c_sample = conv(inter_rep, out_size, kernel=4, 808 | stride=4, sn=self.sn, use_bias=False, 809 | data_format=self.data_format, scope='samp_conv') 810 | 811 | num_centers = 8 812 | quant_centers = tf.get_variable( 813 | 'quant_centers', shape=(num_centers,), dtype=tf.float32, 814 | initializer=tf.random_uniform_initializer(minval=-16., 815 | maxval=16)) 816 | 817 | print('quant_centers', quant_centers) 818 | print('c_sample', c_sample) 819 | quant_dist = tf.square(tf.abs(tf.expand_dims(c_sample, axis=-1) - quant_centers)) 820 | phi_soft = tf.nn.softmax(-1. * quant_dist, dim=-1) 821 | symbols_hard = tf.argmax(phi_soft, axis=-1) 822 | phi_hard = tf.one_hot(symbols_hard, depth=num_centers, axis=-1, dtype=tf.float32) 823 | softout = tf.reduce_sum(phi_soft * quant_centers, -1) 824 | hardout = tf.reduce_sum(phi_hard * quant_centers, -1) 825 | 826 | c_sample_q = softout + tf.stop_gradient(hardout - softout) 827 | 828 | print('phi_soft', phi_soft) 829 | print('phi_hard', phi_hard) 830 | print('quant_dist', quant_dist) 831 | print('softout', softout) 832 | print('hardout', hardout) 833 | print('c_sample_q', c_sample_q) 834 | 835 | c_recon = self_attention_full(c_sample_q, channels=out_size, sn=self.sn, 836 | data_format=self.data_format, scope='self_attention1') 837 | c_recon = resblock_up(c_recon, channels=64, use_bias=False, 838 | is_training=training, sn=self.sn, 839 | data_format=self.data_format, scope='resblock_up_x2') 840 | c_recon = self_attention_2(c_recon, channels=64, sn=self.sn, 841 | data_format=self.data_format, scope='self_attention2') 842 | c_recon = resblock_up(c_recon, channels=32, use_bias=False, 843 | is_training=training, sn=self.sn, 844 | data_format=self.data_format, scope='resblock_up_x4') 845 | 846 | c_recon = batch_norm(c_recon, training, self.data_format) 847 | c_recon = tf.nn.relu(c_recon) 848 | 849 | if self.data_format == 'channels_first': 850 | c_recon = tf.pad(c_recon, tf.constant([[0, 0], [0, 0], [1, 1], [1, 1]])) 851 | else: 852 | c_recon = tf.pad(c_recon, tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]])) 853 | c_recon = conv(c_recon, channels=3, kernel=3, stride=1, 854 | use_bias=False, sn=self.sn, data_format=self.data_format, scope='G_logit') 855 | 856 | 857 | c_recon = tf.nn.tanh(c_recon) 858 | _R_MEAN = 123.68 859 | _G_MEAN = 116.78 860 | _B_MEAN = 103.94 861 | _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN] 862 | if self.data_format == 'channels_first': 863 | ch_means = tf.expand_dims(tf.expand_dims(tf.expand_dims(_CHANNEL_MEANS, 0), 2), 3) 864 | else: 865 | ch_means = tf.expand_dims(tf.expand_dims(tf.expand_dims(_CHANNEL_MEANS, 0), 0), 0) 866 | 867 | return c_sample, (c_recon+1.0)*127.5-ch_means 868 | 869 | @tf.custom_gradient 870 | def grad1pass(x): 871 | def grad(dy): 872 | d_norm = tf.sqrt(tf.reduce_sum(dy*dy)) 873 | return dy*1.0/tf.maximum(1.0, d_norm) 874 | return x, grad 875 | 876 | with self._model_variable_scope(): 877 | if self.data_format == 'channels_first': 878 | # Convert the inputs from channels_last (NHWC) to channels_first (NCHW). 879 | # This provides a large performance boost on GPU. See 880 | # https://www.tensorflow.org/performance/performance_guide#data_formats 881 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 882 | 883 | if self.offload: 884 | c_sample, c_recon = endecoder(inputs) 885 | print('c_sample', c_sample.get_shape()) 886 | print('c_recon', c_recon.get_shape()) 887 | 888 | inputs = grad1pass(c_recon) 889 | # inputs = c_recon 890 | 891 | inter_feature = [] 892 | 893 | inputs = conv2d_fixed_padding( 894 | inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size, 895 | strides=self.conv_stride, data_format=self.data_format) 896 | 897 | # inter_feature.append(inputs) 898 | 899 | inputs = tf.identity(inputs, 'initial_conv') 900 | 901 | # We do not include batch normalization or activation functions in V2 902 | # for the initial conv1 because the first ResNet unit will perform these 903 | # for both the shortcut and non-shortcut paths as part of the first 904 | # block's projection. Cf. Appendix of [2]. 905 | if self.resnet_version == 1: 906 | inputs = batch_norm(inputs, training, self.data_format) 907 | inputs = tf.nn.relu(inputs) 908 | 909 | if self.first_pool_size: 910 | inputs = tf.layers.max_pooling2d( 911 | inputs=inputs, pool_size=self.first_pool_size, 912 | strides=self.first_pool_stride, padding='SAME', 913 | data_format=self.data_format) 914 | inputs = tf.identity(inputs, 'initial_max_pool') 915 | 916 | for i, num_blocks in enumerate(self.block_sizes): 917 | num_filters = self.num_filters * (2**i) 918 | inputs = block_layer( 919 | inputs=inputs, filters=num_filters, bottleneck=self.bottleneck, 920 | block_fn=self.block_fn, blocks=num_blocks, 921 | strides=self.block_strides[i], training=training, 922 | name='block_layer{}'.format(i + 1), data_format=self.data_format) 923 | if i == 1: 924 | inter_feature.append(inputs) 925 | 926 | # Only apply the BN and ReLU for model that does pre_activation in each 927 | # building/bottleneck block, eg resnet V2. 928 | if self.pre_activation: 929 | inputs = batch_norm(inputs, training, self.data_format) 930 | inputs = tf.nn.relu(inputs) 931 | 932 | # The current top layer has shape 933 | # `batch_size x pool_size x pool_size x final_size`. 934 | # ResNet does an Average Pooling layer over pool_size, 935 | # but that is the same as doing a reduce_mean. We do a reduce_mean 936 | # here because it performs better than AveragePooling2D. 937 | axes = [2, 3] if self.data_format == 'channels_first' else [1, 2] 938 | inputs = tf.reduce_mean(inputs, axes, keepdims=True) 939 | inputs = tf.identity(inputs, 'final_reduce_mean') 940 | 941 | inputs = tf.squeeze(inputs, axes) 942 | inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) 943 | inputs = tf.identity(inputs, 'final_dense') 944 | if self.offload: 945 | if self.data_format == 'channels_first': 946 | c_recon = tf.transpose(c_recon, [0, 2, 3, 1]) 947 | return inputs, c_recon, inter_feature 948 | else: 949 | return inputs, inter_feature 950 | -------------------------------------------------------------------------------- /resnet_run_loop.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utility and supporting functions for ResNet. 16 | 17 | This module contains ResNet code which does not directly build layers. This 18 | includes dataset management, hyperparameter and optimizer code, and argument 19 | parsing. Code for defining the ResNet layers can be found in resnet_model.py. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import functools 27 | import math 28 | import multiprocessing 29 | import os 30 | import string 31 | 32 | # pylint: disable=g-bad-import-order 33 | from absl import flags 34 | import tensorflow as tf 35 | from tensorflow.contrib.data.python.ops import threadpool 36 | 37 | import resnet_model 38 | import imagenet_preprocessing 39 | from utils.flags import core as flags_core 40 | from utils.export import export 41 | from utils.logs import hooks_helper 42 | from utils.logs import logger 43 | from utils.misc import distribution_utils 44 | from utils.misc import model_helpers 45 | 46 | 47 | ################################################################################ 48 | # Functions for input processing. 49 | ################################################################################ 50 | def process_record_dataset(dataset, 51 | is_training, 52 | batch_size, 53 | shuffle_buffer, 54 | parse_record_fn, 55 | num_epochs=1, 56 | dtype=tf.float32, 57 | datasets_num_private_threads=None, 58 | num_parallel_batches=1): 59 | """Given a Dataset with raw records, return an iterator over the records. 60 | 61 | Args: 62 | dataset: A Dataset representing raw records 63 | is_training: A boolean denoting whether the input is for training. 64 | batch_size: The number of samples per batch. 65 | shuffle_buffer: The buffer size to use when shuffling records. A larger 66 | value results in better randomness, but smaller values reduce startup 67 | time and use less memory. 68 | parse_record_fn: A function that takes a raw record and returns the 69 | corresponding (image, label) pair. 70 | num_epochs: The number of epochs to repeat the dataset. 71 | dtype: Data type to use for images/features. 72 | datasets_num_private_threads: Number of threads for a private 73 | threadpool created for all datasets computation. 74 | num_parallel_batches: Number of parallel batches for tf.data. 75 | 76 | Returns: 77 | Dataset of (image, label) pairs ready for iteration. 78 | """ 79 | 80 | # Prefetches a batch at a time to smooth out the time taken to load input 81 | # files for shuffling and processing. 82 | dataset = dataset.prefetch(buffer_size=batch_size) 83 | if is_training: 84 | # Shuffles records before repeating to respect epoch boundaries. 85 | dataset = dataset.shuffle(buffer_size=shuffle_buffer) 86 | 87 | # Repeats the dataset for the number of epochs to train. 88 | dataset = dataset.repeat(num_epochs) 89 | 90 | # Parses the raw records into images and labels. 91 | dataset = dataset.apply( 92 | tf.contrib.data.map_and_batch( 93 | lambda value: parse_record_fn(value, is_training, dtype), 94 | batch_size=batch_size, 95 | num_parallel_batches=num_parallel_batches, 96 | drop_remainder=False)) 97 | 98 | # Operations between the final prefetch and the get_next call to the iterator 99 | # will happen synchronously during run time. We prefetch here again to 100 | # background all of the above processing work and keep it out of the 101 | # critical training path. Setting buffer_size to tf.contrib.data.AUTOTUNE 102 | # allows DistributionStrategies to adjust how many batches to fetch based 103 | # on how many devices are present. 104 | dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) 105 | 106 | # Defines a specific size thread pool for tf.data operations. 107 | if datasets_num_private_threads: 108 | tf.logging.info('datasets_num_private_threads: %s', 109 | datasets_num_private_threads) 110 | dataset = threadpool.override_threadpool( 111 | dataset, 112 | threadpool.PrivateThreadPool( 113 | datasets_num_private_threads, 114 | display_name='input_pipeline_thread_pool')) 115 | 116 | return dataset 117 | 118 | 119 | def get_synth_input_fn(height, width, num_channels, num_classes, 120 | dtype=tf.float32): 121 | """Returns an input function that returns a dataset with random data. 122 | 123 | This input_fn returns a data set that iterates over a set of random data and 124 | bypasses all preprocessing, e.g. jpeg decode and copy. The host to device 125 | copy is still included. This used to find the upper throughput bound when 126 | tunning the full input pipeline. 127 | 128 | Args: 129 | height: Integer height that will be used to create a fake image tensor. 130 | width: Integer width that will be used to create a fake image tensor. 131 | num_channels: Integer depth that will be used to create a fake image tensor. 132 | num_classes: Number of classes that should be represented in the fake labels 133 | tensor 134 | dtype: Data type for features/images. 135 | 136 | Returns: 137 | An input_fn that can be used in place of a real one to return a dataset 138 | that can be used for iteration. 139 | """ 140 | # pylint: disable=unused-argument 141 | def input_fn(is_training, data_dir, batch_size, *args, **kwargs): 142 | """Returns dataset filled with random data.""" 143 | # Synthetic input should be within [0, 255]. 144 | inputs = tf.truncated_normal( 145 | [batch_size] + [height, width, num_channels], 146 | dtype=dtype, 147 | mean=127, 148 | stddev=60, 149 | name='synthetic_inputs') 150 | 151 | labels = tf.random_uniform( 152 | [batch_size], 153 | minval=0, 154 | maxval=num_classes - 1, 155 | dtype=tf.int32, 156 | name='synthetic_labels') 157 | data = tf.data.Dataset.from_tensors((inputs, labels)).repeat() 158 | data = data.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) 159 | return data 160 | 161 | return input_fn 162 | 163 | 164 | def image_bytes_serving_input_fn(image_shape, dtype=tf.float32): 165 | """Serving input fn for raw jpeg images.""" 166 | 167 | def _preprocess_image(image_bytes): 168 | """Preprocess a single raw image.""" 169 | # Bounding box around the whole image. 170 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=dtype, shape=[1, 1, 4]) 171 | height, width, num_channels = image_shape 172 | image = imagenet_preprocessing.preprocess_image( 173 | image_bytes, bbox, height, width, num_channels, is_training=False) 174 | return image 175 | 176 | image_bytes_list = tf.placeholder( 177 | shape=[None], dtype=tf.string, name='input_tensor') 178 | images = tf.map_fn( 179 | _preprocess_image, image_bytes_list, back_prop=False, dtype=dtype) 180 | return tf.estimator.export.TensorServingInputReceiver( 181 | images, {'image_bytes': image_bytes_list}) 182 | 183 | 184 | def override_flags_and_set_envars_for_gpu_thread_pool(flags_obj): 185 | """Override flags and set env_vars for performance. 186 | 187 | These settings exist to test the difference between using stock settings 188 | and manual tuning. It also shows some of the ENV_VARS that can be tweaked to 189 | squeeze a few extra examples per second. These settings are defaulted to the 190 | current platform of interest, which changes over time. 191 | 192 | On systems with small numbers of cpu cores, e.g. under 8 logical cores, 193 | setting up a gpu thread pool with `tf_gpu_thread_mode=gpu_private` may perform 194 | poorly. 195 | 196 | Args: 197 | flags_obj: Current flags, which will be adjusted possibly overriding 198 | what has been set by the user on the command-line. 199 | """ 200 | cpu_count = multiprocessing.cpu_count() 201 | tf.logging.info('Logical CPU cores: %s', cpu_count) 202 | 203 | # Sets up thread pool for each GPU for op scheduling. 204 | per_gpu_thread_count = 1 205 | total_gpu_thread_count = per_gpu_thread_count * flags_obj.num_gpus 206 | os.environ['TF_GPU_THREAD_MODE'] = flags_obj.tf_gpu_thread_mode 207 | os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count) 208 | tf.logging.info('TF_GPU_THREAD_COUNT: %s', os.environ['TF_GPU_THREAD_COUNT']) 209 | tf.logging.info('TF_GPU_THREAD_MODE: %s', os.environ['TF_GPU_THREAD_MODE']) 210 | 211 | # Reduces general thread pool by number of threads used for GPU pool. 212 | main_thread_count = cpu_count - total_gpu_thread_count 213 | flags_obj.inter_op_parallelism_threads = main_thread_count 214 | 215 | # Sets thread count for tf.data. Logical cores minus threads assign to the 216 | # private GPU pool along with 2 thread per GPU for event monitoring and 217 | # sending / receiving tensors. 218 | num_monitoring_threads = 2 * flags_obj.num_gpus 219 | flags_obj.datasets_num_private_threads = (cpu_count - total_gpu_thread_count 220 | - num_monitoring_threads) 221 | 222 | 223 | ################################################################################ 224 | # Functions for running training/eval/validation loops for the model. 225 | ################################################################################ 226 | def learning_rate_with_decay( 227 | batch_size, batch_denom, num_images, boundary_epochs, decay_rates, 228 | base_lr=0.1, warmup=False): 229 | """Get a learning rate that decays step-wise as training progresses. 230 | 231 | Args: 232 | batch_size: the number of examples processed in each training batch. 233 | batch_denom: this value will be used to scale the base learning rate. 234 | `0.1 * batch size` is divided by this number, such that when 235 | batch_denom == batch_size, the initial learning rate will be 0.1. 236 | num_images: total number of images that will be used for training. 237 | boundary_epochs: list of ints representing the epochs at which we 238 | decay the learning rate. 239 | decay_rates: list of floats representing the decay rates to be used 240 | for scaling the learning rate. It should have one more element 241 | than `boundary_epochs`, and all elements should have the same type. 242 | base_lr: Initial learning rate scaled based on batch_denom. 243 | warmup: Run a 5 epoch warmup to the initial lr. 244 | Returns: 245 | Returns a function that takes a single argument - the number of batches 246 | trained so far (global_step)- and returns the learning rate to be used 247 | for training the next batch. 248 | """ 249 | initial_learning_rate = base_lr * batch_size / batch_denom 250 | batches_per_epoch = num_images / batch_size 251 | 252 | # Reduce the learning rate at certain epochs. 253 | # CIFAR-10: divide by 10 at epoch 100, 150, and 200 254 | # ImageNet: divide by 10 at epoch 30, 60, 80, and 90 255 | boundaries = [int(batches_per_epoch * epoch) for epoch in boundary_epochs] 256 | vals = [initial_learning_rate * decay for decay in decay_rates] 257 | 258 | def learning_rate_fn(global_step): 259 | """Builds scaled learning rate function with 5 epoch warm up.""" 260 | lr = tf.train.piecewise_constant(global_step, boundaries, vals) 261 | if warmup: 262 | warmup_steps = int(batches_per_epoch * 5) 263 | warmup_lr = ( 264 | initial_learning_rate * tf.cast(global_step, tf.float32) / tf.cast( 265 | warmup_steps, tf.float32)) 266 | return tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr) 267 | return lr 268 | 269 | return learning_rate_fn 270 | 271 | 272 | def resnet_model_fn(features, labels, mode, model_class, 273 | resnet_size, weight_decay, learning_rate_fn, momentum, 274 | data_format, resnet_version, loss_scale, 275 | loss_filter_fn=None, dtype=resnet_model.DEFAULT_DTYPE, 276 | fine_tune=False, reconst_loss_scale=1.0, use_ce=False, 277 | opt_chos='sgd', clip_grad=False, spectral_norm=False, 278 | ce_scale=1.0, sep_grad_nrom=False, norm_teach_feature=False, 279 | compress_ratio=0.05): 280 | """Shared functionality for different resnet model_fns. 281 | 282 | Initializes the ResnetModel representing the model layers 283 | and uses that model to build the necessary EstimatorSpecs for 284 | the `mode` in question. For training, this means building losses, 285 | the optimizer, and the train op that get passed into the EstimatorSpec. 286 | For evaluation and prediction, the EstimatorSpec is returned without 287 | a train op, but with the necessary parameters for the given mode. 288 | 289 | Args: 290 | features: tensor representing input images 291 | labels: tensor representing class labels for all input images 292 | mode: current estimator mode; should be one of 293 | `tf.estimator.ModeKeys.TRAIN`, `EVALUATE`, `PREDICT` 294 | model_class: a class representing a TensorFlow model that has a __call__ 295 | function. We assume here that this is a subclass of ResnetModel. 296 | resnet_size: A single integer for the size of the ResNet model. 297 | weight_decay: weight decay loss rate used to regularize learned variables. 298 | learning_rate_fn: function that returns the current learning rate given 299 | the current global_step 300 | momentum: momentum term used for optimization 301 | data_format: Input format ('channels_last', 'channels_first', or None). 302 | If set to None, the format is dependent on whether a GPU is available. 303 | resnet_version: Integer representing which version of the ResNet network to 304 | use. See README for details. Valid values: [1, 2] 305 | loss_scale: The factor to scale the loss for numerical stability. A detailed 306 | summary is present in the arg parser help text. 307 | loss_filter_fn: function that takes a string variable name and returns 308 | True if the var should be included in loss calculation, and False 309 | otherwise. If None, batch_normalization variables will be excluded 310 | from the loss. 311 | dtype: the TensorFlow dtype to use for calculations. 312 | fine_tune: If True only train the dense layers(final layers). 313 | 314 | Returns: 315 | EstimatorSpec parameterized according to the input params and the 316 | current mode. 317 | """ 318 | 319 | # Generate a summary node for the images 320 | tf.summary.image('images', features, max_outputs=6) 321 | # tf.summary.scalar('images0_max', tf.reduce_max(features[:,:,:,0])) 322 | # tf.summary.scalar('images0_min', tf.reduce_min(features[:,:,:,0])) 323 | # tf.summary.scalar('images1_max', tf.reduce_max(features[:,:,:,1])) 324 | # tf.summary.scalar('images1_min', tf.reduce_min(features[:,:,:,1])) 325 | # tf.summary.scalar('images2_max', tf.reduce_max(features[:,:,:,2])) 326 | # tf.summary.scalar('images2_min', tf.reduce_min(features[:,:,:,2])) 327 | # Checks that features/images have same data type being used for calculations. 328 | assert features.dtype == dtype 329 | 330 | model = model_class(resnet_size, data_format, spectral_norm=spectral_norm, 331 | resnet_version=resnet_version, dtype=dtype, 332 | reuse=False, offload=True, compress_ratio=compress_ratio) 333 | model_teach = model_class(resnet_size, data_format, spectral_norm=spectral_norm, 334 | resnet_version=resnet_version, dtype=dtype, 335 | reuse=True, offload=False) 336 | 337 | logits, reconst, interF = model(features, mode == tf.estimator.ModeKeys.TRAIN) 338 | logits_teach, interF_teach = model_teach(features, mode == tf.estimator.ModeKeys.TRAIN) 339 | 340 | tf.summary.image('reconstruction', reconst, max_outputs=6) 341 | 342 | # This acts as a no-op if the logits are already in fp32 (provided logits are 343 | # not a SparseTensor). If dtype is is low precision, logits must be cast to 344 | # fp32 for numerical stability. 345 | logits = tf.cast(logits, tf.float32) 346 | 347 | predictions = { 348 | 'classes': tf.argmax(logits, axis=1), 349 | 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 350 | } 351 | 352 | if mode == tf.estimator.ModeKeys.PREDICT: 353 | # Return the predictions and the specification for serving a SavedModel 354 | return tf.estimator.EstimatorSpec( 355 | mode=mode, 356 | predictions=predictions, 357 | export_outputs={ 358 | 'predict': tf.estimator.export.PredictOutput(predictions) 359 | }) 360 | 361 | # Calculate loss, which includes softmax cross entropy and L2 regularization. 362 | # cross_entropy = ce_scale * tf.losses.sparse_softmax_cross_entropy( 363 | # logits=logits, labels=labels) 364 | # pred_porb = tf.nn.softmax(logits) 365 | # teach_prob = tf.nn.softmax(logits_teach) 366 | # cross_entropy = ce_scale*tf.reduce_mean(-tf.reduce_sum(pred_porb*teach_prob, -1)) 367 | # cross_entropy = ce_scale*tf.reduce_mean(-tf.reduce_sum(teach_prob*tf.log(pred_porb), -1)) 368 | 369 | gamma_list = ['resnet_model/batch_normalization_21/gamma'] 370 | beta_list = ['resnet_model/batch_normalization_21/beta'] 371 | cross_entropy = 0. 372 | print('data_format: '+str(data_format)) 373 | cross_entropy = ce_scale*tf.reduce_mean(tf.square(logits - logits_teach)) 374 | if fine_tune: 375 | cross_entropy = tf.losses.sparse_softmax_cross_entropy( 376 | logits=logits, labels=labels) 377 | tf.logging.info('cross_entropy: '+str(cross_entropy)+str(cross_entropy.get_shape().as_list())) 378 | 379 | # reconst_loss = tf.nn.l2_loss(reconst - features) 380 | reconst_loss = reconst_loss_scale*tf.reduce_mean(tf.square((reconst - features)/127.5)) 381 | 382 | # Solve NaN problem 383 | # cross_entropy = tf.where(tf.is_nan(cross_entropy), tf.zeros_like(cross_entropy), cross_entropy) 384 | reconst_loss = tf.where(tf.is_nan(reconst_loss), tf.zeros_like(reconst_loss), reconst_loss) 385 | 386 | # Create a tensor named cross_entropy for logging purposes. 387 | tf.identity(cross_entropy, name='cross_entropy') 388 | tf.summary.scalar('cross_entropy', cross_entropy) 389 | # tf.summary.scalar('feature_loss', feature_loss) 390 | tf.summary.scalar('reconst_loss', reconst_loss) 391 | 392 | # If no loss_filter_fn is passed, assume we want the default behavior, 393 | # which is that batch_normalization variables are excluded from loss. 394 | def exclude_batch_norm(name): 395 | if fine_tune: 396 | return 'resnet_model/dense' in name 397 | else: 398 | return 'endecoder' in name 399 | # return 'batch_normalization' not in name 400 | loss_filter_fn = loss_filter_fn or exclude_batch_norm 401 | 402 | # Add weight decay to the loss. 403 | l2_loss = weight_decay * tf.add_n( 404 | # loss is computed using fp32 for numerical stability. 405 | [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables() 406 | if loss_filter_fn(v.name)]) 407 | 408 | # Solve NaN problem 409 | l2_loss = tf.where(tf.is_nan(l2_loss), tf.zeros_like(l2_loss), l2_loss) 410 | 411 | tf.summary.scalar('l2_loss', l2_loss) 412 | 413 | # if use_ce and not sep_grad_nrom: 414 | if fine_tune: 415 | loss = l2_loss + cross_entropy 416 | else: 417 | if use_ce: 418 | loss = reconst_loss + cross_entropy 419 | else: 420 | loss = reconst_loss 421 | 422 | for v in tf.trainable_variables(): 423 | print(v.name, str(v.get_shape().as_list())) 424 | 425 | if mode == tf.estimator.ModeKeys.TRAIN: 426 | global_step = tf.train.get_or_create_global_step() 427 | 428 | # learning_rate = learning_rate_fn(global_step) 429 | if opt_chos == 'sgd': 430 | learning_rate = learning_rate_fn(global_step) 431 | elif opt_chos == 'fast_sgd_warmup': 432 | # if sep_grad_nrom: 433 | # lr = tf.train.piecewise_constant(global_step, 434 | # [int(14*1281167/32), int(20*1281167/32)], 435 | # [0.016, 0.0016, 0.00016]) 436 | # warmup_steps = int(1281167/32 * 8) 437 | # warmup_lr = ( 438 | # 0.016 * tf.cast(global_step, tf.float32) / tf.cast( 439 | # warmup_steps, tf.float32)) 440 | # else: 441 | lr = tf.train.piecewise_constant(global_step, 442 | [int(7*1281167/32), int(10*1281167/32)], 443 | [0.016, 0.0016, 0.00016]) 444 | warmup_steps = int(1281167/32 * 4) 445 | warmup_lr = ( 446 | 0.016 * tf.cast(global_step, tf.float32) / tf.cast( 447 | warmup_steps, tf.float32)) 448 | learning_rate = tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr) 449 | elif opt_chos == 'adam': 450 | learning_rate = 0.0001 451 | if use_ce: 452 | learning_rate = 0.0001 453 | # learning_rate = 0.00005 454 | # learning_rate = 0.000001 455 | # learning_rate = 0.00001 456 | # learning_rate = 0.000005 457 | 458 | 459 | # Create a tensor named learning_rate for logging purposes 460 | tf.identity(learning_rate, name='learning_rate') 461 | tf.summary.scalar('learning_rate', learning_rate) 462 | 463 | if opt_chos == 'sgd': 464 | optimizer = tf.train.MomentumOptimizer( 465 | learning_rate=learning_rate, 466 | momentum=momentum 467 | ) 468 | elif opt_chos == 'fast_sgd_warmup': 469 | optimizer = tf.train.MomentumOptimizer( 470 | learning_rate=learning_rate, 471 | momentum=momentum 472 | ) 473 | elif opt_chos == 'adam': 474 | if fine_tune: 475 | optimizer = tf.train.AdamOptimizer( 476 | learning_rate=learning_rate) 477 | else: 478 | optimizer = tf.train.AdamOptimizer( 479 | beta1=0.0, 480 | beta2=0.9, 481 | learning_rate=learning_rate) 482 | 483 | endecoder_var = [var for var in tf.trainable_variables() 484 | if 'endecoder' in var.name] 485 | 486 | def _dense_grad_filter(gvs): 487 | """Only apply gradient updates to the final layer. 488 | 489 | This function is used for fine tuning. 490 | 491 | Args: 492 | gvs: list of tuples with gradients and variable info 493 | Returns: 494 | filtered gradients so that only the dense layer remains 495 | """ 496 | return [(g, v) for g, v in gvs if 'resnet_model/dense' in v.name] 497 | 498 | if loss_scale != 1: 499 | 500 | if fine_tune: 501 | scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale) 502 | else: 503 | scaled_grad_vars = optimizer.compute_gradients(loss * loss_scale, 504 | var_list=endecoder_var) 505 | 506 | if fine_tune: 507 | scaled_grad_vars = _dense_grad_filter(scaled_grad_vars) 508 | 509 | # Once the gradient computation is complete we can scale the gradients 510 | # back to the correct scale before passing them to the optimizer. 511 | unscaled_grad_vars = [(grad / loss_scale, var) 512 | for grad, var in scaled_grad_vars] 513 | if clip_grad: 514 | capped_unscaled_grad_vars = [] 515 | for grad, var in unscaled_grad_vars: 516 | capped_unscaled_grad_vars.append((tf.clip_by_value(grad, -1., 1.), var)) 517 | minimize_op = optimizer.apply_gradients(capped_unscaled_grad_vars, global_step) 518 | else: 519 | minimize_op = optimizer.apply_gradients(unscaled_grad_vars, global_step) 520 | else: 521 | if fine_tune: 522 | grad_vars = optimizer.compute_gradients(loss) 523 | else: 524 | grad_vars = optimizer.compute_gradients(loss, 525 | var_list=endecoder_var) 526 | if fine_tune: 527 | grad_vars = _dense_grad_filter(grad_vars) 528 | if clip_grad: 529 | capped_grad_vars = [] 530 | for grad, var in grad_vars: 531 | capped_grad_vars.append((tf.clip_by_value(grad, -1., 1.), var)) 532 | minimize_op = optimizer.apply_gradients(capped_grad_vars, global_step) 533 | else: 534 | minimize_op = optimizer.apply_gradients(grad_vars, global_step) 535 | 536 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 537 | train_op = tf.group(minimize_op, update_ops) 538 | else: 539 | train_op = None 540 | 541 | accuracy = tf.metrics.accuracy(labels, predictions['classes']) 542 | accuracy_top_5 = tf.metrics.mean(tf.nn.in_top_k(predictions=logits, 543 | targets=labels, 544 | k=5, 545 | name='top_5_op')) 546 | metrics = {'accuracy': accuracy, 547 | 'accuracy_top_5': accuracy_top_5} 548 | 549 | # Create a tensor named train_accuracy for logging purposes 550 | tf.identity(accuracy[1], name='train_accuracy') 551 | tf.identity(accuracy_top_5[1], name='train_accuracy_top_5') 552 | tf.summary.scalar('train_accuracy', accuracy[1]) 553 | tf.summary.scalar('train_accuracy_top_5', accuracy_top_5[1]) 554 | 555 | return tf.estimator.EstimatorSpec( 556 | mode=mode, 557 | predictions=predictions, 558 | loss=loss, 559 | train_op=train_op, 560 | eval_metric_ops=metrics) 561 | 562 | 563 | def resnet_main( 564 | flags_obj, model_function, input_function, dataset_name, shape=None): 565 | """Shared main loop for ResNet Models. 566 | 567 | Args: 568 | flags_obj: An object containing parsed flags. See define_resnet_flags() 569 | for details. 570 | model_function: the function that instantiates the Model and builds the 571 | ops for train/eval. This will be passed directly into the estimator. 572 | input_function: the function that processes the dataset and returns a 573 | dataset that the estimator can train on. This will be wrapped with 574 | all the relevant flags for running and passed to estimator. 575 | dataset_name: the name of the dataset for training and evaluation. This is 576 | used for logging purpose. 577 | shape: list of ints representing the shape of the images used for training. 578 | This is only used if flags_obj.export_dir is passed. 579 | """ 580 | 581 | model_helpers.apply_clean(flags.FLAGS) 582 | 583 | # Ensures flag override logic is only executed if explicitly triggered. 584 | if flags_obj.tf_gpu_thread_mode: 585 | override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) 586 | 587 | # Creates session config. allow_soft_placement = True, is required for 588 | # multi-GPU and is not harmful for other modes. 589 | session_config = tf.ConfigProto( 590 | inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, 591 | intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, 592 | allow_soft_placement=True) 593 | 594 | distribution_strategy = distribution_utils.get_distribution_strategy( 595 | flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) 596 | 597 | # Creates a `RunConfig` that checkpoints every 24 hours which essentially 598 | # results in checkpoints determined only by `epochs_between_evals`. 599 | run_config = tf.estimator.RunConfig( 600 | train_distribute=distribution_strategy, 601 | session_config=session_config, 602 | save_checkpoints_secs=60*60*24) 603 | 604 | # Initializes model with all but the dense layer from pretrained ResNet. 605 | if flags_obj.pretrained_model_checkpoint_path is not None: 606 | if flags_obj.fine_tune: 607 | if string.lower(flags_obj.optimizer) == 'adam': 608 | if flags_obj.no_dense_init: 609 | warm_start_settings = tf.estimator.WarmStartSettings( 610 | flags_obj.pretrained_model_checkpoint_path, 611 | vars_to_warm_start=['^(?!.*(resnet_model/dense|beta1_power|beta2_power|Adam|global_step))']) 612 | # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))']) 613 | else: 614 | warm_start_settings = tf.estimator.WarmStartSettings( 615 | flags_obj.pretrained_model_checkpoint_path, 616 | vars_to_warm_start=['^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|beta1_power|beta2_power|Adam|global_step))']) 617 | # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))']) 618 | else: 619 | if flags_obj.no_dense_init: 620 | warm_start_settings = tf.estimator.WarmStartSettings( 621 | flags_obj.pretrained_model_checkpoint_path, 622 | vars_to_warm_start=['^(?!.*(resnet_model/dense|Momentum|global_step))']) 623 | else: 624 | warm_start_settings = tf.estimator.WarmStartSettings( 625 | flags_obj.pretrained_model_checkpoint_path, 626 | vars_to_warm_start=['^(?!.*(resnet_model/dense/kernel/Momentum|resnet_model/dense/bias/Momentum|global_step))']) 627 | # vars_to_warm_start=['^(?!.*(resnet_model/dense|global_step))']) 628 | else: 629 | if string.lower(flags_obj.optimizer) == 'adam': 630 | warm_start_settings = tf.estimator.WarmStartSettings( 631 | flags_obj.pretrained_model_checkpoint_path, 632 | vars_to_warm_start=['^(?!.*(endecoder|Momentum|beta1_power|beta2_power|global_step))']) 633 | # vars_to_warm_start='^(?!.*dense)') 634 | else: 635 | warm_start_settings = tf.estimator.WarmStartSettings( 636 | flags_obj.pretrained_model_checkpoint_path, 637 | vars_to_warm_start=['^(?!.*(endecoder|global_step))']) 638 | # vars_to_warm_start='^(?!.*dense)') 639 | else: 640 | warm_start_settings = None 641 | 642 | classifier = tf.estimator.Estimator( 643 | model_fn=model_function, model_dir=flags_obj.model_dir, config=run_config, 644 | warm_start_from=warm_start_settings, params={ 645 | 'resnet_size': int(flags_obj.resnet_size), 646 | 'data_format': flags_obj.data_format, 647 | 'batch_size': flags_obj.batch_size, 648 | 'resnet_version': int(flags_obj.resnet_version), 649 | 'loss_scale': flags_core.get_loss_scale(flags_obj), 650 | 'dtype': flags_core.get_tf_dtype(flags_obj), 651 | 'fine_tune': flags_obj.fine_tune, 652 | 'reconst_loss_scale': flags_obj.reconst_loss_scale, 653 | 'use_ce': flags_obj.use_ce, 654 | 'optimizer': string.lower(flags_obj.optimizer), 655 | 'clip_grad': flags_obj.clip_grad, 656 | 'spectral_norm': flags_obj.spectral_norm, 657 | 'ce_scale': flags_obj.ce_scale, 658 | 'sep_grad_nrom': flags_obj.sep_grad_nrom, 659 | 'norm_teach_feature':flags_obj.norm_teach_feature, 660 | 'no_dense_init':flags_obj.no_dense_init, 661 | 'compress_ratio':flags_obj.compress_ratio 662 | }) 663 | 664 | run_params = { 665 | 'batch_size': flags_obj.batch_size, 666 | 'dtype': flags_core.get_tf_dtype(flags_obj), 667 | 'resnet_size': flags_obj.resnet_size, 668 | 'resnet_version': flags_obj.resnet_version, 669 | 'synthetic_data': flags_obj.use_synthetic_data, 670 | 'train_epochs': flags_obj.train_epochs, 671 | 'fine_tune': flags_obj.fine_tune, 672 | 'reconst_loss_scale': flags_obj.reconst_loss_scale, 673 | 'use_ce': flags_obj.use_ce, 674 | 'optimizer': string.lower(flags_obj.optimizer), 675 | 'clip_grad': flags_obj.clip_grad, 676 | 'spectral_norm': flags_obj.spectral_norm, 677 | 'ce_scale':flags_obj.ce_scale, 678 | 'sep_grad_nrom': flags_obj.sep_grad_nrom, 679 | 'norm_teach_feature': flags_obj.norm_teach_feature, 680 | 'no_dense_init': flags_obj.no_dense_init, 681 | 'compress_ratio': flags_obj.compress_ratio, 682 | } 683 | if flags_obj.use_synthetic_data: 684 | dataset_name = dataset_name + '-synthetic' 685 | 686 | benchmark_logger = logger.get_benchmark_logger() 687 | benchmark_logger.log_run_info('resnet', dataset_name, run_params, 688 | test_id=flags_obj.benchmark_test_id) 689 | 690 | train_hooks = hooks_helper.get_train_hooks( 691 | flags_obj.hooks, 692 | model_dir=flags_obj.model_dir, 693 | batch_size=flags_obj.batch_size) 694 | 695 | def input_fn_train(num_epochs): 696 | return input_function( 697 | is_training=True, 698 | data_dir=flags_obj.data_dir, 699 | batch_size=distribution_utils.per_device_batch_size( 700 | flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), 701 | num_epochs=num_epochs, 702 | dtype=flags_core.get_tf_dtype(flags_obj), 703 | datasets_num_private_threads=flags_obj.datasets_num_private_threads, 704 | num_parallel_batches=flags_obj.datasets_num_parallel_batches) 705 | 706 | def input_fn_eval(): 707 | return input_function( 708 | is_training=False, 709 | data_dir=flags_obj.data_dir, 710 | batch_size=distribution_utils.per_device_batch_size( 711 | flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), 712 | num_epochs=1, 713 | dtype=flags_core.get_tf_dtype(flags_obj)) 714 | 715 | if flags_obj.eval_only or not flags_obj.train_epochs: 716 | # If --eval_only is set, perform a single loop with zero train epochs. 717 | schedule, n_loops = [0], 1 718 | else: 719 | # Compute the number of times to loop while training. All but the last 720 | # pass will train for `epochs_between_evals` epochs, while the last will 721 | # train for the number needed to reach `training_epochs`. For instance if 722 | # train_epochs = 25 and epochs_between_evals = 10 723 | # schedule will be set to [10, 10, 5]. That is to say, the loop will: 724 | # Train for 10 epochs and then evaluate. 725 | # Train for another 10 epochs and then evaluate. 726 | # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. 727 | n_loops = math.ceil(flags_obj.train_epochs / flags_obj.epochs_between_evals) 728 | schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))] 729 | schedule[-1] = flags_obj.train_epochs - sum(schedule[:-1]) # over counting. 730 | 731 | 732 | print('schedule: ', schedule, flags_obj.epochs_between_evals, flags_obj.max_train_steps) 733 | for cycle_index, num_train_epochs in enumerate(schedule): 734 | tf.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) 735 | 736 | if num_train_epochs: 737 | classifier.train(input_fn=lambda: input_fn_train(num_train_epochs), 738 | hooks=train_hooks, max_steps=flags_obj.max_train_steps) 739 | 740 | tf.logging.info('Starting to evaluate.') 741 | 742 | # flags_obj.max_train_steps is generally associated with testing and 743 | # profiling. As a result it is frequently called with synthetic data, which 744 | # will iterate forever. Passing steps=flags_obj.max_train_steps allows the 745 | # eval (which is generally unimportant in those circumstances) to terminate. 746 | # Note that eval will run for max_train_steps each loop, regardless of the 747 | # global_step count. 748 | eval_results = classifier.evaluate(input_fn=input_fn_eval, 749 | steps=flags_obj.max_train_steps) 750 | 751 | benchmark_logger.log_evaluation_result(eval_results) 752 | 753 | if model_helpers.past_stop_threshold( 754 | flags_obj.stop_threshold, eval_results['accuracy']): 755 | break 756 | 757 | if flags_obj.export_dir is not None: 758 | # Exports a saved model for the given classifier. 759 | export_dtype = flags_core.get_tf_dtype(flags_obj) 760 | if flags_obj.image_bytes_as_serving_input: 761 | input_receiver_fn = functools.partial( 762 | image_bytes_serving_input_fn, shape, dtype=export_dtype) 763 | else: 764 | input_receiver_fn = export.build_tensor_serving_input_receiver_fn( 765 | shape, batch_size=flags_obj.batch_size, dtype=export_dtype) 766 | classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, 767 | strip_default_attrs=True) 768 | 769 | 770 | def define_resnet_flags(resnet_size_choices=None): 771 | """Add flags and validators for ResNet.""" 772 | flags_core.define_base() 773 | flags_core.define_performance(num_parallel_calls=False, 774 | tf_gpu_thread_mode=True, 775 | datasets_num_private_threads=True, 776 | datasets_num_parallel_batches=True) 777 | flags_core.define_image() 778 | flags_core.define_benchmark() 779 | flags.adopt_module_key_flags(flags_core) 780 | 781 | flags.DEFINE_enum( 782 | name='resnet_version', short_name='rv', default='2', 783 | enum_values=['1', '2'], 784 | help=flags_core.help_wrap( 785 | 'Version of ResNet. (1 or 2) See README.md for details.')) 786 | flags.DEFINE_bool( 787 | name='fine_tune', short_name='ft', default=False, 788 | help=flags_core.help_wrap( 789 | 'If True do not train any parameters except for the final layer.')) 790 | flags.DEFINE_string( 791 | name='pretrained_model_checkpoint_path', short_name='pmcp', default=None, 792 | help=flags_core.help_wrap( 793 | 'If not None initialize all the network except the final layer with ' 794 | 'these values')) 795 | flags.DEFINE_boolean( 796 | name='eval_only', default=False, 797 | help=flags_core.help_wrap('Skip training and only perform evaluation on ' 798 | 'the latest checkpoint.')) 799 | flags.DEFINE_boolean( 800 | name='image_bytes_as_serving_input', default=False, 801 | help=flags_core.help_wrap( 802 | 'If True exports savedmodel with serving signature that accepts ' 803 | 'JPEG image bytes instead of a fixed size [HxWxC] tensor that ' 804 | 'represents the image. The former is easier to use for serving at ' 805 | 'the expense of image resize/cropping being done as part of model ' 806 | 'inference. Note, this flag only applies to ImageNet and cannot ' 807 | 'be used for CIFAR.')) 808 | flags.DEFINE_float( 809 | name='reconst_loss_scale', default=10.0, 810 | help=flags_core.help_wrap( 811 | 'scale the reconst_loss' 812 | )) 813 | flags.DEFINE_boolean( 814 | name='use_ce', default=False, 815 | help=flags_core.help_wrap( 816 | 'use cross entropy loss for compressive sensing training')) 817 | flags.DEFINE_string( 818 | name='optimizer', short_name='opt', 819 | # default='sgd', 820 | default='adam', 821 | help=flags_core.help_wrap('Choose optimizer for training')) 822 | flags.DEFINE_boolean( 823 | name='clip_grad', default=False, 824 | help=flags_core.help_wrap( 825 | 'whether to clip weights during training')) 826 | flags.DEFINE_boolean( 827 | name='spectral_norm', short_name='sn', default=True, 828 | help=flags_core.help_wrap( 829 | 'whether to user spectral norm in the cs part')) 830 | flags.DEFINE_float( 831 | name='ce_scale', default=1.0, 832 | help=flags_core.help_wrap( 833 | 'scale the cross_entropy' 834 | )) 835 | flags.DEFINE_boolean( 836 | name='sep_grad_nrom', default=False, 837 | help=flags_core.help_wrap( 838 | 'spearate the gradients from reconstruction and ce, and norm the ce grad')) 839 | flags.DEFINE_boolean( 840 | name='norm_teach_feature', default=False, 841 | help=flags_core.help_wrap( 842 | 'norm each channel of teaching feature with BN params')) 843 | flags.DEFINE_boolean( 844 | name='no_dense_init', default=False, 845 | help=flags_core.help_wrap( 846 | 'dont init resenet/dense during fine tuning')) 847 | flags.DEFINE_float( 848 | name='compress_ratio', default=0.1, 849 | help=flags_core.help_wrap( 850 | 'the compress ratio of the offloading layer')) 851 | 852 | choice_kwargs = dict( 853 | name='resnet_size', short_name='rs', default='50', 854 | help=flags_core.help_wrap('The size of the ResNet model to use.')) 855 | 856 | if resnet_size_choices is None: 857 | flags.DEFINE_string(**choice_kwargs) 858 | else: 859 | flags.DEFINE_enum(enum_values=resnet_size_choices, **choice_kwargs) 860 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/export/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/export/export.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Convenience functions for exporting models as SavedModels or other types.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def build_tensor_serving_input_receiver_fn(shape, dtype=tf.float32, 25 | batch_size=1): 26 | """Returns a input_receiver_fn that can be used during serving. 27 | 28 | This expects examples to come through as float tensors, and simply 29 | wraps them as TensorServingInputReceivers. 30 | 31 | Arguably, this should live in tf.estimator.export. Testing here first. 32 | 33 | Args: 34 | shape: list representing target size of a single example. 35 | dtype: the expected datatype for the input example 36 | batch_size: number of input tensors that will be passed for prediction 37 | 38 | Returns: 39 | A function that itself returns a TensorServingInputReceiver. 40 | """ 41 | def serving_input_receiver_fn(): 42 | # Prep a placeholder where the input example will be fed in 43 | features = tf.placeholder( 44 | dtype=dtype, shape=[batch_size] + shape, name='input_tensor') 45 | 46 | return tf.estimator.export.TensorServingInputReceiver( 47 | features=features, receiver_tensors=features) 48 | 49 | return serving_input_receiver_fn 50 | -------------------------------------------------------------------------------- /utils/flags/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/flags/_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Flags which will be nearly universal across models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | import tensorflow as tf 23 | 24 | from utils.flags._conventions import help_wrap 25 | from utils.logs import hooks_helper 26 | 27 | 28 | def define_base(data_dir=True, model_dir=True, clean=True, train_epochs=True, 29 | epochs_between_evals=True, stop_threshold=True, batch_size=True, 30 | num_gpu=True, hooks=True, export_dir=True): 31 | """Register base flags. 32 | 33 | Args: 34 | data_dir: Create a flag for specifying the input data directory. 35 | model_dir: Create a flag for specifying the model file directory. 36 | train_epochs: Create a flag to specify the number of training epochs. 37 | epochs_between_evals: Create a flag to specify the frequency of testing. 38 | stop_threshold: Create a flag to specify a threshold accuracy or other 39 | eval metric which should trigger the end of training. 40 | batch_size: Create a flag to specify the batch size. 41 | num_gpu: Create a flag to specify the number of GPUs used. 42 | hooks: Create a flag to specify hooks for logging. 43 | export_dir: Create a flag to specify where a SavedModel should be exported. 44 | 45 | Returns: 46 | A list of flags for core.py to marks as key flags. 47 | """ 48 | key_flags = [] 49 | 50 | if data_dir: 51 | flags.DEFINE_string( 52 | name="data_dir", short_name="dd", default="/tmp", 53 | help=help_wrap("The location of the input data.")) 54 | key_flags.append("data_dir") 55 | 56 | if model_dir: 57 | flags.DEFINE_string( 58 | name="model_dir", short_name="md", default="/tmp", 59 | help=help_wrap("The location of the model checkpoint files.")) 60 | key_flags.append("model_dir") 61 | 62 | if clean: 63 | flags.DEFINE_boolean( 64 | name="clean", default=False, 65 | help=help_wrap("If set, model_dir will be removed if it exists.")) 66 | key_flags.append("clean") 67 | 68 | if train_epochs: 69 | flags.DEFINE_integer( 70 | name="train_epochs", short_name="te", default=1, 71 | help=help_wrap("The number of epochs used to train.")) 72 | key_flags.append("train_epochs") 73 | 74 | if epochs_between_evals: 75 | flags.DEFINE_integer( 76 | name="epochs_between_evals", short_name="ebe", default=1, 77 | help=help_wrap("The number of training epochs to run between " 78 | "evaluations.")) 79 | key_flags.append("epochs_between_evals") 80 | 81 | if stop_threshold: 82 | flags.DEFINE_float( 83 | name="stop_threshold", short_name="st", 84 | default=None, 85 | help=help_wrap("If passed, training will stop at the earlier of " 86 | "train_epochs and when the evaluation metric is " 87 | "greater than or equal to stop_threshold.")) 88 | 89 | if batch_size: 90 | flags.DEFINE_integer( 91 | name="batch_size", short_name="bs", default=32, 92 | help=help_wrap("Batch size for training and evaluation. When using " 93 | "multiple gpus, this is the global batch size for " 94 | "all devices. For example, if the batch size is 32 " 95 | "and there are 4 GPUs, each GPU will get 8 examples on " 96 | "each step.")) 97 | key_flags.append("batch_size") 98 | 99 | if num_gpu: 100 | flags.DEFINE_integer( 101 | name="num_gpus", short_name="ng", 102 | default=1 if tf.test.is_gpu_available() else 0, 103 | help=help_wrap( 104 | "How many GPUs to use with the DistributionStrategies API. The " 105 | "default is 1 if TensorFlow can detect a GPU, and 0 otherwise.")) 106 | 107 | if hooks: 108 | # Construct a pretty summary of hooks. 109 | hook_list_str = ( 110 | u"\ufeff Hook:\n" + u"\n".join([u"\ufeff {}".format(key) for key 111 | in hooks_helper.HOOKS])) 112 | flags.DEFINE_list( 113 | name="hooks", short_name="hk", default="LoggingTensorHook", 114 | help=help_wrap( 115 | u"A list of (case insensitive) strings to specify the names of " 116 | u"training hooks.\n{}\n\ufeff Example: `--hooks ProfilerHook," 117 | u"ExamplesPerSecondHook`\n See official.utils.logs.hooks_helper " 118 | u"for details.".format(hook_list_str)) 119 | ) 120 | key_flags.append("hooks") 121 | 122 | if export_dir: 123 | flags.DEFINE_string( 124 | name="export_dir", short_name="ed", default=None, 125 | help=help_wrap("If set, a SavedModel serialization of the model will " 126 | "be exported to this directory at the end of training. " 127 | "See the README for more details and relevant links.") 128 | ) 129 | key_flags.append("export_dir") 130 | 131 | return key_flags 132 | 133 | 134 | def get_num_gpus(flags_obj): 135 | """Treat num_gpus=-1 as 'use all'.""" 136 | if flags_obj.num_gpus != -1: 137 | return flags_obj.num_gpus 138 | 139 | from tensorflow.python.client import device_lib # pylint: disable=g-import-not-at-top 140 | local_device_protos = device_lib.list_local_devices() 141 | return sum([1 for d in local_device_protos if d.device_type == "GPU"]) 142 | -------------------------------------------------------------------------------- /utils/flags/_benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Flags for benchmarking models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | 23 | from utils.flags._conventions import help_wrap 24 | 25 | 26 | def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True): 27 | """Register benchmarking flags. 28 | 29 | Args: 30 | benchmark_log_dir: Create a flag to specify location for benchmark logging. 31 | bigquery_uploader: Create flags for uploading results to BigQuery. 32 | 33 | Returns: 34 | A list of flags for core.py to marks as key flags. 35 | """ 36 | 37 | key_flags = [] 38 | 39 | flags.DEFINE_enum( 40 | name="benchmark_logger_type", default="BaseBenchmarkLogger", 41 | enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger", 42 | "BenchmarkBigQueryLogger"], 43 | help=help_wrap("The type of benchmark logger to use. Defaults to using " 44 | "BaseBenchmarkLogger which logs to STDOUT. Different " 45 | "loggers will require other flags to be able to work.")) 46 | flags.DEFINE_string( 47 | name="benchmark_test_id", short_name="bti", default=None, 48 | help=help_wrap("The unique test ID of the benchmark run. It could be the " 49 | "combination of key parameters. It is hardware " 50 | "independent and could be used compare the performance " 51 | "between different test runs. This flag is designed for " 52 | "human consumption, and does not have any impact within " 53 | "the system.")) 54 | 55 | if benchmark_log_dir: 56 | flags.DEFINE_string( 57 | name="benchmark_log_dir", short_name="bld", default=None, 58 | help=help_wrap("The location of the benchmark logging.") 59 | ) 60 | 61 | if bigquery_uploader: 62 | flags.DEFINE_string( 63 | name="gcp_project", short_name="gp", default=None, 64 | help=help_wrap( 65 | "The GCP project name where the benchmark will be uploaded.")) 66 | 67 | flags.DEFINE_string( 68 | name="bigquery_data_set", short_name="bds", default="test_benchmark", 69 | help=help_wrap( 70 | "The Bigquery dataset name where the benchmark will be uploaded.")) 71 | 72 | flags.DEFINE_string( 73 | name="bigquery_run_table", short_name="brt", default="benchmark_run", 74 | help=help_wrap("The Bigquery table name where the benchmark run " 75 | "information will be uploaded.")) 76 | 77 | flags.DEFINE_string( 78 | name="bigquery_run_status_table", short_name="brst", 79 | default="benchmark_run_status", 80 | help=help_wrap("The Bigquery table name where the benchmark run " 81 | "status information will be uploaded.")) 82 | 83 | flags.DEFINE_string( 84 | name="bigquery_metric_table", short_name="bmt", 85 | default="benchmark_metric", 86 | help=help_wrap("The Bigquery table name where the benchmark metric " 87 | "information will be uploaded.")) 88 | 89 | @flags.multi_flags_validator( 90 | ["benchmark_logger_type", "benchmark_log_dir"], 91 | message="--benchmark_logger_type=BenchmarkFileLogger will require " 92 | "--benchmark_log_dir being set") 93 | def _check_benchmark_log_dir(flags_dict): 94 | benchmark_logger_type = flags_dict["benchmark_logger_type"] 95 | if benchmark_logger_type == "BenchmarkFileLogger": 96 | return flags_dict["benchmark_log_dir"] 97 | return True 98 | 99 | return key_flags 100 | -------------------------------------------------------------------------------- /utils/flags/_conventions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Central location for shared arparse convention definitions.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import codecs 22 | import functools 23 | 24 | from absl import app as absl_app 25 | from absl import flags 26 | 27 | 28 | # This codifies help string conventions and makes it easy to update them if 29 | # necessary. Currently the only major effect is that help bodies start on the 30 | # line after flags are listed. All flag definitions should wrap the text bodies 31 | # with help wrap when calling DEFINE_*. 32 | _help_wrap = functools.partial(flags.text_wrap, length=80, indent="", 33 | firstline_indent="\n") 34 | 35 | 36 | # Pretty formatting causes issues when utf-8 is not installed on a system. 37 | try: 38 | codecs.lookup("utf-8") 39 | help_wrap = _help_wrap 40 | except LookupError: 41 | def help_wrap(text, *args, **kwargs): 42 | return _help_wrap(text, *args, **kwargs).replace("\ufeff", "") 43 | 44 | 45 | # Replace None with h to also allow -h 46 | absl_app.HelpshortFlag.SHORT_NAME = "h" 47 | -------------------------------------------------------------------------------- /utils/flags/_device.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Flags for managing compute devices. Currently only contains TPU flags.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | import tensorflow as tf 23 | 24 | from utils.flags._conventions import help_wrap 25 | 26 | 27 | def require_cloud_storage(flag_names): 28 | """Register a validator to check directory flags. 29 | Args: 30 | flag_names: An iterable of strings containing the names of flags to be 31 | checked. 32 | """ 33 | msg = "TPU requires GCS path for {}".format(", ".join(flag_names)) 34 | @flags.multi_flags_validator(["tpu"] + flag_names, message=msg) 35 | def _path_check(flag_values): # pylint: disable=missing-docstring 36 | if flag_values["tpu"] is None: 37 | return True 38 | 39 | valid_flags = True 40 | for key in flag_names: 41 | if not flag_values[key].startswith("gs://"): 42 | tf.logging.error("{} must be a GCS path.".format(key)) 43 | valid_flags = False 44 | 45 | return valid_flags 46 | 47 | 48 | def define_device(tpu=True): 49 | """Register device specific flags. 50 | Args: 51 | tpu: Create flags to specify TPU operation. 52 | Returns: 53 | A list of flags for core.py to marks as key flags. 54 | """ 55 | 56 | key_flags = [] 57 | 58 | if tpu: 59 | flags.DEFINE_string( 60 | name="tpu", default=None, 61 | help=help_wrap( 62 | "The Cloud TPU to use for training. This should be either the name " 63 | "used when creating the Cloud TPU, or a " 64 | "grpc://ip.address.of.tpu:8470 url. Passing `local` will use the" 65 | "CPU of the local instance instead. (Good for debugging.)")) 66 | key_flags.append("tpu") 67 | 68 | flags.DEFINE_string( 69 | name="tpu_zone", default=None, 70 | help=help_wrap( 71 | "[Optional] GCE zone where the Cloud TPU is located in. If not " 72 | "specified, we will attempt to automatically detect the GCE " 73 | "project from metadata.")) 74 | 75 | flags.DEFINE_string( 76 | name="tpu_gcp_project", default=None, 77 | help=help_wrap( 78 | "[Optional] Project name for the Cloud TPU-enabled project. If not " 79 | "specified, we will attempt to automatically detect the GCE " 80 | "project from metadata.")) 81 | 82 | flags.DEFINE_integer(name="num_tpu_shards", default=8, 83 | help=help_wrap("Number of shards (TPU chips).")) 84 | 85 | return key_flags 86 | -------------------------------------------------------------------------------- /utils/flags/_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Misc flags.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | 23 | from utils.flags._conventions import help_wrap 24 | 25 | 26 | def define_image(data_format=True): 27 | """Register image specific flags. 28 | 29 | Args: 30 | data_format: Create a flag to specify image axis convention. 31 | 32 | Returns: 33 | A list of flags for core.py to marks as key flags. 34 | """ 35 | 36 | key_flags = [] 37 | 38 | if data_format: 39 | flags.DEFINE_enum( 40 | name="data_format", short_name="df", default=None, 41 | enum_values=["channels_first", "channels_last"], 42 | help=help_wrap( 43 | "A flag to override the data format used in the model. " 44 | "channels_first provides a performance boost on GPU but is not " 45 | "always compatible with CPU. If left unspecified, the data format " 46 | "will be chosen automatically based on whether TensorFlow was " 47 | "built for CPU or GPU.")) 48 | key_flags.append("data_format") 49 | 50 | return key_flags 51 | -------------------------------------------------------------------------------- /utils/flags/_performance.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Register flags for optimizing performance.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import multiprocessing 22 | 23 | from absl import flags # pylint: disable=g-bad-import-order 24 | import tensorflow as tf # pylint: disable=g-bad-import-order 25 | 26 | from utils.flags._conventions import help_wrap 27 | 28 | 29 | # Map string to (TensorFlow dtype, default loss scale) 30 | DTYPE_MAP = { 31 | "fp16": (tf.float16, 128), 32 | "fp32": (tf.float32, 1), 33 | } 34 | 35 | 36 | def get_tf_dtype(flags_obj): 37 | return DTYPE_MAP[flags_obj.dtype][0] 38 | 39 | 40 | def get_loss_scale(flags_obj): 41 | if flags_obj.loss_scale is not None: 42 | return flags_obj.loss_scale 43 | return DTYPE_MAP[flags_obj.dtype][1] 44 | 45 | 46 | def define_performance(num_parallel_calls=True, inter_op=True, intra_op=True, 47 | synthetic_data=True, max_train_steps=True, dtype=True, 48 | all_reduce_alg=True, tf_gpu_thread_mode=False, 49 | datasets_num_private_threads=False, 50 | datasets_num_parallel_batches=False): 51 | """Register flags for specifying performance tuning arguments. 52 | 53 | Args: 54 | num_parallel_calls: Create a flag to specify parallelism of data loading. 55 | inter_op: Create a flag to allow specification of inter op threads. 56 | intra_op: Create a flag to allow specification of intra op threads. 57 | synthetic_data: Create a flag to allow the use of synthetic data. 58 | max_train_steps: Create a flags to allow specification of maximum number 59 | of training steps 60 | dtype: Create flags for specifying dtype. 61 | all_reduce_alg: If set forces a specific algorithm for multi-gpu. 62 | tf_gpu_thread_mode: gpu_private triggers us of private thread pool. 63 | datasets_num_private_threads: Number of private threads for datasets. 64 | datasets_num_parallel_batches: Determines how many batches to process in 65 | parallel when using map and batch from tf.data. 66 | 67 | Returns: 68 | A list of flags for core.py to marks as key flags. 69 | """ 70 | 71 | key_flags = [] 72 | if num_parallel_calls: 73 | flags.DEFINE_integer( 74 | name="num_parallel_calls", short_name="npc", 75 | default=multiprocessing.cpu_count(), 76 | help=help_wrap("The number of records that are processed in parallel " 77 | "during input processing. This can be optimized per " 78 | "data set but for generally homogeneous data sets, " 79 | "should be approximately the number of available CPU " 80 | "cores. (default behavior)")) 81 | 82 | if inter_op: 83 | flags.DEFINE_integer( 84 | name="inter_op_parallelism_threads", short_name="inter", default=0, 85 | help=help_wrap("Number of inter_op_parallelism_threads to use for CPU. " 86 | "See TensorFlow config.proto for details.") 87 | ) 88 | 89 | if intra_op: 90 | flags.DEFINE_integer( 91 | name="intra_op_parallelism_threads", short_name="intra", default=0, 92 | help=help_wrap("Number of intra_op_parallelism_threads to use for CPU. " 93 | "See TensorFlow config.proto for details.")) 94 | 95 | if synthetic_data: 96 | flags.DEFINE_bool( 97 | name="use_synthetic_data", short_name="synth", default=False, 98 | help=help_wrap( 99 | "If set, use fake data (zeroes) instead of a real dataset. " 100 | "This mode is useful for performance debugging, as it removes " 101 | "input processing steps, but will not learn anything.")) 102 | 103 | if max_train_steps: 104 | flags.DEFINE_integer( 105 | name="max_train_steps", short_name="mts", default=None, help=help_wrap( 106 | "The model will stop training if the global_step reaches this " 107 | "value. If not set, training will run until the specified number " 108 | "of epochs have run as usual. It is generally recommended to set " 109 | "--train_epochs=1 when using this flag." 110 | )) 111 | 112 | if dtype: 113 | flags.DEFINE_enum( 114 | name="dtype", short_name="dt", default="fp32", 115 | enum_values=DTYPE_MAP.keys(), 116 | help=help_wrap("The TensorFlow datatype used for calculations. " 117 | "Variables may be cast to a higher precision on a " 118 | "case-by-case basis for numerical stability.")) 119 | 120 | flags.DEFINE_integer( 121 | name="loss_scale", short_name="ls", default=None, 122 | help=help_wrap( 123 | "The amount to scale the loss by when the model is run. Before " 124 | "gradients are computed, the loss is multiplied by the loss scale, " 125 | "making all gradients loss_scale times larger. To adjust for this, " 126 | "gradients are divided by the loss scale before being applied to " 127 | "variables. This is mathematically equivalent to training without " 128 | "a loss scale, but the loss scale helps avoid some intermediate " 129 | "gradients from underflowing to zero. If not provided the default " 130 | "for fp16 is 128 and 1 for all other dtypes.")) 131 | 132 | loss_scale_val_msg = "loss_scale should be a positive integer." 133 | @flags.validator(flag_name="loss_scale", message=loss_scale_val_msg) 134 | def _check_loss_scale(loss_scale): # pylint: disable=unused-variable 135 | if loss_scale is None: 136 | return True # null case is handled in get_loss_scale() 137 | 138 | return loss_scale > 0 139 | 140 | if all_reduce_alg: 141 | flags.DEFINE_string( 142 | name="all_reduce_alg", short_name="ara", default=None, 143 | help=help_wrap("Defines the algorithm to use for performing all-reduce." 144 | "See tf.contrib.distribute.AllReduceCrossTowerOps for " 145 | "more details and available options.")) 146 | 147 | if tf_gpu_thread_mode: 148 | flags.DEFINE_string( 149 | name="tf_gpu_thread_mode", short_name="gt_mode", default=None, 150 | help=help_wrap( 151 | "Whether and how the GPU device uses its own threadpool.") 152 | ) 153 | 154 | if datasets_num_private_threads: 155 | flags.DEFINE_integer( 156 | name="datasets_num_private_threads", 157 | default=None, 158 | help=help_wrap( 159 | "Number of threads for a private threadpool created for all" 160 | "datasets computation..") 161 | ) 162 | 163 | if datasets_num_parallel_batches: 164 | flags.DEFINE_integer( 165 | name="datasets_num_parallel_batches", 166 | default=None, 167 | help=help_wrap( 168 | "Determines how many batches to process in parallel when using " 169 | "map and batch from tf.data.") 170 | ) 171 | 172 | return key_flags 173 | -------------------------------------------------------------------------------- /utils/flags/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Public interface for flag definition. 16 | 17 | See _example.py for detailed instructions on defining flags. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import functools 25 | import sys 26 | 27 | from absl import app as absl_app 28 | from absl import flags 29 | 30 | from utils.flags import _base 31 | from utils.flags import _benchmark 32 | from utils.flags import _conventions 33 | from utils.flags import _device 34 | from utils.flags import _misc 35 | from utils.flags import _performance 36 | 37 | 38 | def set_defaults(**kwargs): 39 | for key, value in kwargs.items(): 40 | flags.FLAGS.set_default(name=key, value=value) 41 | 42 | 43 | def parse_flags(argv=None): 44 | """Reset flags and reparse. Currently only used in testing.""" 45 | flags.FLAGS.unparse_flags() 46 | absl_app.parse_flags_with_usage(argv or sys.argv) 47 | 48 | 49 | def register_key_flags_in_core(f): 50 | """Defines a function in core.py, and registers its key flags. 51 | 52 | absl uses the location of a flags.declare_key_flag() to determine the context 53 | in which a flag is key. By making all declares in core, this allows model 54 | main functions to call flags.adopt_module_key_flags() on core and correctly 55 | chain key flags. 56 | 57 | Args: 58 | f: The function to be wrapped 59 | 60 | Returns: 61 | The "core-defined" version of the input function. 62 | """ 63 | 64 | def core_fn(*args, **kwargs): 65 | key_flags = f(*args, **kwargs) 66 | [flags.declare_key_flag(fl) for fl in key_flags] # pylint: disable=expression-not-assigned 67 | return core_fn 68 | 69 | 70 | define_base = register_key_flags_in_core(_base.define_base) 71 | # Remove options not relevant for Eager from define_base(). 72 | define_base_eager = register_key_flags_in_core(functools.partial( 73 | _base.define_base, epochs_between_evals=False, stop_threshold=False, 74 | hooks=False)) 75 | define_benchmark = register_key_flags_in_core(_benchmark.define_benchmark) 76 | define_device = register_key_flags_in_core(_device.define_device) 77 | define_image = register_key_flags_in_core(_misc.define_image) 78 | define_performance = register_key_flags_in_core(_performance.define_performance) 79 | 80 | 81 | help_wrap = _conventions.help_wrap 82 | 83 | 84 | get_num_gpus = _base.get_num_gpus 85 | get_tf_dtype = _performance.get_tf_dtype 86 | get_loss_scale = _performance.get_loss_scale 87 | DTYPE_MAP = _performance.DTYPE_MAP 88 | require_cloud_storage = _device.require_cloud_storage 89 | -------------------------------------------------------------------------------- /utils/logs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/logs/cloud_lib.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utilities that interact with cloud service. 17 | """ 18 | 19 | import requests 20 | 21 | GCP_METADATA_URL = "http://metadata/computeMetadata/v1/instance/hostname" 22 | GCP_METADATA_HEADER = {"Metadata-Flavor": "Google"} 23 | 24 | 25 | def on_gcp(): 26 | """Detect whether the current running environment is on GCP.""" 27 | try: 28 | # Timeout in 5 seconds, in case the test environment has connectivity issue. 29 | # There is not default timeout, which means it might block forever. 30 | response = requests.get( 31 | GCP_METADATA_URL, headers=GCP_METADATA_HEADER, timeout=5) 32 | return response.status_code == 200 33 | except requests.exceptions.RequestException: 34 | return False 35 | -------------------------------------------------------------------------------- /utils/logs/hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Hook that counts examples per second every N steps or seconds.""" 17 | 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf # pylint: disable=g-bad-import-order 24 | 25 | from utils.logs import logger 26 | 27 | 28 | class ExamplesPerSecondHook(tf.train.SessionRunHook): 29 | """Hook to print out examples per second. 30 | 31 | Total time is tracked and then divided by the total number of steps 32 | to get the average step time and then batch_size is used to determine 33 | the running average of examples per second. The examples per second for the 34 | most recent interval is also logged. 35 | """ 36 | 37 | def __init__(self, 38 | batch_size, 39 | every_n_steps=None, 40 | every_n_secs=None, 41 | warm_steps=0, 42 | metric_logger=None): 43 | """Initializer for ExamplesPerSecondHook. 44 | 45 | Args: 46 | batch_size: Total batch size across all workers used to calculate 47 | examples/second from global time. 48 | every_n_steps: Log stats every n steps. 49 | every_n_secs: Log stats every n seconds. Exactly one of the 50 | `every_n_steps` or `every_n_secs` should be set. 51 | warm_steps: The number of steps to be skipped before logging and running 52 | average calculation. warm_steps steps refers to global steps across all 53 | workers, not on each worker 54 | metric_logger: instance of `BenchmarkLogger`, the benchmark logger that 55 | hook should use to write the log. If None, BaseBenchmarkLogger will 56 | be used. 57 | 58 | Raises: 59 | ValueError: if neither `every_n_steps` or `every_n_secs` is set, or 60 | both are set. 61 | """ 62 | 63 | if (every_n_steps is None) == (every_n_secs is None): 64 | raise ValueError("exactly one of every_n_steps" 65 | " and every_n_secs should be provided.") 66 | 67 | self._logger = metric_logger or logger.BaseBenchmarkLogger() 68 | 69 | self._timer = tf.train.SecondOrStepTimer( 70 | every_steps=every_n_steps, every_secs=every_n_secs) 71 | 72 | self._step_train_time = 0 73 | self._total_steps = 0 74 | self._batch_size = batch_size 75 | self._warm_steps = warm_steps 76 | 77 | def begin(self): 78 | """Called once before using the session to check global step.""" 79 | self._global_step_tensor = tf.train.get_global_step() 80 | if self._global_step_tensor is None: 81 | raise RuntimeError( 82 | "Global step should be created to use StepCounterHook.") 83 | 84 | def before_run(self, run_context): # pylint: disable=unused-argument 85 | """Called before each call to run(). 86 | 87 | Args: 88 | run_context: A SessionRunContext object. 89 | 90 | Returns: 91 | A SessionRunArgs object or None if never triggered. 92 | """ 93 | return tf.train.SessionRunArgs(self._global_step_tensor) 94 | 95 | def after_run(self, run_context, run_values): # pylint: disable=unused-argument 96 | """Called after each call to run(). 97 | 98 | Args: 99 | run_context: A SessionRunContext object. 100 | run_values: A SessionRunValues object. 101 | """ 102 | global_step = run_values.results 103 | 104 | if self._timer.should_trigger_for_step( 105 | global_step) and global_step > self._warm_steps: 106 | elapsed_time, elapsed_steps = self._timer.update_last_triggered_step( 107 | global_step) 108 | if elapsed_time is not None: 109 | self._step_train_time += elapsed_time 110 | self._total_steps += elapsed_steps 111 | 112 | # average examples per second is based on the total (accumulative) 113 | # training steps and training time so far 114 | average_examples_per_sec = self._batch_size * ( 115 | self._total_steps / self._step_train_time) 116 | # current examples per second is based on the elapsed training steps 117 | # and training time per batch 118 | current_examples_per_sec = self._batch_size * ( 119 | elapsed_steps / elapsed_time) 120 | 121 | self._logger.log_metric( 122 | "average_examples_per_sec", average_examples_per_sec, 123 | global_step=global_step) 124 | 125 | self._logger.log_metric( 126 | "current_examples_per_sec", current_examples_per_sec, 127 | global_step=global_step) 128 | -------------------------------------------------------------------------------- /utils/logs/hooks_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Hooks helper to return a list of TensorFlow hooks for training by name. 17 | 18 | More hooks can be added to this set. To add a new hook, 1) add the new hook to 19 | the registry in HOOKS, 2) add a corresponding function that parses out necessary 20 | parameters. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf # pylint: disable=g-bad-import-order 28 | 29 | from utils.logs import hooks 30 | from utils.logs import logger 31 | from utils.logs import metric_hook 32 | 33 | _TENSORS_TO_LOG = dict((x, x) for x in ['learning_rate', 34 | 'cross_entropy', 35 | 'train_accuracy']) 36 | 37 | 38 | def get_train_hooks(name_list, use_tpu=False, **kwargs): 39 | """Factory for getting a list of TensorFlow hooks for training by name. 40 | 41 | Args: 42 | name_list: a list of strings to name desired hook classes. Allowed: 43 | LoggingTensorHook, ProfilerHook, ExamplesPerSecondHook, which are defined 44 | as keys in HOOKS 45 | use_tpu: Boolean of whether computation occurs on a TPU. This will disable 46 | hooks altogether. 47 | **kwargs: a dictionary of arguments to the hooks. 48 | 49 | Returns: 50 | list of instantiated hooks, ready to be used in a classifier.train call. 51 | 52 | Raises: 53 | ValueError: if an unrecognized name is passed. 54 | """ 55 | 56 | if not name_list: 57 | return [] 58 | 59 | if use_tpu: 60 | tf.logging.warning("hooks_helper received name_list `{}`, but a TPU is " 61 | "specified. No hooks will be used.".format(name_list)) 62 | return [] 63 | 64 | train_hooks = [] 65 | for name in name_list: 66 | hook_name = HOOKS.get(name.strip().lower()) 67 | if hook_name is None: 68 | raise ValueError('Unrecognized training hook requested: {}'.format(name)) 69 | else: 70 | train_hooks.append(hook_name(**kwargs)) 71 | 72 | return train_hooks 73 | 74 | 75 | def get_logging_tensor_hook(every_n_iter=100, tensors_to_log=None, **kwargs): # pylint: disable=unused-argument 76 | """Function to get LoggingTensorHook. 77 | 78 | Args: 79 | every_n_iter: `int`, print the values of `tensors` once every N local 80 | steps taken on the current worker. 81 | tensors_to_log: List of tensor names or dictionary mapping labels to tensor 82 | names. If not set, log _TENSORS_TO_LOG by default. 83 | **kwargs: a dictionary of arguments to LoggingTensorHook. 84 | 85 | Returns: 86 | Returns a LoggingTensorHook with a standard set of tensors that will be 87 | printed to stdout. 88 | """ 89 | if tensors_to_log is None: 90 | tensors_to_log = _TENSORS_TO_LOG 91 | 92 | return tf.train.LoggingTensorHook( 93 | tensors=tensors_to_log, 94 | every_n_iter=every_n_iter) 95 | 96 | 97 | def get_profiler_hook(model_dir, save_steps=1000, **kwargs): # pylint: disable=unused-argument 98 | """Function to get ProfilerHook. 99 | 100 | Args: 101 | model_dir: The directory to save the profile traces to. 102 | save_steps: `int`, print profile traces every N steps. 103 | **kwargs: a dictionary of arguments to ProfilerHook. 104 | 105 | Returns: 106 | Returns a ProfilerHook that writes out timelines that can be loaded into 107 | profiling tools like chrome://tracing. 108 | """ 109 | return tf.train.ProfilerHook(save_steps=save_steps, output_dir=model_dir) 110 | 111 | 112 | def get_examples_per_second_hook(every_n_steps=100, 113 | batch_size=128, 114 | warm_steps=5, 115 | **kwargs): # pylint: disable=unused-argument 116 | """Function to get ExamplesPerSecondHook. 117 | 118 | Args: 119 | every_n_steps: `int`, print current and average examples per second every 120 | N steps. 121 | batch_size: `int`, total batch size used to calculate examples/second from 122 | global time. 123 | warm_steps: skip this number of steps before logging and running average. 124 | **kwargs: a dictionary of arguments to ExamplesPerSecondHook. 125 | 126 | Returns: 127 | Returns a ProfilerHook that writes out timelines that can be loaded into 128 | profiling tools like chrome://tracing. 129 | """ 130 | return hooks.ExamplesPerSecondHook( 131 | batch_size=batch_size, every_n_steps=every_n_steps, 132 | warm_steps=warm_steps, metric_logger=logger.get_benchmark_logger()) 133 | 134 | 135 | def get_logging_metric_hook(tensors_to_log=None, 136 | every_n_secs=600, 137 | **kwargs): # pylint: disable=unused-argument 138 | """Function to get LoggingMetricHook. 139 | 140 | Args: 141 | tensors_to_log: List of tensor names or dictionary mapping labels to tensor 142 | names. If not set, log _TENSORS_TO_LOG by default. 143 | every_n_secs: `int`, the frequency for logging the metric. Default to every 144 | 10 mins. 145 | 146 | Returns: 147 | Returns a LoggingMetricHook that saves tensor values in a JSON format. 148 | """ 149 | if tensors_to_log is None: 150 | tensors_to_log = _TENSORS_TO_LOG 151 | return metric_hook.LoggingMetricHook( 152 | tensors=tensors_to_log, 153 | metric_logger=logger.get_benchmark_logger(), 154 | every_n_secs=every_n_secs) 155 | 156 | 157 | # A dictionary to map one hook name and its corresponding function 158 | HOOKS = { 159 | 'loggingtensorhook': get_logging_tensor_hook, 160 | 'profilerhook': get_profiler_hook, 161 | 'examplespersecondhook': get_examples_per_second_hook, 162 | 'loggingmetrichook': get_logging_metric_hook, 163 | } 164 | -------------------------------------------------------------------------------- /utils/logs/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Logging utilities for benchmark. 17 | 18 | For collecting local environment metrics like CPU and memory, certain python 19 | packages need be installed. See README for details. 20 | """ 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import contextlib 26 | import datetime 27 | import json 28 | import multiprocessing 29 | import numbers 30 | import os 31 | import threading 32 | import uuid 33 | 34 | from six.moves import _thread as thread 35 | from absl import flags 36 | import tensorflow as tf 37 | from tensorflow.python.client import device_lib 38 | 39 | from utils.logs import cloud_lib 40 | 41 | METRIC_LOG_FILE_NAME = "metric.log" 42 | BENCHMARK_RUN_LOG_FILE_NAME = "benchmark_run.log" 43 | _DATE_TIME_FORMAT_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ" 44 | GCP_TEST_ENV = "GCP" 45 | RUN_STATUS_SUCCESS = "success" 46 | RUN_STATUS_FAILURE = "failure" 47 | RUN_STATUS_RUNNING = "running" 48 | 49 | 50 | FLAGS = flags.FLAGS 51 | 52 | # Don't use it directly. Use get_benchmark_logger to access a logger. 53 | _benchmark_logger = None 54 | _logger_lock = threading.Lock() 55 | 56 | 57 | def config_benchmark_logger(flag_obj=None): 58 | """Config the global benchmark logger.""" 59 | _logger_lock.acquire() 60 | try: 61 | global _benchmark_logger 62 | if not flag_obj: 63 | flag_obj = FLAGS 64 | 65 | if (not hasattr(flag_obj, "benchmark_logger_type") or 66 | flag_obj.benchmark_logger_type == "BaseBenchmarkLogger"): 67 | _benchmark_logger = BaseBenchmarkLogger() 68 | elif flag_obj.benchmark_logger_type == "BenchmarkFileLogger": 69 | _benchmark_logger = BenchmarkFileLogger(flag_obj.benchmark_log_dir) 70 | elif flag_obj.benchmark_logger_type == "BenchmarkBigQueryLogger": 71 | from official.benchmark import benchmark_uploader as bu # pylint: disable=g-import-not-at-top 72 | bq_uploader = bu.BigQueryUploader(gcp_project=flag_obj.gcp_project) 73 | _benchmark_logger = BenchmarkBigQueryLogger( 74 | bigquery_uploader=bq_uploader, 75 | bigquery_data_set=flag_obj.bigquery_data_set, 76 | bigquery_run_table=flag_obj.bigquery_run_table, 77 | bigquery_run_status_table=flag_obj.bigquery_run_status_table, 78 | bigquery_metric_table=flag_obj.bigquery_metric_table, 79 | run_id=str(uuid.uuid4())) 80 | else: 81 | raise ValueError("Unrecognized benchmark_logger_type: %s" 82 | % flag_obj.benchmark_logger_type) 83 | 84 | finally: 85 | _logger_lock.release() 86 | return _benchmark_logger 87 | 88 | 89 | def get_benchmark_logger(): 90 | if not _benchmark_logger: 91 | config_benchmark_logger() 92 | return _benchmark_logger 93 | 94 | 95 | @contextlib.contextmanager 96 | def benchmark_context(flag_obj): 97 | """Context of benchmark, which will update status of the run accordingly.""" 98 | benchmark_logger = config_benchmark_logger(flag_obj) 99 | try: 100 | yield 101 | benchmark_logger.on_finish(RUN_STATUS_SUCCESS) 102 | except Exception: # pylint: disable=broad-except 103 | # Catch all the exception, update the run status to be failure, and re-raise 104 | benchmark_logger.on_finish(RUN_STATUS_FAILURE) 105 | raise 106 | 107 | 108 | class BaseBenchmarkLogger(object): 109 | """Class to log the benchmark information to STDOUT.""" 110 | 111 | def log_evaluation_result(self, eval_results): 112 | """Log the evaluation result. 113 | 114 | The evaluate result is a dictionary that contains metrics defined in 115 | model_fn. It also contains a entry for global_step which contains the value 116 | of the global step when evaluation was performed. 117 | 118 | Args: 119 | eval_results: dict, the result of evaluate. 120 | """ 121 | if not isinstance(eval_results, dict): 122 | tf.logging.warning("eval_results should be dictionary for logging. " 123 | "Got %s", type(eval_results)) 124 | return 125 | global_step = eval_results[tf.GraphKeys.GLOBAL_STEP] 126 | for key in sorted(eval_results): 127 | if key != tf.GraphKeys.GLOBAL_STEP: 128 | self.log_metric(key, eval_results[key], global_step=global_step) 129 | 130 | def log_metric(self, name, value, unit=None, global_step=None, extras=None): 131 | """Log the benchmark metric information to local file. 132 | 133 | Currently the logging is done in a synchronized way. This should be updated 134 | to log asynchronously. 135 | 136 | Args: 137 | name: string, the name of the metric to log. 138 | value: number, the value of the metric. The value will not be logged if it 139 | is not a number type. 140 | unit: string, the unit of the metric, E.g "image per second". 141 | global_step: int, the global_step when the metric is logged. 142 | extras: map of string:string, the extra information about the metric. 143 | """ 144 | metric = _process_metric_to_json(name, value, unit, global_step, extras) 145 | if metric: 146 | tf.logging.info("Benchmark metric: %s", metric) 147 | 148 | def log_run_info(self, model_name, dataset_name, run_params, test_id=None): 149 | tf.logging.info("Benchmark run: %s", 150 | _gather_run_info(model_name, dataset_name, run_params, 151 | test_id)) 152 | 153 | def on_finish(self, status): 154 | pass 155 | 156 | 157 | class BenchmarkFileLogger(BaseBenchmarkLogger): 158 | """Class to log the benchmark information to local disk.""" 159 | 160 | def __init__(self, logging_dir): 161 | super(BenchmarkFileLogger, self).__init__() 162 | self._logging_dir = logging_dir 163 | if not tf.gfile.IsDirectory(self._logging_dir): 164 | tf.gfile.MakeDirs(self._logging_dir) 165 | self._metric_file_handler = tf.gfile.GFile( 166 | os.path.join(self._logging_dir, METRIC_LOG_FILE_NAME), "a") 167 | 168 | def log_metric(self, name, value, unit=None, global_step=None, extras=None): 169 | """Log the benchmark metric information to local file. 170 | 171 | Currently the logging is done in a synchronized way. This should be updated 172 | to log asynchronously. 173 | 174 | Args: 175 | name: string, the name of the metric to log. 176 | value: number, the value of the metric. The value will not be logged if it 177 | is not a number type. 178 | unit: string, the unit of the metric, E.g "image per second". 179 | global_step: int, the global_step when the metric is logged. 180 | extras: map of string:string, the extra information about the metric. 181 | """ 182 | metric = _process_metric_to_json(name, value, unit, global_step, extras) 183 | if metric: 184 | try: 185 | json.dump(metric, self._metric_file_handler) 186 | self._metric_file_handler.write("\n") 187 | self._metric_file_handler.flush() 188 | except (TypeError, ValueError) as e: 189 | tf.logging.warning("Failed to dump metric to log file: " 190 | "name %s, value %s, error %s", name, value, e) 191 | 192 | def log_run_info(self, model_name, dataset_name, run_params, test_id=None): 193 | """Collect most of the TF runtime information for the local env. 194 | 195 | The schema of the run info follows official/benchmark/datastore/schema. 196 | 197 | Args: 198 | model_name: string, the name of the model. 199 | dataset_name: string, the name of dataset for training and evaluation. 200 | run_params: dict, the dictionary of parameters for the run, it could 201 | include hyperparameters or other params that are important for the run. 202 | test_id: string, the unique name of the test run by the combination of key 203 | parameters, eg batch size, num of GPU. It is hardware independent. 204 | """ 205 | run_info = _gather_run_info(model_name, dataset_name, run_params, test_id) 206 | 207 | with tf.gfile.GFile(os.path.join( 208 | self._logging_dir, BENCHMARK_RUN_LOG_FILE_NAME), "w") as f: 209 | try: 210 | json.dump(run_info, f) 211 | f.write("\n") 212 | except (TypeError, ValueError) as e: 213 | tf.logging.warning("Failed to dump benchmark run info to log file: %s", 214 | e) 215 | 216 | def on_finish(self, status): 217 | self._metric_file_handler.flush() 218 | self._metric_file_handler.close() 219 | 220 | 221 | class BenchmarkBigQueryLogger(BaseBenchmarkLogger): 222 | """Class to log the benchmark information to BigQuery data store.""" 223 | 224 | def __init__(self, 225 | bigquery_uploader, 226 | bigquery_data_set, 227 | bigquery_run_table, 228 | bigquery_run_status_table, 229 | bigquery_metric_table, 230 | run_id): 231 | super(BenchmarkBigQueryLogger, self).__init__() 232 | self._bigquery_uploader = bigquery_uploader 233 | self._bigquery_data_set = bigquery_data_set 234 | self._bigquery_run_table = bigquery_run_table 235 | self._bigquery_run_status_table = bigquery_run_status_table 236 | self._bigquery_metric_table = bigquery_metric_table 237 | self._run_id = run_id 238 | 239 | def log_metric(self, name, value, unit=None, global_step=None, extras=None): 240 | """Log the benchmark metric information to bigquery. 241 | 242 | Args: 243 | name: string, the name of the metric to log. 244 | value: number, the value of the metric. The value will not be logged if it 245 | is not a number type. 246 | unit: string, the unit of the metric, E.g "image per second". 247 | global_step: int, the global_step when the metric is logged. 248 | extras: map of string:string, the extra information about the metric. 249 | """ 250 | metric = _process_metric_to_json(name, value, unit, global_step, extras) 251 | if metric: 252 | # Starting new thread for bigquery upload in case it might take long time 253 | # and impact the benchmark and performance measurement. Starting a new 254 | # thread might have potential performance impact for model that run on 255 | # CPU. 256 | thread.start_new_thread( 257 | self._bigquery_uploader.upload_benchmark_metric_json, 258 | (self._bigquery_data_set, 259 | self._bigquery_metric_table, 260 | self._run_id, 261 | [metric])) 262 | 263 | def log_run_info(self, model_name, dataset_name, run_params, test_id=None): 264 | """Collect most of the TF runtime information for the local env. 265 | 266 | The schema of the run info follows official/benchmark/datastore/schema. 267 | 268 | Args: 269 | model_name: string, the name of the model. 270 | dataset_name: string, the name of dataset for training and evaluation. 271 | run_params: dict, the dictionary of parameters for the run, it could 272 | include hyperparameters or other params that are important for the run. 273 | test_id: string, the unique name of the test run by the combination of key 274 | parameters, eg batch size, num of GPU. It is hardware independent. 275 | """ 276 | run_info = _gather_run_info(model_name, dataset_name, run_params, test_id) 277 | # Starting new thread for bigquery upload in case it might take long time 278 | # and impact the benchmark and performance measurement. Starting a new 279 | # thread might have potential performance impact for model that run on CPU. 280 | thread.start_new_thread( 281 | self._bigquery_uploader.upload_benchmark_run_json, 282 | (self._bigquery_data_set, 283 | self._bigquery_run_table, 284 | self._run_id, 285 | run_info)) 286 | thread.start_new_thread( 287 | self._bigquery_uploader.insert_run_status, 288 | (self._bigquery_data_set, 289 | self._bigquery_run_status_table, 290 | self._run_id, 291 | RUN_STATUS_RUNNING)) 292 | 293 | def on_finish(self, status): 294 | self._bigquery_uploader.update_run_status( 295 | self._bigquery_data_set, 296 | self._bigquery_run_status_table, 297 | self._run_id, 298 | status) 299 | 300 | 301 | def _gather_run_info(model_name, dataset_name, run_params, test_id): 302 | """Collect the benchmark run information for the local environment.""" 303 | run_info = { 304 | "model_name": model_name, 305 | "dataset": {"name": dataset_name}, 306 | "machine_config": {}, 307 | "test_id": test_id, 308 | "run_date": datetime.datetime.utcnow().strftime( 309 | _DATE_TIME_FORMAT_PATTERN)} 310 | session_config = None 311 | if "session_config" in run_params: 312 | session_config = run_params["session_config"] 313 | _collect_tensorflow_info(run_info) 314 | _collect_tensorflow_environment_variables(run_info) 315 | _collect_run_params(run_info, run_params) 316 | _collect_cpu_info(run_info) 317 | _collect_gpu_info(run_info, session_config) 318 | _collect_memory_info(run_info) 319 | _collect_test_environment(run_info) 320 | return run_info 321 | 322 | 323 | def _process_metric_to_json( 324 | name, value, unit=None, global_step=None, extras=None): 325 | """Validate the metric data and generate JSON for insert.""" 326 | if not isinstance(value, numbers.Number): 327 | tf.logging.warning( 328 | "Metric value to log should be a number. Got %s", type(value)) 329 | return None 330 | 331 | extras = _convert_to_json_dict(extras) 332 | return { 333 | "name": name, 334 | "value": float(value), 335 | "unit": unit, 336 | "global_step": global_step, 337 | "timestamp": datetime.datetime.utcnow().strftime( 338 | _DATE_TIME_FORMAT_PATTERN), 339 | "extras": extras} 340 | 341 | 342 | def _collect_tensorflow_info(run_info): 343 | run_info["tensorflow_version"] = { 344 | "version": tf.VERSION, "git_hash": tf.GIT_VERSION} 345 | 346 | 347 | def _collect_run_params(run_info, run_params): 348 | """Log the parameter information for the benchmark run.""" 349 | def process_param(name, value): 350 | type_check = { 351 | str: {"name": name, "string_value": value}, 352 | int: {"name": name, "long_value": value}, 353 | bool: {"name": name, "bool_value": str(value)}, 354 | float: {"name": name, "float_value": value}, 355 | } 356 | return type_check.get(type(value), 357 | {"name": name, "string_value": str(value)}) 358 | if run_params: 359 | run_info["run_parameters"] = [ 360 | process_param(k, v) for k, v in sorted(run_params.items())] 361 | 362 | 363 | def _collect_tensorflow_environment_variables(run_info): 364 | run_info["tensorflow_environment_variables"] = [ 365 | {"name": k, "value": v} 366 | for k, v in sorted(os.environ.items()) if k.startswith("TF_")] 367 | 368 | 369 | # The following code is mirrored from tensorflow/tools/test/system_info_lib 370 | # which is not exposed for import. 371 | def _collect_cpu_info(run_info): 372 | """Collect the CPU information for the local environment.""" 373 | cpu_info = {} 374 | 375 | cpu_info["num_cores"] = multiprocessing.cpu_count() 376 | 377 | try: 378 | # Note: cpuinfo is not installed in the TensorFlow OSS tree. 379 | # It is installable via pip. 380 | import cpuinfo # pylint: disable=g-import-not-at-top 381 | 382 | info = cpuinfo.get_cpu_info() 383 | cpu_info["cpu_info"] = info["brand"] 384 | cpu_info["mhz_per_cpu"] = info["hz_advertised_raw"][0] / 1.0e6 385 | 386 | run_info["machine_config"]["cpu_info"] = cpu_info 387 | except ImportError: 388 | tf.logging.warn("'cpuinfo' not imported. CPU info will not be logged.") 389 | 390 | 391 | def _collect_gpu_info(run_info, session_config=None): 392 | """Collect local GPU information by TF device library.""" 393 | gpu_info = {} 394 | local_device_protos = device_lib.list_local_devices(session_config) 395 | 396 | gpu_info["count"] = len([d for d in local_device_protos 397 | if d.device_type == "GPU"]) 398 | # The device description usually is a JSON string, which contains the GPU 399 | # model info, eg: 400 | # "device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0" 401 | for d in local_device_protos: 402 | if d.device_type == "GPU": 403 | gpu_info["model"] = _parse_gpu_model(d.physical_device_desc) 404 | # Assume all the GPU connected are same model 405 | break 406 | run_info["machine_config"]["gpu_info"] = gpu_info 407 | 408 | 409 | def _collect_memory_info(run_info): 410 | try: 411 | # Note: psutil is not installed in the TensorFlow OSS tree. 412 | # It is installable via pip. 413 | import psutil # pylint: disable=g-import-not-at-top 414 | vmem = psutil.virtual_memory() 415 | run_info["machine_config"]["memory_total"] = vmem.total 416 | run_info["machine_config"]["memory_available"] = vmem.available 417 | except ImportError: 418 | tf.logging.warn("'psutil' not imported. Memory info will not be logged.") 419 | 420 | 421 | def _collect_test_environment(run_info): 422 | """Detect the local environment, eg GCE, AWS or DGX, etc.""" 423 | if cloud_lib.on_gcp(): 424 | run_info["test_environment"] = GCP_TEST_ENV 425 | # TODO(scottzhu): Add more testing env detection for other platform 426 | 427 | 428 | def _parse_gpu_model(physical_device_desc): 429 | # Assume all the GPU connected are same model 430 | for kv in physical_device_desc.split(","): 431 | k, _, v = kv.partition(":") 432 | if k.strip() == "name": 433 | return v.strip() 434 | return None 435 | 436 | 437 | def _convert_to_json_dict(input_dict): 438 | if input_dict: 439 | return [{"name": k, "value": v} for k, v in sorted(input_dict.items())] 440 | else: 441 | return [] 442 | -------------------------------------------------------------------------------- /utils/logs/metric_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Session hook for logging benchmark metric.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf # pylint: disable=g-bad-import-order 22 | 23 | 24 | class LoggingMetricHook(tf.train.LoggingTensorHook): 25 | """Hook to log benchmark metric information. 26 | 27 | This hook is very similar as tf.train.LoggingTensorHook, which logs given 28 | tensors every N local steps, every N seconds, or at the end. The metric 29 | information will be logged to given log_dir or via metric_logger in JSON 30 | format, which can be consumed by data analysis pipeline later. 31 | 32 | Note that if `at_end` is True, `tensors` should not include any tensor 33 | whose evaluation produces a side effect such as consuming additional inputs. 34 | """ 35 | 36 | def __init__(self, tensors, metric_logger=None, 37 | every_n_iter=None, every_n_secs=None, at_end=False): 38 | """Initializer for LoggingMetricHook. 39 | 40 | Args: 41 | tensors: `dict` that maps string-valued tags to tensors/tensor names, 42 | or `iterable` of tensors/tensor names. 43 | metric_logger: instance of `BenchmarkLogger`, the benchmark logger that 44 | hook should use to write the log. 45 | every_n_iter: `int`, print the values of `tensors` once every N local 46 | steps taken on the current worker. 47 | every_n_secs: `int` or `float`, print the values of `tensors` once every N 48 | seconds. Exactly one of `every_n_iter` and `every_n_secs` should be 49 | provided. 50 | at_end: `bool` specifying whether to print the values of `tensors` at the 51 | end of the run. 52 | 53 | Raises: 54 | ValueError: 55 | 1. `every_n_iter` is non-positive, or 56 | 2. Exactly one of every_n_iter and every_n_secs should be provided. 57 | 3. Exactly one of log_dir and metric_logger should be provided. 58 | """ 59 | super(LoggingMetricHook, self).__init__( 60 | tensors=tensors, 61 | every_n_iter=every_n_iter, 62 | every_n_secs=every_n_secs, 63 | at_end=at_end) 64 | 65 | if metric_logger is None: 66 | raise ValueError("metric_logger should be provided.") 67 | self._logger = metric_logger 68 | 69 | def begin(self): 70 | super(LoggingMetricHook, self).begin() 71 | self._global_step_tensor = tf.train.get_global_step() 72 | if self._global_step_tensor is None: 73 | raise RuntimeError( 74 | "Global step should be created to use LoggingMetricHook.") 75 | if self._global_step_tensor.name not in self._current_tensors: 76 | self._current_tensors[self._global_step_tensor.name] = ( 77 | self._global_step_tensor) 78 | 79 | def after_run(self, unused_run_context, run_values): 80 | # should_trigger is a internal state that populated at before_run, and it is 81 | # using self_timer to determine whether it should trigger. 82 | if self._should_trigger: 83 | self._log_metric(run_values.results) 84 | 85 | self._iter_count += 1 86 | 87 | def end(self, session): 88 | if self._log_at_end: 89 | values = session.run(self._current_tensors) 90 | self._log_metric(values) 91 | 92 | def _log_metric(self, tensor_values): 93 | self._timer.update_last_triggered_step(self._iter_count) 94 | global_step = tensor_values[self._global_step_tensor.name] 95 | # self._tag_order is populated during the init of LoggingTensorHook 96 | for tag in self._tag_order: 97 | self._logger.log_metric(tag, tensor_values[tag], global_step=global_step) 98 | -------------------------------------------------------------------------------- /utils/logs/mlperf_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Wrapper for the mlperf logging utils. 17 | 18 | MLPerf compliance logging is only desired under a limited set of circumstances. 19 | This module is intended to keep users from needing to consider logging (or 20 | install the module) unless they are performing mlperf runs. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from collections import namedtuple 28 | import json 29 | import os 30 | import re 31 | import subprocess 32 | import sys 33 | import typing 34 | 35 | import tensorflow as tf 36 | 37 | _MIN_VERSION = (0, 0, 6) 38 | _STACK_OFFSET = 2 39 | 40 | SUDO = "sudo" if os.geteuid() else "" 41 | 42 | # This indirection is used in docker. 43 | DROP_CACHE_LOC = os.getenv("DROP_CACHE_LOC", "/proc/sys/vm/drop_caches") 44 | 45 | _NCF_PREFIX = "NCF_RAW_" 46 | 47 | # TODO(robieta): move line parsing to mlperf util 48 | _PREFIX = r"(?:{})?:::MLPv([0-9]+).([0-9]+).([0-9]+)".format(_NCF_PREFIX) 49 | _BENCHMARK = r"([a-zA-Z0-9_]+)" 50 | _TIMESTAMP = r"([0-9]+\.[0-9]+)" 51 | _CALLSITE = r"\((.+):([0-9]+)\)" 52 | _TAG = r"([a-zA-Z0-9_]+)" 53 | _VALUE = r"(.*)" 54 | 55 | ParsedLine = namedtuple("ParsedLine", ["version", "benchmark", "timestamp", 56 | "callsite", "tag", "value"]) 57 | 58 | LINE_PATTERN = re.compile( 59 | "^{prefix} {benchmark} {timestamp} {callsite} {tag}(: |$){value}?$".format( 60 | prefix=_PREFIX, benchmark=_BENCHMARK, timestamp=_TIMESTAMP, 61 | callsite=_CALLSITE, tag=_TAG, value=_VALUE)) 62 | 63 | 64 | def parse_line(line): # type: (str) -> typing.Optional[ParsedLine] 65 | match = LINE_PATTERN.match(line.strip()) 66 | if not match: 67 | return 68 | 69 | major, minor, micro, benchmark, timestamp = match.groups()[:5] 70 | call_file, call_line, tag, _, value = match.groups()[5:] 71 | 72 | return ParsedLine(version=(int(major), int(minor), int(micro)), 73 | benchmark=benchmark, timestamp=timestamp, 74 | callsite=(call_file, call_line), tag=tag, value=value) 75 | 76 | 77 | def unparse_line(parsed_line): # type: (ParsedLine) -> str 78 | version_str = "{}.{}.{}".format(*parsed_line.version) 79 | callsite_str = "({}:{})".format(*parsed_line.callsite) 80 | value_str = ": {}".format(parsed_line.value) if parsed_line.value else "" 81 | return ":::MLPv{} {} {} {} {} {}".format( 82 | version_str, parsed_line.benchmark, parsed_line.timestamp, callsite_str, 83 | parsed_line.tag, value_str) 84 | 85 | 86 | def get_mlperf_log(): 87 | """Shielded import of mlperf_log module.""" 88 | try: 89 | import mlperf_compliance 90 | 91 | def test_mlperf_log_pip_version(): 92 | """Check that mlperf_compliance is up to date.""" 93 | import pkg_resources 94 | version = pkg_resources.get_distribution("mlperf_compliance") 95 | version = tuple(int(i) for i in version.version.split(".")) 96 | if version < _MIN_VERSION: 97 | tf.logging.warning( 98 | "mlperf_compliance is version {}, must be >= {}".format( 99 | ".".join([str(i) for i in version]), 100 | ".".join([str(i) for i in _MIN_VERSION]))) 101 | raise ImportError 102 | return mlperf_compliance.mlperf_log 103 | 104 | mlperf_log = test_mlperf_log_pip_version() 105 | 106 | except ImportError: 107 | mlperf_log = None 108 | 109 | return mlperf_log 110 | 111 | 112 | class Logger(object): 113 | """MLPerf logger indirection class. 114 | 115 | This logger only logs for MLPerf runs, and prevents various errors associated 116 | with not having the mlperf_compliance package installed. 117 | """ 118 | class Tags(object): 119 | def __init__(self, mlperf_log): 120 | self._enabled = False 121 | self._mlperf_log = mlperf_log 122 | 123 | def __getattr__(self, item): 124 | if self._mlperf_log is None or not self._enabled: 125 | return 126 | return getattr(self._mlperf_log, item) 127 | 128 | def __init__(self): 129 | self._enabled = False 130 | self._mlperf_log = get_mlperf_log() 131 | self.tags = self.Tags(self._mlperf_log) 132 | 133 | def __call__(self, enable=False): 134 | if enable and self._mlperf_log is None: 135 | raise ImportError("MLPerf logging was requested, but mlperf_compliance " 136 | "module could not be loaded.") 137 | 138 | self._enabled = enable 139 | self.tags._enabled = enable 140 | return self 141 | 142 | def __enter__(self): 143 | pass 144 | 145 | def __exit__(self, exc_type, exc_val, exc_tb): 146 | self._enabled = False 147 | self.tags._enabled = False 148 | 149 | @property 150 | def log_file(self): 151 | if self._mlperf_log is None: 152 | return 153 | return self._mlperf_log.LOG_FILE 154 | 155 | @property 156 | def enabled(self): 157 | return self._enabled 158 | 159 | def ncf_print(self, key, value=None, stack_offset=_STACK_OFFSET, 160 | deferred=False, extra_print=False, prefix=_NCF_PREFIX): 161 | if self._mlperf_log is None or not self.enabled: 162 | return 163 | self._mlperf_log.ncf_print(key=key, value=value, stack_offset=stack_offset, 164 | deferred=deferred, extra_print=extra_print, 165 | prefix=prefix) 166 | 167 | def set_ncf_root(self, path): 168 | if self._mlperf_log is None: 169 | return 170 | self._mlperf_log.ROOT_DIR_NCF = path 171 | 172 | 173 | LOGGER = Logger() 174 | ncf_print, set_ncf_root = LOGGER.ncf_print, LOGGER.set_ncf_root 175 | TAGS = LOGGER.tags 176 | 177 | 178 | def clear_system_caches(): 179 | if not LOGGER.enabled: 180 | return 181 | ret_code = subprocess.call( 182 | ["sync && echo 3 | {} tee {}".format(SUDO, DROP_CACHE_LOC)], 183 | shell=True) 184 | 185 | if ret_code: 186 | raise ValueError("Failed to clear caches") 187 | 188 | 189 | def stitch_ncf(): 190 | """Format NCF logs for MLPerf compliance.""" 191 | if not LOGGER.enabled: 192 | return 193 | 194 | if LOGGER.log_file is None or not tf.gfile.Exists(LOGGER.log_file): 195 | tf.logging.warning("Could not find log file to stitch.") 196 | return 197 | 198 | log_lines = [] 199 | num_eval_users = None 200 | start_time = None 201 | stop_time = None 202 | with tf.gfile.Open(LOGGER.log_file, "r") as f: 203 | for line in f: 204 | parsed_line = parse_line(line) 205 | if not parsed_line: 206 | tf.logging.warning("Failed to parse line: {}".format(line)) 207 | continue 208 | log_lines.append(parsed_line) 209 | 210 | if parsed_line.tag == TAGS.RUN_START: 211 | assert start_time is None 212 | start_time = float(parsed_line.timestamp) 213 | 214 | if parsed_line.tag == TAGS.RUN_STOP: 215 | assert stop_time is None 216 | stop_time = float(parsed_line.timestamp) 217 | 218 | if (parsed_line.tag == TAGS.EVAL_HP_NUM_USERS and parsed_line.value 219 | is not None and "DEFERRED" not in parsed_line.value): 220 | assert num_eval_users is None or num_eval_users == parsed_line.value 221 | num_eval_users = parsed_line.value 222 | log_lines.pop() 223 | 224 | for i, parsed_line in enumerate(log_lines): 225 | if parsed_line.tag == TAGS.EVAL_HP_NUM_USERS: 226 | log_lines[i] = ParsedLine(*parsed_line[:-1], value=num_eval_users) 227 | 228 | log_lines = sorted([unparse_line(i) for i in log_lines]) 229 | 230 | output_path = os.getenv("STITCHED_COMPLIANCE_FILE", None) 231 | if output_path: 232 | with tf.gfile.Open(output_path, "w") as f: 233 | for line in log_lines: 234 | f.write(line + "\n") 235 | else: 236 | for line in log_lines: 237 | print(line) 238 | sys.stdout.flush() 239 | 240 | if start_time is not None and stop_time is not None: 241 | tf.logging.info("MLPerf time: {:.1f} sec.".format(stop_time - start_time)) 242 | 243 | if __name__ == "__main__": 244 | tf.logging.set_verbosity(tf.logging.INFO) 245 | with LOGGER(True): 246 | ncf_print(key=TAGS.RUN_START) 247 | -------------------------------------------------------------------------------- /utils/misc/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/misc/distribution_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helper functions for running models in a distributed setting.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def get_distribution_strategy(num_gpus, all_reduce_alg=None): 25 | """Return a DistributionStrategy for running the model. 26 | 27 | Args: 28 | num_gpus: Number of GPUs to run this model. 29 | all_reduce_alg: Specify which algorithm to use when performing all-reduce. 30 | See tf.contrib.distribute.AllReduceCrossDeviceOps for available 31 | algorithms. If None, DistributionStrategy will choose based on device 32 | topology. 33 | 34 | Returns: 35 | tf.contrib.distribute.DistibutionStrategy object. 36 | """ 37 | if num_gpus == 0: 38 | return tf.contrib.distribute.OneDeviceStrategy("device:CPU:0") 39 | elif num_gpus == 1: 40 | return tf.contrib.distribute.OneDeviceStrategy("device:GPU:0") 41 | else: 42 | if all_reduce_alg: 43 | return tf.contrib.distribute.MirroredStrategy( 44 | num_gpus=num_gpus, 45 | cross_tower_ops=tf.contrib.distribute.AllReduceCrossDeviceOps( 46 | all_reduce_alg, num_packs=2)) 47 | else: 48 | return tf.contrib.distribute.MirroredStrategy(num_gpus=num_gpus) 49 | 50 | 51 | def per_device_batch_size(batch_size, num_gpus): 52 | """For multi-gpu, batch-size must be a multiple of the number of GPUs. 53 | 54 | Note that this should eventually be handled by DistributionStrategies 55 | directly. Multi-GPU support is currently experimental, however, 56 | so doing the work here until that feature is in place. 57 | 58 | Args: 59 | batch_size: Global batch size to be divided among devices. This should be 60 | equal to num_gpus times the single-GPU batch_size for multi-gpu training. 61 | num_gpus: How many GPUs are used with DistributionStrategies. 62 | 63 | Returns: 64 | Batch size per device. 65 | 66 | Raises: 67 | ValueError: if batch_size is not divisible by number of devices 68 | """ 69 | if num_gpus <= 1: 70 | return batch_size 71 | 72 | remainder = batch_size % num_gpus 73 | if remainder: 74 | err = ("When running with multiple GPUs, batch size " 75 | "must be a multiple of the number of available GPUs. Found {} " 76 | "GPUs with a batch size of {}; try --batch_size={} instead." 77 | ).format(num_gpus, batch_size, batch_size - remainder) 78 | raise ValueError(err) 79 | return int(batch_size / num_gpus) 80 | -------------------------------------------------------------------------------- /utils/misc/model_helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Miscellaneous functions that can be called by models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numbers 22 | 23 | import tensorflow as tf 24 | from tensorflow.python.util import nest 25 | 26 | 27 | def past_stop_threshold(stop_threshold, eval_metric): 28 | """Return a boolean representing whether a model should be stopped. 29 | 30 | Args: 31 | stop_threshold: float, the threshold above which a model should stop 32 | training. 33 | eval_metric: float, the current value of the relevant metric to check. 34 | 35 | Returns: 36 | True if training should stop, False otherwise. 37 | 38 | Raises: 39 | ValueError: if either stop_threshold or eval_metric is not a number 40 | """ 41 | if stop_threshold is None: 42 | return False 43 | 44 | if not isinstance(stop_threshold, numbers.Number): 45 | raise ValueError("Threshold for checking stop conditions must be a number.") 46 | if not isinstance(eval_metric, numbers.Number): 47 | raise ValueError("Eval metric being checked against stop conditions " 48 | "must be a number.") 49 | 50 | if eval_metric >= stop_threshold: 51 | tf.logging.info( 52 | "Stop threshold of {} was passed with metric value {}.".format( 53 | stop_threshold, eval_metric)) 54 | return True 55 | 56 | return False 57 | 58 | 59 | def generate_synthetic_data( 60 | input_shape, input_value=0, input_dtype=None, label_shape=None, 61 | label_value=0, label_dtype=None): 62 | """Create a repeating dataset with constant values. 63 | 64 | Args: 65 | input_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of 66 | the input data. 67 | input_value: Value of each input element. 68 | input_dtype: Input dtype. If None, will be inferred by the input value. 69 | label_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of 70 | the label data. 71 | label_value: Value of each input element. 72 | label_dtype: Input dtype. If None, will be inferred by the target value. 73 | 74 | Returns: 75 | Dataset of tensors or tuples of tensors (if label_shape is set). 76 | """ 77 | # TODO(kathywu): Replace with SyntheticDataset once it is in contrib. 78 | element = input_element = nest.map_structure( 79 | lambda s: tf.constant(input_value, input_dtype, s), input_shape) 80 | 81 | if label_shape: 82 | label_element = nest.map_structure( 83 | lambda s: tf.constant(label_value, label_dtype, s), label_shape) 84 | element = (input_element, label_element) 85 | 86 | return tf.data.Dataset.from_tensors(element).repeat() 87 | 88 | 89 | def apply_clean(flags_obj): 90 | if flags_obj.clean and tf.gfile.Exists(flags_obj.model_dir): 91 | tf.logging.info("--clean flag set. Removing existing model dir: {}".format( 92 | flags_obj.model_dir)) 93 | tf.gfile.DeleteRecursively(flags_obj.model_dir) 94 | --------------------------------------------------------------------------------