├── README.md ├── args.py ├── custom_layers.py ├── eval.py ├── model_def.py ├── model_profiling.py ├── models.py └── teaser.png /README.md: -------------------------------------------------------------------------------- 1 | # Neural Architecture Search for Lightweight Non-Local Networks 2 | 3 | This repository contains the code for CVPR 2020 paper [Neural Architecture Search for Lightweight Non-Local Networks](https://cs.jhu.edu/~alanlab/Pubs20/li2020neural.pdf). 4 | This paper presents a lightweight non-local block and automatically searched state-of-the-art non-local networks for mobile vision. 5 | We also provide pytorch implementation [here](https://github.com/meijieru/yet_another_mobilenet_series). 6 | 7 |
8 | 9 |
10 | 11 | If you use the code, please cite: 12 | 13 | @inproceedings{li2020neural, 14 | title={Neural Architecture Search for Lightweight Non-local Networks}, 15 | author={Li, Yingwei and Jin, Xiaojie and Mei, Jieru and Lian, Xiaochen and Yang, Linjie and Xie, Cihang and Yu, Qihang and Zhou, Yuyin and Bai, Song and Yuille, Alan}, 16 | booktitle={CVPR}, 17 | year={2020} 18 | } 19 | 20 | ## Requirements 21 | TensorFlow 1.14.0 22 | 23 | tensorpack 0.9.8 (for dataset loading) 24 | 25 | ## Model Preparation 26 | Download the [AutoNL-L-77.7.zip](https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/yli286_jh_edu/EcfjxufrZTNLkxQG_929cPABhwmfBupJreOQSMlIm18Tvg?e=ZOWJIm) and [AutoNL-S-76.5.zip](https://livejohnshopkins-my.sharepoint.com/:u:/g/personal/yli286_jh_edu/ES89oOHhIeBBpRCO76vaspAB1hmFytENyJGHSOwI__3aWw?e=VghMRF) pretrained models. 27 | Unzip and place them at the root directory of the source code. 28 | 29 | ## Usage 30 | Download and place the ImageNet validation set at ```$PATH_TO_IMAGENET/val```. 31 | ```bash 32 | python eval.py --model_dir=AutoNL-S-76.5 --valdir=$PATH_TO_IMAGENET/val --arch=AutoNL-S-76.5/arch.txt 33 | python eval.py --model_dir=AutoNL-L-77.7 --valdir=$PATH_TO_IMAGENET/val --arch=AutoNL-L-77.7/arch.txt 34 | ``` 35 | The last printed line should read: 36 | ``` 37 | Test: [50000/50000] Prec@1 77.7 Prec@5 93.7 38 | ``` 39 | for AutoNL-L, and 40 | ``` 41 | Test: [50000/50000] Prec@1 76.5 Prec@5 93.1 42 | ``` 43 | for AutoNL-S. 44 | 45 | ## Acknowledgements 46 | Part of code comes from [single-path-nas](https://github.com/dstamoulis/single-path-nas), [mnasnet](https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet) 47 | and [ImageNet-Adversarial-Training](https://github.com/facebookresearch/ImageNet-Adversarial-Training). 48 | 49 | If you encounter any problems or have any inquiries, please contact us at yingwei.li@jhu.edu. 50 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | from absl import flags 2 | 3 | FLAGS = flags.FLAGS 4 | flags.DEFINE_string('model_dir', default=None, 5 | help='The directory where the model and training/evaluation summaries are stored.') 6 | flags.DEFINE_integer('input_image_size', default=224, help='Input image size.') 7 | flags.DEFINE_float('moving_average_decay', default=0.9999, help='Moving average decay rate.') 8 | flags.DEFINE_string('arch', default=None, help='The directory where the output of SP-NAS search is stored.') 9 | flags.DEFINE_bool('use_keras', default=True, help='whether use keras') 10 | flags.DEFINE_bool('nl_zero_init', default=True, help='whether use zero initializer to initialize non local bn') 11 | flags.DEFINE_string('valdir', default='/mnt/cephfs_hl/uslabcv/yingwei/datasets/imagenet/val', 12 | help='validation dir') -------------------------------------------------------------------------------- /custom_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Custom layers.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | 24 | def _get_conv2d(filters, kernel_size, use_keras, **kwargs): 25 | """A helper function to create Conv2D layer.""" 26 | if use_keras: 27 | return tf.keras.layers.Conv2D( 28 | filters=filters, kernel_size=kernel_size, **kwargs) 29 | else: 30 | return tf.layers.Conv2D(filters=filters, kernel_size=kernel_size, **kwargs) 31 | 32 | 33 | def _split_channels(total_filters, num_groups): 34 | split = [total_filters // num_groups for _ in range(num_groups)] 35 | split[0] += total_filters - sum(split) 36 | return split 37 | 38 | 39 | class GroupedConv2D(object): 40 | """Groupped convolution. 41 | 42 | Currently tf.keras and tf.layers don't support group convolution, so here we 43 | use split/concat to implement this op. It reuses kernel_size for group 44 | definition, where len(kernel_size) is number of groups. Notably, it allows 45 | different group has different kernel size. 46 | """ 47 | 48 | def __init__(self, filters, kernel_size, use_keras, **kwargs): 49 | """Initialize the layer. 50 | 51 | Args: 52 | filters: Integer, the dimensionality of the output space. 53 | kernel_size: An integer or a list. If it is a single integer, then it is 54 | same as the original Conv2D. If it is a list, then we split the channels 55 | and perform different kernel for each group. 56 | use_keras: An boolean value, whether to use keras layer. 57 | **kwargs: other parameters passed to the original conv2d layer. 58 | """ 59 | self._groups = len(kernel_size) 60 | self._channel_axis = -1 61 | 62 | self._convs = [] 63 | splits = _split_channels(filters, self._groups) 64 | for i in range(self._groups): 65 | self._convs.append( 66 | _get_conv2d(splits[i], kernel_size[i], use_keras, **kwargs)) 67 | 68 | def __call__(self, inputs): 69 | if len(self._convs) == 1: 70 | return self._convs[0](inputs) 71 | 72 | filters = inputs.shape[self._channel_axis].value 73 | splits = _split_channels(filters, len(self._convs)) 74 | x_splits = tf.split(inputs, splits, self._channel_axis) 75 | x_outputs = [c(x) for x, c in zip(x_splits, self._convs)] 76 | x = tf.concat(x_outputs, self._channel_axis) 77 | return x 78 | 79 | 80 | class MDConv(object): 81 | """MDConv with mixed depthwise convolutional kernels. 82 | 83 | MDConv is an improved depthwise convolution that mixes multiple kernels (e.g. 84 | 3x3, 5x5, etc). Right now, we use an naive implementation that split channels 85 | into multiple groups and perform different kernels for each group. 86 | 87 | See Mixnet paper for more details. 88 | """ 89 | 90 | def __init__(self, kernel_size, strides, dilated=False, **kwargs): 91 | """Initialize the layer. 92 | 93 | Most of args are the same as tf.keras.layers.DepthwiseConv2D except it has 94 | an extra parameter "dilated" to indicate whether to use dilated conv to 95 | simulate large kernel size. If dilated=True, then dilation_rate is ignored. 96 | 97 | Args: 98 | kernel_size: An integer or a list. If it is a single integer, then it is 99 | same as the original tf.keras.layers.DepthwiseConv2D. If it is a list, 100 | then we split the channels and perform different kernel for each group. 101 | strides: An integer or tuple/list of 2 integers, specifying the strides of 102 | the convolution along the height and width. 103 | dilated: Bool. indicate whether to use dilated conv to simulate large 104 | kernel size. 105 | **kwargs: other parameters passed to the original depthwise_conv layer. 106 | """ 107 | self._channel_axis = -1 108 | self._dilated = dilated 109 | 110 | self._convs = [] 111 | for s in kernel_size: 112 | d = 1 113 | if strides[0] == 1 and self._dilated: 114 | # Only apply dilated conv for stride 1 if needed. 115 | d, s = (s - 1) // 2, 3 116 | tf.logging.info('Use dilated conv with dilation rate = {}'.format(d)) 117 | self._convs.append( 118 | tf.keras.layers.DepthwiseConv2D( 119 | s, strides=strides, dilation_rate=d, **kwargs)) 120 | 121 | def __call__(self, inputs): 122 | if len(self._convs) == 1: 123 | return self._convs[0](inputs) 124 | 125 | filters = inputs.shape[self._channel_axis].value 126 | splits = _split_channels(filters, len(self._convs)) 127 | x_splits = tf.split(inputs, splits, self._channel_axis) 128 | x_outputs = [c(x) for x, c in zip(x_splits, self._convs)] 129 | x = tf.concat(x_outputs, self._channel_axis) 130 | return x 131 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from absl import app 8 | from tensorpack import dataset, DataFromList, MultiThreadMapData, BatchData 9 | 10 | import models 11 | from args import FLAGS 12 | 13 | CROP_PADDING = 32 14 | MEAN_RGB = [0.485 * 255, 0.456 * 255, 0.406 * 255] 15 | STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] 16 | 17 | 18 | def get_val_dataflow(datadir, parallel=1): 19 | assert datadir is not None 20 | ds = dataset.ILSVRC12Files(datadir, 'val', shuffle=False) 21 | 22 | def mapf(dp): 23 | fname, cls = dp 24 | with open(fname, "rb") as f: 25 | im_bytes = f.read() 26 | return im_bytes, cls 27 | 28 | ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=min(2000, ds.size()), strict=True) 29 | return ds 30 | 31 | 32 | def _decode_and_center_crop(image_bytes, image_size): 33 | """Crops to center of image with padding then scales image_size.""" 34 | shape = tf.image.extract_jpeg_shape(image_bytes) 35 | image_height = shape[0] 36 | image_width = shape[1] 37 | 38 | padded_center_crop_size = tf.cast( 39 | ((image_size / (image_size + CROP_PADDING)) * 40 | tf.cast(tf.minimum(image_height, image_width), tf.float32)), 41 | tf.int32) 42 | 43 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 44 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 45 | crop_window = tf.stack([offset_height, offset_width, 46 | padded_center_crop_size, padded_center_crop_size]) 47 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 48 | image = tf.image.resize_bicubic([image], [image_size, image_size])[0] 49 | 50 | return image 51 | 52 | 53 | def preprocess_for_eval(image_bytes, image_size, scope=None): 54 | with tf.name_scope(scope, 'eval_image', [image_bytes, image_size, image_size]): 55 | image = _decode_and_center_crop(image_bytes, image_size) 56 | image = tf.reshape(image, [image_size, image_size, 3]) 57 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 58 | 59 | return image 60 | 61 | 62 | def main(unused_argv): 63 | # get input 64 | image_ph = tf.placeholder(tf.string) 65 | image_proc = preprocess_for_eval(image_ph, FLAGS.input_image_size) 66 | images = tf.expand_dims(image_proc, 0) 67 | images -= tf.constant(MEAN_RGB, shape=[1, 1, 3], dtype=images.dtype) 68 | images /= tf.constant(STDDEV_RGB, shape=[1, 1, 3], dtype=images.dtype) 69 | 70 | override_params = {'data_format': 'channels_last', 'num_classes': 1000} 71 | logits, _, _ = models.build_model( 72 | images, training=False, 73 | override_params=override_params, 74 | arch=FLAGS.arch) 75 | 76 | config = tf.ConfigProto() 77 | config.gpu_options.allow_growth = True 78 | sess = tf.Session(config=config) 79 | ckpt_path = os.path.join(FLAGS.model_dir, "bestmodel.ckpt") 80 | if not os.path.exists(ckpt_path + ".data-00000-of-00001"): 81 | ckpt_path = tf.train.latest_checkpoint(ckpt_dir) 82 | global_step = tf.train.get_global_step() 83 | ema = tf.train.ExponentialMovingAverage( 84 | decay=FLAGS.moving_average_decay, num_updates=global_step) 85 | ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars') 86 | for v in tf.global_variables(): 87 | # We maintain mva for batch norm moving mean and variance as well. 88 | if 'moving_mean' in v.name or 'moving_variance' in v.name: 89 | ema_vars.append(v) 90 | ema_vars = list(set(ema_vars)) 91 | restore_vars_dict = ema.variables_to_restore(ema_vars) 92 | ckpt_restorer = tf.train.Saver(restore_vars_dict) 93 | ckpt_restorer.restore(sess, ckpt_path) 94 | 95 | c1, c5 = 0, 0 96 | ds = get_val_dataflow(os.path.join(FLAGS.valdir, "..")) 97 | ds.reset_state() 98 | preds = [] 99 | labs = [] 100 | for i, (image, label) in enumerate(ds): 101 | # image, label = images[0], labels[0] 102 | logits_val = sess.run(logits, feed_dict={image_ph: image}) 103 | top5 = logits_val.squeeze().argsort()[::-1][:5] 104 | top1 = top5[0] 105 | if label == top1: 106 | c1 += 1 107 | if label in top5: 108 | c5 += 1 109 | preds.append(top1) 110 | labs.append(label) 111 | if (i + 1) % 1000 == 0: 112 | print('Test: [{0}/{1}]\t' 113 | 'Prec@1 {2:.1f}\t' 114 | 'Prec@5 {3:.1f}\t'.format( 115 | i + 1, len(ds), c1 / (i + 1.) * 100, c5 / (i + 1.) * 100)) 116 | 117 | 118 | if __name__ == '__main__': 119 | app.run(main) 120 | -------------------------------------------------------------------------------- /model_def.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for MnesNet model. 16 | 17 | [1] Mingxing Tan, Bo Chen, Ruoming Pang, Vijay Vasudevan, Quoc V. Le 18 | MnasNet: Platform-Aware Neural Architecture Search for Mobile. 19 | arXiv:1807.11626 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import collections 27 | import numpy as np 28 | import six 29 | from six.moves import xrange # pylint: disable=redefined-builtin 30 | import tensorflow as tf 31 | 32 | import custom_layers 33 | from args import FLAGS 34 | from model_profiling import module_profiling 35 | from tensorflow.python.keras import backend as K 36 | 37 | GlobalParams = collections.namedtuple('GlobalParams', [ 38 | 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'data_format', 39 | 'num_classes', 'depth_multiplier', 'depth_divisor', 'min_depth', 40 | ]) 41 | GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) 42 | 43 | # TODO(hongkuny): Consider rewrite an argument class with encoding/decoding. 44 | BlockArgs = collections.namedtuple('BlockArgs', [ 45 | 'dw_ksize', 'num_repeat', 'input_filters', 'output_filters', 46 | 'expand_ratio', 'id_skip', 'strides', 'se_ratio', 'non_local', 47 | 'expand_ksize', 'project_ksize', 'swish', 48 | ]) 49 | # defaults will be a public argument for namedtuple in Python 3.7 50 | # https://docs.python.org/3/library/collections.html#collections.namedtuple 51 | BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) 52 | 53 | 54 | def conv_kernel_initializer(shape, dtype=None, partition_info=None): 55 | """Initialization for convolutional kernels. 56 | 57 | The main difference with tf.variance_scaling_initializer is that 58 | tf.variance_scaling_initializer uses a truncated normal with an uncorrected 59 | standard deviation, whereas here we use a normal distribution. Similarly, 60 | tf.contrib.layers.variance_scaling_initializer uses a truncated normal with 61 | a corrected standard deviation. 62 | 63 | Args: 64 | shape: shape of variable 65 | dtype: dtype of variable 66 | partition_info: unused 67 | 68 | Returns: 69 | an initialization for the variable 70 | """ 71 | del partition_info 72 | kernel_height, kernel_width, _, out_filters = shape 73 | fan_out = int(kernel_height * kernel_width * out_filters) 74 | return tf.random_normal( 75 | shape, mean=0.0, stddev=np.sqrt(2.0 / fan_out), dtype=dtype) 76 | 77 | 78 | def dense_kernel_initializer(shape, dtype=None, partition_info=None): 79 | """Initialization for dense kernels. 80 | 81 | This initialization is equal to 82 | tf.variance_scaling_initializer(scale=1.0/3.0, mode='fan_out', 83 | distribution='uniform'). 84 | It is written out explicitly here for clarity. 85 | 86 | Args: 87 | shape: shape of variable 88 | dtype: dtype of variable 89 | partition_info: unused 90 | 91 | Returns: 92 | an initialization for the variable 93 | """ 94 | del partition_info 95 | init_range = 1.0 / np.sqrt(shape[1]) 96 | return tf.random_uniform(shape, -init_range, init_range, dtype=dtype) 97 | 98 | 99 | def round_filters(filters, global_params): 100 | """Round number of filters based on depth multiplier.""" 101 | multiplier = global_params.depth_multiplier 102 | divisor = global_params.depth_divisor 103 | min_depth = global_params.min_depth 104 | if not multiplier: 105 | return filters 106 | 107 | filters *= multiplier 108 | min_depth = min_depth or divisor 109 | new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) 110 | # Make sure that round down does not go down by more than 10%. 111 | if new_filters < 0.9 * filters: 112 | new_filters += divisor 113 | return new_filters 114 | 115 | 116 | class MnasBlock(object): 117 | """A class of MnasNet Inveretd Residual Bottleneck. 118 | 119 | Attributes: 120 | has_se: boolean. Whether the block contains a Squeeze and Excitation layer 121 | inside. 122 | endpoints: dict. A list of internal tensors. 123 | """ 124 | 125 | def __init__(self, block_args, global_params): 126 | """Initializes a MnasNet block. 127 | 128 | Args: 129 | block_args: BlockArgs, arguments to create a MnasBlock. 130 | global_params: GlobalParams, a set of global parameters. 131 | """ 132 | self._block_args = block_args 133 | self._batch_norm_momentum = global_params.batch_norm_momentum 134 | self._batch_norm_epsilon = global_params.batch_norm_epsilon 135 | self.data_format = global_params.data_format 136 | assert global_params.data_format == 'channels_last' 137 | self._channel_axis = -1 138 | self._spatial_dims = [1, 2] 139 | self.has_se = (self._block_args.se_ratio is not None) and ( 140 | self._block_args.se_ratio > 0) and (self._block_args.se_ratio <= 1) 141 | self._relu_fn = tf.nn.swish 142 | self.endpoints = None 143 | 144 | # Builds the block accordings to arguments. 145 | self._build() 146 | 147 | def _build(self): 148 | """Builds MnasNet block according to the arguments.""" 149 | filters = self._block_args.input_filters * self._block_args.expand_ratio 150 | if self._block_args.expand_ratio != 1: 151 | # Expansion phase: 152 | self._expand_conv = custom_layers.GroupedConv2D( 153 | filters, 154 | kernel_size=self._block_args.expand_ksize, 155 | strides=[1, 1], 156 | kernel_initializer=conv_kernel_initializer, 157 | padding='same', 158 | use_bias=False, data_format=self.data_format, 159 | use_keras=FLAGS.use_keras 160 | ) 161 | self._bn0 = tf.layers.BatchNormalization( 162 | axis=self._channel_axis, 163 | momentum=self._batch_norm_momentum, 164 | epsilon=self._batch_norm_epsilon, 165 | fused=True) 166 | 167 | kernel_size = self._block_args.dw_ksize 168 | # Depth-wise convolution phase: 169 | self._depthwise_conv = custom_layers.MDConv( 170 | kernel_size, 171 | strides=self._block_args.strides, 172 | depthwise_initializer=conv_kernel_initializer, 173 | padding='same', 174 | data_format=self.data_format, 175 | use_bias=False, dilated=False) 176 | self._bn1 = tf.layers.BatchNormalization( 177 | axis=self._channel_axis, 178 | momentum=self._batch_norm_momentum, 179 | epsilon=self._batch_norm_epsilon, 180 | fused=True) 181 | 182 | if self.has_se: 183 | num_reduced_filters = max( 184 | 1, int(self._block_args.input_filters * self._block_args.se_ratio)) 185 | # Squeeze and Excitation layer. 186 | self._se_reduce = tf.keras.layers.Conv2D( 187 | num_reduced_filters, 188 | kernel_size=[1, 1], 189 | strides=[1, 1], 190 | kernel_initializer=conv_kernel_initializer, 191 | padding='same', 192 | use_bias=True, data_format=self.data_format) 193 | self._se_expand = tf.keras.layers.Conv2D( 194 | filters, 195 | kernel_size=[1, 1], 196 | strides=[1, 1], 197 | kernel_initializer=conv_kernel_initializer, 198 | padding='same', 199 | use_bias=True, data_format=self.data_format) 200 | 201 | # Output phase: 202 | filters = self._block_args.output_filters 203 | self._project_conv = custom_layers.GroupedConv2D( 204 | filters, 205 | kernel_size=self._block_args.project_ksize, 206 | strides=[1, 1], 207 | kernel_initializer=conv_kernel_initializer, 208 | padding='same', 209 | use_bias=False, data_format=self.data_format, 210 | use_keras=FLAGS.use_keras) 211 | self._bn2 = tf.layers.BatchNormalization( 212 | axis=self._channel_axis, 213 | momentum=self._batch_norm_momentum, 214 | epsilon=self._batch_norm_epsilon, 215 | fused=True) 216 | if self._block_args.non_local: # this line still work, even if this value becomes a float, or a list 217 | self._non_local_conv = tf.keras.layers.DepthwiseConv2D( 218 | kernel_size=[3, 3], 219 | strides=[1, 1], 220 | kernel_initializer=conv_kernel_initializer, 221 | padding='same', 222 | use_bias=False, 223 | data_format=self.data_format) 224 | 225 | self._non_local_bn = tf.layers.BatchNormalization( 226 | axis=self._channel_axis, 227 | momentum=self._batch_norm_momentum, 228 | epsilon=self._batch_norm_epsilon, 229 | gamma_initializer=tf.zeros_initializer() if FLAGS.nl_zero_init else tf.ones_initializer, # this line is correct 230 | fused=True) 231 | 232 | def _call_se(self, input_tensor): 233 | """Call Squeeze and Excitation layer. 234 | 235 | Args: 236 | input_tensor: Tensor, a single input tensor for Squeeze/Excitation layer. 237 | 238 | Returns: 239 | A output tensor, which should have the same shape as input. 240 | """ 241 | macs = 0. 242 | se_tensor = tf.reduce_mean(input_tensor, self._spatial_dims, keepdims=True) 243 | macs += module_profiling(tf.keras.layers.GlobalAveragePooling2D(), input_tensor, se_tensor, False) 244 | s_tensor = self._relu_fn(self._se_reduce(se_tensor)) 245 | macs += module_profiling(self._se_reduce, se_tensor, s_tensor, False) 246 | e_tensor = self._se_expand(s_tensor) 247 | macs += module_profiling(self._se_expand, s_tensor, e_tensor, False) 248 | tf.logging.info('Built Squeeze and Excitation with tensor shape: %s' % e_tensor.shape) 249 | return tf.sigmoid(e_tensor) * input_tensor, macs 250 | 251 | def _call_non_local(self, l, training=True, nl_ratio=1.0, nl_stride=1): 252 | def reduce_func(l, nl_stride): 253 | return l[:, ::nl_stride, ::nl_stride, :], 0 254 | 255 | total_macs = 0. 256 | tf.logging.info('Block input: %s shape: %s' % (l.name, l.shape)) 257 | f, macs = non_local_op(l, embed=False, softmax=False, nl_ratio=nl_ratio, nl_stride=nl_stride, 258 | reduce_func=reduce_func) 259 | total_macs += macs 260 | f_output = self._non_local_conv(f) 261 | macs = module_profiling(self._non_local_conv, f, f_output, False) 262 | total_macs += macs 263 | f = self._non_local_bn(f_output, training=training) 264 | l = l + f 265 | tf.logging.info('Non-local: %s shape: %s' % (l.name, l.shape)) 266 | return l, total_macs 267 | 268 | def call(self, inputs, training=True): 269 | """Implementation of MnasBlock call(). 270 | 271 | Args: 272 | inputs: the inputs tensor. 273 | training: boolean, whether the model is constructed for training. 274 | 275 | Returns: 276 | A output tensor. 277 | """ 278 | total_macs = 0. 279 | tf.logging.info('Block input: %s shape: %s' % (inputs.name, inputs.shape)) 280 | if self._block_args.expand_ratio != 1: 281 | outputs_expand_conv = self._expand_conv(inputs) 282 | total_macs += module_profiling(self._expand_conv, inputs, outputs_expand_conv, False) # compute macs 283 | x = self._relu_fn(self._bn0(outputs_expand_conv, training=training)) 284 | else: 285 | x = inputs 286 | tf.logging.info('Expand: %s shape: %s' % (x.name, x.shape)) 287 | 288 | outputs_depthwise_conv = self._depthwise_conv(x) 289 | total_macs += module_profiling(self._depthwise_conv, x, outputs_depthwise_conv, False) # compute macs 290 | x = self._relu_fn(self._bn1(outputs_depthwise_conv, training=training)) 291 | tf.logging.info('DWConv: %s shape: %s' % (x.name, x.shape)) 292 | 293 | if self.has_se: 294 | with tf.variable_scope('se'): 295 | x, macs = self._call_se(x) 296 | total_macs += macs 297 | # raise NotImplementedError 298 | 299 | self.endpoints = {'expansion_output': x} 300 | 301 | outputs_project_conv = self._project_conv(x) 302 | total_macs += module_profiling(self._project_conv, x, outputs_project_conv, False) # compute macs 303 | x = self._bn2(outputs_project_conv, training=training) 304 | 305 | if self._block_args.non_local: 306 | with tf.variable_scope('nl'): 307 | x, macs = self._call_non_local(x, training=training, nl_ratio=self._block_args.non_local[0], 308 | nl_stride=self._block_args.non_local[1]) 309 | total_macs += macs 310 | 311 | if self._block_args.id_skip: 312 | if all( 313 | s == 1 for s in self._block_args.strides 314 | ) and self._block_args.input_filters == self._block_args.output_filters: 315 | x = tf.add(x, inputs) 316 | total_macs += module_profiling(tf.add, x, inputs, False) # compute macs 317 | tf.logging.info('Project: %s shape: %s' % (x.name, x.shape)) 318 | 319 | return x, total_macs 320 | 321 | 322 | def non_local_op(l, embed, softmax, nl_ratio=1.0, nl_stride=1, reduce_func=None): 323 | H, W, n_in = l.shape.as_list()[1:] 324 | reduced_HW = (H // nl_stride) * (W // nl_stride) 325 | if embed: 326 | raise NotImplementedError 327 | else: 328 | if nl_stride == 1: 329 | l_reduced = l 330 | reduce_macs = 0 331 | else: 332 | assert reduce_func is not None 333 | l_reduced, reduce_macs = reduce_func(l, nl_stride) 334 | 335 | theta, phi, g = l[:, :, :, :int(nl_ratio * n_in)], l_reduced[:, :, :, :int(nl_ratio * n_in)], l_reduced 336 | 337 | if (H * W) * reduced_HW * n_in * (1 + nl_ratio) < ( 338 | H * W) * n_in ** 2 * nl_ratio + reduced_HW * n_in ** 2 * nl_ratio or softmax: 339 | f = tf.einsum('nabi,ncdi->nabcd', theta, phi) 340 | if softmax: 341 | raise NotImplementedError 342 | f = tf.einsum('nabcd,ncdi->nabi', f, g) 343 | # macs = (H * W) ** 2 * n_in * nl_ratio + (H * W) ** 2 * n_in 344 | macs = (H * W) * reduced_HW * n_in * (1 + nl_ratio) 345 | else: 346 | f = tf.einsum('nhwi,nhwj->nij', phi, g) 347 | f = tf.einsum('nij,nhwi->nhwj', f, theta) 348 | # macs = (H * W) * n_in ** 2 * 2 * nl_ratio 349 | macs = (H * W) * n_in ** 2 * nl_ratio + reduced_HW * n_in ** 2 * nl_ratio 350 | if not softmax: 351 | f = f / tf.cast(H * W, f.dtype) 352 | return tf.reshape(f, tf.shape(l)), macs + reduce_macs 353 | 354 | 355 | class MnasNetModel(tf.keras.Model): 356 | """A class implements tf.keras.Model for MnesNet model. 357 | 358 | Reference: https://arxiv.org/abs/1807.11626 359 | """ 360 | 361 | def __init__(self, blocks_args=None, global_params=None): 362 | """Initializes an `MnasNetModel` instance. 363 | 364 | Args: 365 | blocks_args: A list of BlockArgs to construct MnasNet block modules. 366 | global_params: GlobalParams, a set of global parameters. 367 | 368 | Raises: 369 | ValueError: when blocks_args is not specified as a list. 370 | """ 371 | super(MnasNetModel, self).__init__() 372 | if not isinstance(blocks_args, list): 373 | raise ValueError('blocks_args should be a list.') 374 | self._global_params = global_params 375 | self._blocks_args = blocks_args 376 | # Use relu in default for head and stem. 377 | self._relu_fn = tf.nn.swish 378 | self.endpoints = None 379 | self._build() 380 | 381 | def _build(self): 382 | """Builds a MnasNet model.""" 383 | self._blocks = [] 384 | # Builds blocks. 385 | for block_args in self._blocks_args: 386 | assert block_args.num_repeat > 0 387 | # Update block input and output filters based on depth multiplier. 388 | block_args = block_args._replace( 389 | input_filters=round_filters(block_args.input_filters, 390 | self._global_params), 391 | output_filters=round_filters(block_args.output_filters, 392 | self._global_params)) 393 | 394 | # The first block needs to take care of stride and filter size increase. 395 | self._blocks.append(MnasBlock(block_args, self._global_params)) # removed kernel mask here 396 | if block_args.num_repeat > 1: 397 | # pylint: disable=protected-access 398 | block_args = block_args._replace( 399 | input_filters=block_args.output_filters, strides=[1, 1]) 400 | # pylint: enable=protected-access 401 | for _ in xrange(block_args.num_repeat - 1): 402 | self._blocks.append(MnasBlock(block_args, self._global_params)) # removed kernel mask here 403 | 404 | batch_norm_momentum = self._global_params.batch_norm_momentum 405 | batch_norm_epsilon = self._global_params.batch_norm_epsilon 406 | if self._global_params.data_format == 'channels_first': 407 | channel_axis = 1 408 | else: 409 | channel_axis = -1 410 | 411 | # Stem part. 412 | stem_size = 32 413 | self._conv_stem = tf.keras.layers.Conv2D( 414 | filters=round_filters(stem_size, self._global_params), 415 | kernel_size=[3, 3], 416 | strides=[2, 2], 417 | kernel_initializer=conv_kernel_initializer, 418 | padding='same', 419 | use_bias=False, data_format=self._global_params.data_format 420 | ) 421 | self._bn0 = tf.layers.BatchNormalization( 422 | axis=channel_axis, 423 | momentum=batch_norm_momentum, 424 | epsilon=batch_norm_epsilon, 425 | fused=True) 426 | 427 | # Head part. 428 | self._conv_head = tf.keras.layers.Conv2D( 429 | filters=1280, 430 | kernel_size=[1, 1], 431 | strides=[1, 1], 432 | kernel_initializer=conv_kernel_initializer, 433 | padding='same', 434 | use_bias=False, data_format=self._global_params.data_format) 435 | self._bn1 = tf.layers.BatchNormalization( 436 | axis=channel_axis, 437 | momentum=batch_norm_momentum, 438 | epsilon=batch_norm_epsilon, 439 | fused=True) 440 | 441 | self._avg_pooling = tf.keras.layers.GlobalAveragePooling2D( 442 | data_format=self._global_params.data_format) 443 | self._fc = tf.keras.layers.Dense( 444 | self._global_params.num_classes, 445 | kernel_initializer=dense_kernel_initializer) 446 | 447 | if self._global_params.dropout_rate > 0: 448 | self._dropout = tf.keras.layers.Dropout(self._global_params.dropout_rate) 449 | else: 450 | self._dropout = None 451 | 452 | def call(self, inputs, training=True): 453 | """Implementation of MnasNetModel call(). 454 | 455 | Args: 456 | inputs: input tensors. 457 | training: boolean, whether the model is constructed for training. 458 | 459 | Returns: 460 | output tensors. 461 | """ 462 | outputs = None 463 | self.endpoints = {} 464 | total_macs = 0. 465 | # Calls Stem layers 466 | with tf.variable_scope('mnas_stem'): 467 | outputs_conv_stem = self._conv_stem(inputs) 468 | total_macs += module_profiling(self._conv_stem, inputs, outputs_conv_stem, False) # compute macs 469 | outputs = self._relu_fn(self._bn0(outputs_conv_stem, training=training)) 470 | 471 | tf.logging.info('Built stem layers with output shape: %s' % outputs.shape) 472 | self.endpoints['stem'] = outputs 473 | # Calls blocks. 474 | for idx, block in enumerate(self._blocks): 475 | with tf.variable_scope('mnas_blocks_%s' % idx): 476 | outputs, n_macs = block.call(outputs, training=training) 477 | total_macs += n_macs 478 | self.endpoints['block_%s' % idx] = outputs 479 | if block.endpoints: 480 | for k, v in six.iteritems(block.endpoints): 481 | self.endpoints['block_%s/%s' % (idx, k)] = v 482 | # Calls final layers and returns logits. 483 | with tf.variable_scope('mnas_head'): 484 | outputs_conv_head = self._conv_head(outputs) 485 | total_macs += module_profiling(self._conv_head, outputs, outputs_conv_head, False) # compute macs 486 | outputs = self._relu_fn(self._bn1(outputs_conv_head, training=training)) 487 | self.endpoints['cam_feature'] = outputs 488 | 489 | outputs_avg_pooling = self._avg_pooling(outputs) 490 | total_macs += module_profiling(self._avg_pooling, outputs, outputs_avg_pooling, False) # compute macs 491 | outputs = outputs_avg_pooling 492 | 493 | if self._dropout: 494 | outputs = self._dropout(outputs, training=training) 495 | 496 | outputs_fc = self._fc(outputs) 497 | total_macs += module_profiling(self._fc, outputs, outputs_fc, False) # compute macs 498 | outputs = outputs_fc 499 | self.endpoints['head'] = outputs 500 | self.endpoints['_fc'] = self._fc 501 | return outputs, total_macs 502 | -------------------------------------------------------------------------------- /model_profiling.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import custom_layers 4 | import model_def 5 | 6 | 7 | def module_profiling(self, input, output, verbose): 8 | """ 9 | only support NHWC data format 10 | :param self: 11 | :param input: 12 | :param output: 13 | :param verbose: 14 | :return: 15 | """ 16 | in_size = input.shape.as_list() 17 | out_size = output.shape.as_list() 18 | if isinstance(self, custom_layers.GroupedConv2D): 19 | n_macs = 0 20 | for _conv in self._convs: 21 | kernel_size = _conv.kernel.shape.as_list() 22 | n_macs += kernel_size[0] * kernel_size[1] * kernel_size[2] * kernel_size[3] * out_size[1] * out_size[2] 23 | elif isinstance(self, custom_layers.MDConv): 24 | n_macs = 0 25 | for _conv in self._convs: 26 | kernel_size = _conv.depthwise_kernel.shape.as_list() 27 | n_macs += kernel_size[0] * kernel_size[1] * kernel_size[2] * kernel_size[3] * out_size[1] * out_size[2] 28 | elif isinstance(self, tf.keras.layers.DepthwiseConv2D): 29 | kernel_size = self.depthwise_kernel.shape.as_list() 30 | n_macs = kernel_size[0] * kernel_size[1] * kernel_size[2] * kernel_size[3] * out_size[1] * out_size[2] 31 | elif isinstance(self, tf.keras.layers.Conv2D): 32 | kernel_size = self.kernel.shape.as_list() 33 | n_macs = kernel_size[0] * kernel_size[1] * kernel_size[2] * kernel_size[3] * out_size[1] * out_size[2] 34 | elif isinstance(self, tf.keras.layers.GlobalAveragePooling2D): 35 | assert in_size[-1] == out_size[-1] 36 | n_macs = in_size[1] * in_size[2] * in_size[3] 37 | elif isinstance(self, tf.keras.layers.Dense): 38 | n_macs = in_size[1] * out_size[1] 39 | elif self == tf.add: 40 | # n_macs = in_size[1] * in_size[2] * in_size[3] 41 | n_macs = 0 # other people don't take skip connections into consideration 42 | else: 43 | raise NotImplementedError 44 | return n_macs 45 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # author: dstamoulis 2 | # 3 | # This code extends codebase from the "MNasNet on TPU" GitHub repo: 4 | # https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet 5 | # 6 | # This project incorporates material from the project listed above, and it 7 | # is accessible under their original license terms (Apache License 2.0) 8 | # ============================================================================== 9 | """Creates the ConvNet found model by parsing the NAS-decision values 10 | from the provided NAS-search output dir.""" 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import os 17 | import re 18 | import tensorflow as tf 19 | import numpy as np 20 | 21 | import model_def 22 | from args import FLAGS 23 | 24 | 25 | class MnasNetDecoder(object): 26 | """A class of MnasNet decoder to get model configuration.""" 27 | 28 | def _decode_block_string(self, block_string): 29 | """Gets a MNasNet block through a string notation of arguments. 30 | 31 | E.g. r2_k3_s2_e1_i32_o16_se0.25_noskip: r - number of repeat blocks, 32 | k - kernel size, s - strides (1-9), e - expansion ratio, i - input filters, 33 | o - output filters, se - squeeze/excitation ratio 34 | 35 | Args: 36 | block_string: a string, a string representation of block arguments. 37 | 38 | Returns: 39 | A BlockArgs instance. 40 | Raises: 41 | ValueError: if the strides option is not correctly specified. 42 | """ 43 | assert isinstance(block_string, str) 44 | ops = block_string.split('_') 45 | options = {} 46 | for op in ops: 47 | if op == 'nonlocal': 48 | op = 'nonlocal1.0' 49 | splits = re.split(r'(\d.*)', op) 50 | if len(splits) >= 2: 51 | key, value = splits[:2] 52 | options[key] = value 53 | 54 | if 's' not in options or len(options['s']) != 2: 55 | raise ValueError('Strides options should be a pair of integers.') 56 | 57 | def _parse_ksize(ss): 58 | return [int(k) for k in ss.split('.')] 59 | 60 | def _parse_nonlocal(ss): 61 | ss = ss.split("-") 62 | if len(ss) == 2: 63 | return [float(ss[0]), int(ss[1])] 64 | else: 65 | assert len(ss) == 1 66 | return [float(ss[0]), 1] 67 | 68 | BlockArgs = model_def.BlockArgs 69 | 70 | return BlockArgs( 71 | expand_ksize=_parse_ksize(options['a']) if 'a' in options else [1], 72 | dw_ksize=_parse_ksize(options['k']), 73 | project_ksize=_parse_ksize(options['p']) if 'p' in options else [1], 74 | num_repeat=int(options['r']), 75 | input_filters=int(options['i']), 76 | output_filters=int(options['o']), 77 | expand_ratio=int(options['e']), 78 | id_skip=('noskip' not in block_string), 79 | se_ratio=float(options['se']) if 'se' in options else None, 80 | strides=[int(options['s'][0]), int(options['s'][1])], 81 | swish=('sw' in block_string), 82 | non_local=_parse_nonlocal(options['nonlocal']) if 'nonlocal' in options else 0.0 83 | ) 84 | 85 | def decode(self, string_list): 86 | """Decodes a list of string notations to specify blocks inside the network. 87 | 88 | Args: 89 | string_list: a list of strings, each string is a notation of MnasNet 90 | block. 91 | 92 | Returns: 93 | A list of namedtuples to represent MnasNet blocks arguments. 94 | """ 95 | assert isinstance(string_list, list) 96 | blocks_args = [] 97 | for block_string in string_list: 98 | blocks_args.append(self._decode_block_string(block_string)) 99 | return blocks_args 100 | 101 | 102 | def parse_netarch_string(blocks_args, depth_multiplier=None): 103 | decoder = MnasNetDecoder() 104 | global_params = model_def.GlobalParams( 105 | batch_norm_momentum=0.99, 106 | batch_norm_epsilon=1e-3, 107 | dropout_rate=0.2, 108 | data_format='channels_last', 109 | num_classes=1000, 110 | depth_multiplier=depth_multiplier, 111 | depth_divisor=8, 112 | min_depth=None) 113 | return decoder.decode(blocks_args), global_params 114 | 115 | 116 | def build_model(images, training, override_params=None, arch=None): 117 | """A helper functiion to creates a ConvNet model and returns predicted logits. 118 | 119 | Args: 120 | images: input images tensor. 121 | training: boolean, whether the model is constructed for training. 122 | override_params: A dictionary of params for overriding. Fields must exist in 123 | model_def.GlobalParams. 124 | 125 | Returns: 126 | logits: the logits tensor of classes. 127 | endpoints: the endpoints for each layer. 128 | Raises: 129 | When override_params has invalid fields, raises ValueError. 130 | """ 131 | assert isinstance(images, tf.Tensor) 132 | assert os.path.isfile(arch) 133 | with open(arch, 'r') as f: 134 | lines = f.readlines() 135 | lines = [line.strip() for line in lines] 136 | blocks_args, global_params = parse_netarch_string(lines) 137 | 138 | if override_params: 139 | global_params = global_params._replace(**override_params) 140 | 141 | with tf.variable_scope('single-path'): 142 | model = model_def.MnasNetModel(blocks_args, global_params) 143 | logits, macs = model(images, training=training) 144 | macs /= 1e6 # macs to M 145 | 146 | logits = tf.identity(logits, 'logits') 147 | 148 | return logits, model.endpoints, macs 149 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiYingwei/AutoNL/6fec21431d4adb2bea86f36b0d3106465f881570/teaser.png --------------------------------------------------------------------------------