├── .gitignore ├── README.md ├── cifar10_download_and_extract.py ├── constants.py ├── dataset.py ├── models ├── __init__.py ├── densenet_121.py └── densenet_creator.py ├── setup.cfg ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # Environments 84 | .env 85 | .venv 86 | env/ 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # efficient_densenet_tensorflow 2 | A Tensorflow 1.9+ implementation of DenseNet-121, optimized to save GPU memory. 3 | 4 | Based on the following repo's code: 5 | 6 | https://github.com/Jiankai-Sun/Distributed-TensorFlow-Example/tree/master/CIFAR-10 7 | 8 | https://github.com/titu1994/keras-squeeze-excite-network 9 | 10 | 11 | ## Motivation 12 | While DenseNets are fairly easy to implement in deep learning frameworks, most 13 | implementations (such as the [original](https://github.com/liuzhuang13/DenseNet)) tend to be memory-hungry. 14 | In particular, the number of intermediate feature maps generated by batch normalization and concatenation operations 15 | grows quadratically with network depth. 16 | 17 | *It is worth emphasizing that this is not a property inherent to DenseNets, but rather to the implementation.* 18 | 19 | This implementation uses a new strategy to reduce the memory consumption of DenseNets. 20 | It is based on [efficient_densenet_pytorch](https://github.com/gpleiss/efficient_densenet_pytorch). 21 | It makes use of [checkpointing intermeditate features](https://www.tensorflow.org/versions/r1.5/api_docs/python/tf/contrib/layers/recompute_grad) and 22 | [alternate approach](https://github.com/openai/gradient-checkpointing). 23 | 24 | This adds 15-20% of time overhead for training, but **reduces feature map consumption from quadratic to linear.** 25 | 26 | For more details, please see the [technical report](https://arxiv.org/pdf/1707.06990.pdf). 27 | 28 | ## How to checkpoint 29 | Currently all of the dense layers are checkpointed, however you can alter the implementation to trade of speed and memory. 30 | For example by checkpointing earlier layers you remove intermediate checkpoints which are generally larger earlier on due to the 31 | pooling layers. 32 | 33 | However more strategies can be found in the [alternate approach](https://github.com/openai/gradient-checkpointing). 34 | 35 | ## Example setup for a 12gb Nvidia GPU 36 | `python train.py --batch_size 6000 --efficient True` 37 | 38 | `python train.py --batch_size 3750` 39 | 40 | ## Main piece of code: 41 | > models/densenet_creator.py#116 42 | ``` 43 | def _x(ip): 44 | x = batch_normalization(ip, **self.bn_kwargs) 45 | x = tf.nn.relu(x) 46 | 47 | if self.bottleneck: 48 | inter_channel = nb_filter * 4 49 | 50 | x = conv2d(x, inter_channel, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False, 51 | **self.conv_kwargs) 52 | x = batch_normalization(x, **self.bn_kwargs) 53 | x = tf.nn.relu(x) 54 | 55 | x = conv2d(x, nb_filter, (3, 3), kernel_initializer='he_normal', padding='same', use_bias=False, 56 | **self.conv_kwargs) 57 | 58 | if self.dropout_rate: 59 | x = dropout(x, self.dropout_rate, training=self.training) 60 | 61 | return x 62 | 63 | if self.efficient: 64 | # Gradient checkpoint the layer 65 | _x = tf.contrib.layers.recompute_grad(_x) 66 | 67 | ``` 68 | 69 | ## Requirement 70 | - Tensorflow 1.9+ 71 | - Horovod 72 | 73 | ## Usage 74 | If you care about speed, and memory is no object, pass the `efficient=False` argument into the `DenseNet` constructor. 75 | Otherwise, pass in `efficient=True`. 76 | 77 | Important Options: 78 | - `--batch_size` (int) - The number of images per batch (default 3750) 79 | 80 | - `--fp16` (bool) - Whether to run with FP16 or not (default False) 81 | 82 | - `--efficient` (bool) - Whether to run with gradient checkpointing or not (default False) 83 | 84 | 85 | ## Reference 86 | 87 | ``` 88 | @article{pleiss2017memory, 89 | title={Memory-Efficient Implementation of DenseNets}, 90 | author={Pleiss, Geoff and Chen, Danlu and Huang, Gao and Li, Tongcheng and van der Maaten, Laurens and Weinberger, Kilian Q}, 91 | journal={arXiv preprint arXiv:1707.06990}, 92 | year={2017} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /cifar10_download_and_extract.py: -------------------------------------------------------------------------------- 1 | """Downloads and extracts the binary version of the CIFAR-10 dataset.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import tarfile 11 | 12 | from six.moves import urllib 13 | import tensorflow as tf 14 | 15 | DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' 16 | 17 | parser = argparse.ArgumentParser() 18 | 19 | parser.add_argument( 20 | '--data_dir', type=str, default='/tmp/cifar10_data', 21 | help='Directory to download data and extract the tarball') 22 | 23 | 24 | def main(unused_argv): 25 | """Download and extract the tarball from Alex's website.""" 26 | if not os.path.exists(FLAGS.data_dir): 27 | os.makedirs(FLAGS.data_dir) 28 | 29 | filename = DATA_URL.split('/')[-1] 30 | filepath = os.path.join(FLAGS.data_dir, filename) 31 | 32 | if not os.path.exists(filepath): 33 | def _progress(count, block_size, total_size): 34 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 35 | filename, 100.0 * count * block_size / total_size)) 36 | sys.stdout.flush() 37 | 38 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 39 | print() 40 | statinfo = os.stat(filepath) 41 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 42 | 43 | tarfile.open(filepath, 'r:gz').extractall(FLAGS.data_dir) 44 | 45 | 46 | if __name__ == '__main__': 47 | FLAGS, unparsed = parser.parse_known_args() 48 | tf.app.run(argv=[sys.argv[0]] + unparsed) 49 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | HEIGHT = 32 2 | WIDTH = 32 3 | DEPTH = 3 4 | NUM_CLASSES = 10 5 | NUM_DATA_FILES = 5 6 | 7 | # We use a weight decay of 0.0002, which performs better than the 0.0001 that 8 | # was originally suggested. 9 | WEIGHT_DECAY = 2e-4 10 | MOMENTUM = 0.9 11 | 12 | NUM_IMAGES = { 13 | 'train': 50000, 14 | 'validation': 10000, 15 | } -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | 5 | from constants import HEIGHT, WIDTH, DEPTH, NUM_CLASSES, NUM_DATA_FILES, NUM_IMAGES 6 | 7 | 8 | def record_dataset(filenames): 9 | """Returns an input pipeline Dataset from `filenames`.""" 10 | record_bytes = HEIGHT * WIDTH * DEPTH + 1 11 | return tf.data.FixedLengthRecordDataset(filenames, record_bytes) 12 | 13 | 14 | def get_filenames(is_training, data_dir): 15 | """Returns a list of filenames.""" 16 | data_dir = os.path.join(data_dir, 'cifar-10-batches-bin') 17 | 18 | assert os.path.exists(data_dir), ( 19 | 'Run cifar10_download_and_extract.py first to download and extract the ' 20 | 'CIFAR-10 data.') 21 | 22 | if is_training: 23 | return [ 24 | os.path.join(data_dir, 'data_batch_%d.bin' % i) 25 | for i in range(1, NUM_DATA_FILES + 1) 26 | ] 27 | else: 28 | return [os.path.join(data_dir, 'test_batch.bin')] 29 | 30 | 31 | def parse_record(raw_record): 32 | """Parse CIFAR-10 image and label from a raw record.""" 33 | # Every record consists of a label followed by the image, with a fixed number 34 | # of bytes for each. 35 | label_bytes = 1 36 | image_bytes = HEIGHT * WIDTH * DEPTH 37 | record_bytes = label_bytes + image_bytes 38 | 39 | # Convert bytes to a vector of uint8 that is record_bytes long. 40 | record_vector = tf.decode_raw(raw_record, tf.uint8) 41 | 42 | # The first byte represents the label, which we convert from uint8 to int32 43 | # and then to one-hot. 44 | label = tf.cast(record_vector[0], tf.int32) 45 | label = tf.one_hot(label, NUM_CLASSES) 46 | 47 | # The remaining bytes after the label represent the image, which we reshape 48 | # from [depth * height * width] to [depth, height, width]. 49 | depth_major = tf.reshape( 50 | record_vector[label_bytes:record_bytes], [DEPTH, HEIGHT, WIDTH]) 51 | 52 | # Convert from [depth, height, width] to [height, width, depth], and cast as 53 | # float32. 54 | image = tf.cast(tf.transpose(depth_major, [1, 2, 0]), tf.float32) 55 | 56 | return image, label 57 | 58 | 59 | def preprocess_image(image, is_training): 60 | """Preprocess a single image of layout [height, width, depth].""" 61 | if is_training: 62 | # Resize the image to add four extra pixels on each side. 63 | image = tf.image.resize_image_with_crop_or_pad( 64 | image, HEIGHT + 8, WIDTH + 8) 65 | 66 | # Randomly crop a [_HEIGHT, _WIDTH] section of the image. 67 | image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH]) 68 | 69 | # Randomly flip the image horizontally. 70 | image = tf.image.random_flip_left_right(image) 71 | 72 | # Subtract off the mean and divide by the variance of the pixels. 73 | image = tf.image.per_image_standardization(image) 74 | return image 75 | 76 | 77 | def input_fn(is_training, data_dir, batch_size, num_epochs=1): 78 | """Input_fn using the tf.data input pipeline for CIFAR-10 dataset. 79 | 80 | Args: 81 | is_training: A boolean denoting whether the input is for training. 82 | data_dir: The directory containing the input data. 83 | batch_size: The number of samples per batch. 84 | num_epochs: The number of epochs to repeat the dataset. 85 | 86 | Returns: 87 | A tuple of images and labels. 88 | """ 89 | dataset = record_dataset(get_filenames(is_training, data_dir)) 90 | 91 | if is_training: 92 | # When choosing shuffle buffer sizes, larger sizes result in better 93 | # randomness, while smaller sizes have better performance. Because CIFAR-10 94 | # is a relatively small dataset, we choose to shuffle the full epoch. 95 | dataset = dataset.shuffle(buffer_size=NUM_IMAGES['train']) 96 | 97 | dataset = dataset.map(parse_record) 98 | dataset = dataset.map( 99 | lambda image, label: (preprocess_image(image, is_training), label), num_parallel_calls=4) 100 | 101 | dataset = dataset.prefetch(2 * batch_size) 102 | 103 | # We call repeat after shuffling, rather than before, to prevent separate 104 | # epochs from blending together. 105 | dataset = dataset.repeat(num_epochs) 106 | 107 | # Batch results by up to batch_size, and then fetch the tuple from the 108 | # iterator. 109 | dataset = dataset.batch(batch_size) 110 | iterator = dataset.make_one_shot_iterator() 111 | images, labels = iterator.get_next() 112 | 113 | return images, labels 114 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .densenet_creator import DenseNetCreator 2 | 3 | 4 | __all__ = [ 5 | DenseNetCreator, 6 | ] 7 | -------------------------------------------------------------------------------- /models/densenet_121.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from .densenet_creator import DenseNetCreator 4 | 5 | 6 | def get_model(img, classes, data_format, efficient): 7 | if data_format == 'channels_first': 8 | img = tf.transpose(img, [0, 3, 1, 2]) 9 | 10 | return DenseNetCreator(img, classes, data_format=data_format, depth=121, efficient=efficient, nb_dense_block=4, 11 | growth_rate=32, nb_filter=64, nb_layers_per_block=[6, 12, 24, 16], bottleneck=True, 12 | reduction=.5, dropout_rate=0., subsample_initial_block=True, include_top=True)() 13 | -------------------------------------------------------------------------------- /models/densenet_creator.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.layers import average_pooling2d, batch_normalization, conv2d, dense, dropout, max_pooling2d 4 | from tensorflow.keras.layers import concatenate, GlobalAveragePooling2D 5 | 6 | 7 | class DenseNetCreator: 8 | def __init__(self, img_input, nb_classes, bottleneck=False, data_format='channels_first', depth=40, dropout_rate=0., 9 | efficient=False, growth_rate=12, include_top=True, nb_dense_block=3, nb_filter=-1, 10 | nb_layers_per_block=-1, training=True, trainable=True, reduction=0.0, subsample_initial_block=False): 11 | """ Initialise the DenseNet model creator. 12 | 13 | Args: 14 | img_input (tensor): Input tensor. 15 | nb_classes (int): number of classes 16 | bottleneck (bool, default: False): use bottleneck blocks 17 | data_format (str, default: 'channels_first'): The dataformat to use for the network 18 | depth (int, default: 40): number of layers 19 | dropout_rate (float, default:0.): dropout rate 20 | efficient (bool, default: False): Whether to run the slower but more memory efficient model or not 21 | growth_rate (int, default: 12): number of filters to add per dense block 22 | include_top (bool, default: True): Whether to include a dense classification head. 23 | nb_dense_block (int, default: 3): number of dense blocks to add to end (generally = 3) 24 | nb_filter (int, default: -1): initial number of filters. Default -1 indicates initial number of 25 | filters is 2 * growth_rate 26 | nb_layers_per_block: number of layers in each dense block. 27 | Can be a -1, positive integer or a list. 28 | If -1, calculates nb_layer_per_block from the depth of the network. 29 | If positive integer, a set number of layers per dense block. 30 | If list, nb_layer is used as provided. Note that list size must 31 | be (nb_dense_block + 1) 32 | training (bool): Whether it is training or not. 33 | trainable (bool): Whether to add the vairbales to tf.GraphKeys.TRAINABLE_VARIABLES 34 | reduction (float, default: 0.): reduction factor of transition blocks. Note : reduction value is 35 | inverted to compute compression 36 | subsample_initial_block (bool, defualt: False): Set to True to subsample the initial convolution and 37 | add a MaxPool2D before the dense blocks are added. 38 | 39 | """ 40 | self.axis = 1 if data_format == 'channels_first' else 3 41 | 42 | self.bottleneck = bottleneck 43 | self.bn_kwargs = {'fused': True, 44 | 'axis': self.axis, 45 | 'training': training, 46 | 'trainable': trainable} 47 | 48 | self.conv_kwargs = {'data_format': data_format, 'trainable': trainable} 49 | 50 | self.data_format = data_format 51 | self.depth = depth 52 | self.dropout_rate = dropout_rate 53 | 54 | self.efficient = efficient 55 | 56 | self.growth_rate = growth_rate 57 | 58 | self.img_input = img_input 59 | self.include_top = include_top 60 | 61 | self.nb_classes = nb_classes 62 | self.nb_dense_block = nb_dense_block 63 | self.nb_filter = nb_filter 64 | 65 | self.subsample_initial_block = subsample_initial_block 66 | 67 | self.training = training 68 | self.trainable = trainable 69 | 70 | if reduction != 0.0: 71 | assert 0 < reduction <= 1.0, 'reduction value must lie between 0.0 and 1.0' 72 | 73 | # layers in each dense block 74 | if type(nb_layers_per_block) is list or type(nb_layers_per_block) is tuple: 75 | nb_layers = list(nb_layers_per_block) # Convert tuple to list 76 | 77 | assert len(nb_layers) == nb_dense_block, 'If list, nb_layer is used as provided. ' \ 78 | 'Note that list size must be (nb_dense_block)' 79 | self.final_nb_layer = nb_layers[-1] 80 | self.nb_layers = nb_layers[:-1] 81 | else: 82 | if nb_layers_per_block == -1: 83 | assert (depth - 4) % 3 == 0, 'Depth must be 3 N + 4 if nb_layers_per_block == -1' 84 | count = int((depth - 4) / 3) 85 | self.nb_layers = [count for _ in range(nb_dense_block)] 86 | self.final_nb_layer = count 87 | else: 88 | self.final_nb_layer = nb_layers_per_block 89 | self.nb_layers = [nb_layers_per_block] * nb_dense_block 90 | 91 | # compute initial nb_filter if -1, else accept users initial nb_filter 92 | if self.nb_filter <= 0: 93 | self.nb_filter = 2 * self.growth_rate 94 | 95 | # compute compression factor 96 | self.compression = 1.0 - reduction 97 | 98 | # Initial convolution 99 | if self.subsample_initial_block: 100 | self.initial_kernel = (7, 7) 101 | self.initial_strides = (2, 2) 102 | else: 103 | self.initial_kernel = (3, 3) 104 | self.initial_strides = (1, 1) 105 | 106 | def _conv_block(self, ip, nb_filter): 107 | """ Apply BatchNorm, Relu, 3x3 Conv2D, optional bottleneck block and dropout 108 | 109 | Args: 110 | ip: Input tensor 111 | nb_filter: number of filters 112 | 113 | Returns: tensor with batch_norm, relu and convolution2d added (optional bottleneck) 114 | """ 115 | 116 | def _x(ip): 117 | x = batch_normalization(ip, **self.bn_kwargs) 118 | x = tf.nn.relu(x) 119 | 120 | if self.bottleneck: 121 | inter_channel = nb_filter * 4 122 | 123 | x = conv2d(x, inter_channel, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False, 124 | **self.conv_kwargs) 125 | x = batch_normalization(x, **self.bn_kwargs) 126 | x = tf.nn.relu(x) 127 | 128 | x = conv2d(x, nb_filter, (3, 3), kernel_initializer='he_normal', padding='same', use_bias=False, 129 | **self.conv_kwargs) 130 | 131 | if self.dropout_rate: 132 | x = dropout(x, self.dropout_rate, training=self.training) 133 | 134 | return x 135 | 136 | if self.efficient: 137 | # Gradient checkpoint the layer 138 | _x = tf.contrib.layers.recompute_grad(_x) 139 | 140 | return _x(ip) 141 | 142 | def _dense_block(self, x, nb_layers, nb_filter, grow_nb_filters=True, return_concat_list=False): 143 | """ Build a dense_block where the output of each conv_block is fed to subsequent ones 144 | 145 | Args: 146 | x: tensor 147 | nb_layers: the number of layers of conv_block to append to the model. 148 | nb_filter: number of filters 149 | grow_nb_filters: flag to decide to allow number of filters to grow 150 | return_concat_list: return the list of feature maps along with the actual output 151 | 152 | Returns: 153 | tensor with nb_layers of conv_block appended 154 | """ 155 | x_list = [x] 156 | 157 | for i in range(nb_layers): 158 | with tf.variable_scope('denselayer_{}'.format(i), use_resource=True): 159 | cb = self._conv_block(x, self.growth_rate) 160 | x_list.append(cb) 161 | 162 | x = concatenate([x, cb], self.axis) 163 | 164 | if grow_nb_filters: 165 | nb_filter += self.growth_rate 166 | 167 | if self.dropout_rate: 168 | x = dropout(x, self.dropout_rate, training=self.training) 169 | 170 | if return_concat_list: 171 | return x, nb_filter, x_list 172 | else: 173 | return x, nb_filter 174 | 175 | def _transition_block(self, ip, nb_filter): 176 | """ Apply BatchNorm, Relu 1x1, Conv2D, optional compression, dropout and Maxpooling2D 177 | 178 | Args: 179 | ip: tensor 180 | nb_filter: number of filters 181 | compression: calculated as 1 - reduction. Reduces the number of feature maps 182 | in the transition block. 183 | dropout_rate: dropout rate 184 | weight_decay: weight decay factor 185 | 186 | Returns: 187 | tensor, after applying batch_norm, relu-conv, dropout, maxpool 188 | """ 189 | x = batch_normalization(ip, **self.bn_kwargs) 190 | x = tf.nn.relu(x) 191 | x = conv2d(x, int(nb_filter * self.compression), (1, 1), kernel_initializer='he_normal', 192 | padding='same', use_bias=False, **self.conv_kwargs) 193 | x = average_pooling2d(x, (2, 2), strides=(2, 2), data_format=self.data_format) 194 | 195 | return x 196 | 197 | def __call__(self): 198 | """ Builds the network. """ 199 | x = conv2d(self.img_input, self.nb_filter, self.initial_kernel, kernel_initializer='he_normal', padding='same', 200 | strides=self.initial_strides, use_bias=False, **self.conv_kwargs) 201 | 202 | if self.subsample_initial_block: 203 | x = batch_normalization(x, **self.bn_kwargs) 204 | x = tf.nn.relu(x) 205 | x = max_pooling2d(x, (3, 3), data_format=self.data_format, strides=(2, 2), padding='same') 206 | 207 | # Add dense blocks 208 | nb_filter = self.nb_filter 209 | for block_idx in range(self.nb_dense_block - 1): 210 | with tf.variable_scope('denseblock_{}'.format(block_idx)): 211 | x, nb_filter = self._dense_block(x, self.nb_layers[block_idx], nb_filter) 212 | # add transition_block 213 | x = self._transition_block(x, nb_filter) 214 | nb_filter = int(nb_filter * self.compression) 215 | 216 | # The last dense_block does not have a transition_block 217 | x, nb_filter = self._dense_block(x, self.final_nb_layer, self.nb_filter) 218 | 219 | x = batch_normalization(x, **self.bn_kwargs) 220 | x = tf.nn.relu(x) 221 | 222 | x = GlobalAveragePooling2D(data_format=self.data_format)(x) 223 | 224 | if self.include_top: 225 | x = dense(x, self.nb_classes) 226 | 227 | return x 228 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bdist_wheel] 2 | universal=1 3 | 4 | [pep8] 5 | max-line-length = 120 6 | 7 | [flake8] 8 | max-line-length = 120 9 | ignore = F403, F405, E128 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Runs a ResNet model on the CIFAR-10 dataset.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | from constants import HEIGHT, WIDTH, DEPTH, NUM_CLASSES, NUM_IMAGES, MOMENTUM, WEIGHT_DECAY 8 | from dataset import input_fn 9 | from models import densenet_121 10 | from utils import float32_variable_storage_getter 11 | 12 | import argparse 13 | import os 14 | import sys 15 | import horovod.tensorflow as hvd 16 | 17 | import tensorflow as tf 18 | 19 | from tensorflow.contrib.mixed_precision import ExponentialUpdateLossScaleManager, LossScaleOptimizer 20 | 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | # Basic model parameters. 25 | parser.add_argument('--data_dir', type=str, default='/tmp/cifar10_data', 26 | help='The path to the CIFAR-10 data directory.') 27 | 28 | parser.add_argument('--model_dir', type=str, default='/tmp/cifar10_model', 29 | help='The directory where the model will be stored.') 30 | 31 | parser.add_argument('--train_epochs', type=int, default=250, 32 | help='The number of epochs to train.') 33 | 34 | parser.add_argument('--epochs_per_eval', type=int, default=10, 35 | help='The number of epochs to run in between evaluations.') 36 | 37 | parser.add_argument('--batch_size', type=int, default=3750, 38 | help='The number of images per batch.') 39 | 40 | parser.add_argument('--fp16', type = lambda x: (str(x).lower() == 'true'), default=False, 41 | help='Whether to run with FP16 or not.') 42 | 43 | parser.add_argument('--efficient', type = lambda x: (str(x).lower() == 'true'), default=False, 44 | help='Whether to run with gradient checkpointing or not.') 45 | 46 | parser.add_argument( 47 | '--data_format', type=str, default='channels_first', 48 | choices=['channels_first', 'channels_last'], 49 | help='A flag to override the data format used in the model. channels_first ' 50 | 'provides a performance boost on GPU but is not always compatible ' 51 | 'with CPU. If left unspecified, the data format will be chosen ' 52 | 'automatically based on whether TensorFlow was built for CPU or GPU.') 53 | 54 | 55 | def cifar10_model_fn(features, labels, params): 56 | print('PARAMS', params['fp16']) 57 | """Model function for CIFAR-10.""" 58 | tf.summary.image('images', features, max_outputs=6) 59 | 60 | inputs = tf.reshape(features, [-1, HEIGHT, WIDTH, DEPTH]) 61 | if params['fp16']: 62 | inputs = tf.cast(inputs, tf.float16) 63 | 64 | logits = densenet_121.get_model(inputs, NUM_CLASSES, params['data_format'], params['efficient']) 65 | logits = tf.cast(logits, tf.float32) 66 | 67 | predictions = { 68 | 'classes': tf.argmax(logits, axis=1), 69 | 'probabilities': tf.nn.softmax(logits, name='softmax_tensor') 70 | } 71 | 72 | # Calculate loss, which includes softmax cross entropy and L2 regularization. 73 | cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) 74 | 75 | # Create a tensor named cross_entropy for logging purposes. 76 | tf.identity(cross_entropy, name='cross_entropy') 77 | tf.summary.scalar('cross_entropy', cross_entropy) 78 | 79 | # Add weight decay to the loss. 80 | loss = cross_entropy + WEIGHT_DECAY * tf.add_n( 81 | [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) 82 | 83 | # Scale the learning rate linearly with the batch size. When the batch size 84 | # is 128, the learning rate should be 0.1. 85 | initial_learning_rate = 0.1 * params['batch_size'] / 128 86 | batches_per_epoch = NUM_IMAGES['train'] / params['batch_size'] 87 | global_step = tf.train.get_or_create_global_step() 88 | 89 | # Multiply the learning rate by 0.1 at 100, 150, and 200 epochs. 90 | boundaries = [int(batches_per_epoch * epoch) for epoch in [100, 150, 200]] 91 | values = [initial_learning_rate * decay for decay in [1, 0.1, 0.01, 0.001]] 92 | learning_rate = tf.train.piecewise_constant( 93 | tf.cast(global_step, tf.int32), boundaries, values) 94 | 95 | # Create a tensor named learning_rate for logging purposes 96 | tf.identity(learning_rate, name='learning_rate') 97 | tf.summary.scalar('learning_rate', learning_rate) 98 | 99 | optimizer = tf.train.MomentumOptimizer( 100 | learning_rate=learning_rate, 101 | momentum=MOMENTUM) 102 | 103 | if params['fp16']: 104 | # Choose a loss scale manager which decides how to pick the right loss scale 105 | # throughout the training process. 106 | loss_scale_manager = ExponentialUpdateLossScaleManager(128, 100) 107 | # Wraps the original optimizer in a LossScaleOptimizer. 108 | optimizer = LossScaleOptimizer(optimizer, loss_scale_manager) 109 | 110 | compression = hvd.Compression.fp16 if params['fp16'] else hvd.Compression.none 111 | 112 | optimizer = hvd.DistributedOptimizer(optimizer, compression=compression) 113 | 114 | # Batch norm requires update ops to be added as a dependency to the train_op 115 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 116 | with tf.control_dependencies(update_ops): 117 | train_op = optimizer.minimize(loss, global_step) 118 | 119 | accuracy = tf.metrics.accuracy( 120 | tf.argmax(labels, axis=1), predictions['classes']) 121 | metrics = {'accuracy': accuracy} 122 | 123 | # Create a tensor named train_accuracy for logging purposes 124 | tf.identity(accuracy[1], name='train_accuracy') 125 | tf.summary.scalar('train_accuracy', accuracy[1]) 126 | 127 | return train_op, loss, global_step 128 | 129 | 130 | def main(unused_argv): 131 | # Initialize Horovod. 132 | hvd.init() 133 | 134 | # Using the Winograd non-fused algorithms provides a small performance boost. 135 | os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' 136 | 137 | custom_params = { 138 | 'data_format': FLAGS.data_format, 139 | 'batch_size': FLAGS.batch_size, 140 | 'fp16': FLAGS.fp16, 141 | 'efficient': FLAGS.efficient 142 | } 143 | 144 | features, labels = input_fn(True, FLAGS.data_dir, FLAGS.batch_size, None) 145 | with tf.variable_scope('model', custom_getter=float32_variable_storage_getter): 146 | train_op, loss, global_step = cifar10_model_fn(features, labels, custom_params) 147 | 148 | # BroadcastGlobalVariablesHook broadcasts initial variable states from rank 0 149 | # to all other processes. This is necessary to ensure consistent initialization 150 | # of all workers when training is started with random weights or restored 151 | # from a checkpoint. 152 | hooks = [hvd.BroadcastGlobalVariablesHook(0), 153 | tf.train.StopAtStepHook(last_step=10000), 154 | tf.train.LoggingTensorHook(tensors={'step': global_step, 'loss': loss}, 155 | every_n_iter=10), 156 | ] 157 | 158 | # Pin GPU to be used to process local rank (one GPU per process) 159 | config = tf.ConfigProto() 160 | config.gpu_options.allow_growth = True 161 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 162 | 163 | # Save checkpoints only on worker 0 to prevent other workers from corrupting them. 164 | checkpoint_dir = './checkpoints' if hvd.rank() == 0 else None 165 | 166 | # The MonitoredTrainingSession takes care of session initialization, 167 | # restoring from a checkpoint, saving to a checkpoint, and closing when done 168 | # or an error occurs. 169 | with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir, 170 | hooks=hooks, 171 | config=config) as mon_sess: 172 | while not mon_sess.should_stop(): 173 | # Run a training step synchronously. 174 | mon_sess.run(train_op) 175 | 176 | 177 | if __name__ == '__main__': 178 | tf.logging.set_verbosity(tf.logging.INFO) 179 | FLAGS, unparsed = parser.parse_known_args() 180 | tf.app.run(argv=[sys.argv[0]] + unparsed) 181 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def float32_variable_storage_getter(getter, name, shape=None, dtype=None, 5 | initializer=None, regularizer=None, 6 | trainable=True, 7 | *args, **kwargs): 8 | """Custom variable getter that forces trainable variables to be stored in 9 | float32 precision and then casts them to the training precision. 10 | """ 11 | storage_dtype = tf.float32 if trainable else dtype 12 | variable = getter(name, shape, dtype=storage_dtype, 13 | initializer=initializer, regularizer=regularizer, 14 | trainable=trainable, 15 | *args, **kwargs) 16 | if trainable and dtype != tf.float32: 17 | variable = tf.cast(variable, dtype) 18 | return variable --------------------------------------------------------------------------------