├── .gitignore ├── .idea └── vcs.xml ├── LICENSE ├── README.md ├── images ├── dense_vs_sparse.png └── sparse_connectivity.PNG ├── sparsenet.py └── train_cifar10.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Somshubra Majumdar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Networks in Keras 2 | Keras Implementation of Sparse Networks from the paper [Sparsely Connected Convolutional Networks](https://arxiv.org/abs/1801.05895). 3 | 4 | Code derived from the offical repository - https://github.com/Lyken17/SparseNet 5 | 6 | # Sparse Networks 7 | SparseNet is a variant of DenseNets. While DenseNets have a skip connection after every block in its dense structure, SparseNets have such skip connections only at depths of 2^N (with exponential offsets rather than a static linear offset). DenseNets posses *O(n^2)* skip connections for every dense block, whereas SparseNets have only *O(log n)* skip connections in each of its sparse blocks. 8 | 9 | This allows models which are **much less memory intensive**, while still performing at the level / even surpassing DenseNets, with fewer parameters. 10 | 11 | # Sparse Connectivity 12 | 13 | 14 | The above image from the paper shows that each input at the end only requires *log2 n* input connections. 15 | 16 | # Difference between DenseNets and SparseNets 17 | 18 | 19 | This image from their paper shows the major difference between the connectivity pattern in SparseNets vs ResNets/DenseNets. 20 | 21 | # Caveats 22 | There is a small discrepancy in the number of parameters between the paper and this repo. 23 | 24 | - SparseNet-40-24 (Keras = 0.74 M, paper = 0.76 M) 25 | - SparseNet-100-24 (Keras = 2.50 M, paper = 2.52 M) 26 | 27 | If anyone can figure out the cause of this discrepancy, I'd be grateful. 28 | 29 | # Requirements 30 | 31 | - Keras 2.1.3 32 | - Tensorflow / Theano / CNTK (I am assuming since all frameworks support ResNets, they should be able to support this as well without any modification) 33 | -------------------------------------------------------------------------------- /images/dense_vs_sparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-SparseNet/e07358c50017bd566745b375bc192880ff649b1e/images/dense_vs_sparse.png -------------------------------------------------------------------------------- /images/sparse_connectivity.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/titu1994/keras-SparseNet/e07358c50017bd566745b375bc192880ff649b1e/images/sparse_connectivity.PNG -------------------------------------------------------------------------------- /sparsenet.py: -------------------------------------------------------------------------------- 1 | '''SparseNet models for Keras. 2 | # Reference 3 | - [Sparsely Connected Convolutional Networks](https://arxiv.org/abs/1801.05895) 4 | - [Github](https://github.com/lyken17/sparsenet) 5 | ''' 6 | from __future__ import print_function 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | 10 | import numpy as np 11 | import warnings 12 | 13 | from keras.models import Model 14 | from keras.layers.core import Dense, Dropout, Activation, Reshape 15 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D 16 | from keras.layers.pooling import AveragePooling2D, MaxPooling2D 17 | from keras.layers.pooling import GlobalAveragePooling2D 18 | from keras.layers import Input 19 | from keras.layers.merge import concatenate 20 | from keras.layers.normalization import BatchNormalization 21 | from keras.regularizers import l2 22 | from keras.utils.layer_utils import convert_all_kernels_in_model, convert_dense_weights_data_format 23 | from keras.utils.data_utils import get_file 24 | from keras.engine.topology import get_source_inputs 25 | from keras.applications.imagenet_utils import _obtain_input_shape 26 | from keras.applications.imagenet_utils import decode_predictions 27 | import keras.backend as K 28 | 29 | 30 | def preprocess_input(x, data_format=None): 31 | """Preprocesses a tensor encoding a batch of images. 32 | 33 | # Arguments 34 | x: input Numpy tensor, 4D. 35 | data_format: data format of the image tensor. 36 | 37 | # Returns 38 | Preprocessed tensor. 39 | """ 40 | if data_format is None: 41 | data_format = K.image_data_format() 42 | assert data_format in {'channels_last', 'channels_first'} 43 | 44 | if data_format == 'channels_first': 45 | if x.ndim == 3: 46 | # 'RGB'->'BGR' 47 | x = x[::-1, ...] 48 | # Zero-center by mean pixel 49 | x[0, :, :] -= 103.939 50 | x[1, :, :] -= 116.779 51 | x[2, :, :] -= 123.68 52 | else: 53 | x = x[:, ::-1, ...] 54 | x[:, 0, :, :] -= 103.939 55 | x[:, 1, :, :] -= 116.779 56 | x[:, 2, :, :] -= 123.68 57 | else: 58 | # 'RGB'->'BGR' 59 | x = x[..., ::-1] 60 | # Zero-center by mean pixel 61 | x[..., 0] -= 103.939 62 | x[..., 1] -= 116.779 63 | x[..., 2] -= 123.68 64 | 65 | x *= 0.017 # scale values 66 | 67 | return x 68 | 69 | 70 | def SparseNet(input_shape=None, depth=40, nb_dense_block=3, growth_rate=12, nb_filter=-1, nb_layers_per_block=-1, 71 | bottleneck=False, reduction=0.0, dropout_rate=0.0, weight_decay=1e-4, subsample_initial_block=False, 72 | include_top=True, weights=None, input_tensor=None, 73 | classes=10, activation='softmax'): 74 | '''Instantiate the SparseNet architecture, 75 | optionally loading weights pre-trained 76 | on CIFAR-10. Note that when using TensorFlow, 77 | for best performance you should set 78 | `image_data_format='channels_last'` in your Keras config 79 | at ~/.keras/keras.json. 80 | The model and the weights are compatible with both 81 | TensorFlow and Theano. The dimension ordering 82 | convention used by the model is the one 83 | specified in your Keras config file. 84 | # Arguments 85 | input_shape: optional shape tuple, only to be specified 86 | if `include_top` is False (otherwise the input shape 87 | has to be `(32, 32, 3)` (with `channels_last` dim ordering) 88 | or `(3, 32, 32)` (with `channels_first` dim ordering). 89 | It should have exactly 3 inputs channels, 90 | and width and height should be no smaller than 8. 91 | E.g. `(200, 200, 3)` would be one valid value. 92 | depth: number or layers in the DenseNet 93 | nb_dense_block: number of dense blocks to add to end (generally = 3) 94 | growth_rate: number of filters to add per dense block. Can be 95 | a single integer number or a list of numbers. 96 | If it is a list, length of list must match the length of 97 | `nb_layers_per_block` 98 | nb_filter: initial number of filters. -1 indicates initial 99 | number of filters is 2 * growth_rate 100 | nb_layers_per_block: number of layers in each dense block. 101 | Can be a -1, positive integer or a list. 102 | If -1, calculates nb_layer_per_block from the network depth. 103 | If positive integer, a set number of layers per dense block. 104 | If list, nb_layer is used as provided. Note that list size must 105 | be (nb_dense_block + 1) 106 | bottleneck: flag to add bottleneck blocks in between dense blocks 107 | reduction: reduction factor of transition blocks. 108 | Note : reduction value is inverted to compute compression. 109 | dropout_rate: dropout rate 110 | weight_decay: weight decay rate 111 | subsample_initial_block: Set to True to subsample the initial convolution and 112 | add a MaxPool2D before the dense blocks are added. 113 | include_top: whether to include the fully-connected 114 | layer at the top of the network. 115 | weights: one of `None` (random initialization) or 116 | 'imagenet' (pre-training on ImageNet).. 117 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 118 | to use as image input for the model. 119 | classes: optional number of classes to classify images 120 | into, only to be specified if `include_top` is True, and 121 | if no `weights` argument is specified. 122 | activation: Type of activation at the top layer. Can be one of 'softmax' or 'sigmoid'. 123 | Note that if sigmoid is used, classes must be 1. 124 | # Returns 125 | A Keras model instance. 126 | ''' 127 | 128 | if weights not in {'imagenet', None}: 129 | raise ValueError('The `weights` argument should be either ' 130 | '`None` (random initialization) or `cifar10` ' 131 | '(pre-training on CIFAR-10).') 132 | 133 | if weights == 'imagenet' and include_top and classes != 1000: 134 | raise ValueError('If using `weights` as ImageNet with `include_top`' 135 | ' as true, `classes` should be 1000') 136 | 137 | if activation not in ['softmax', 'sigmoid']: 138 | raise ValueError('activation must be one of "softmax" or "sigmoid"') 139 | 140 | if activation == 'sigmoid' and classes != 1: 141 | raise ValueError('sigmoid activation can only be used when classes = 1') 142 | 143 | # Determine proper input shape 144 | input_shape = _obtain_input_shape(input_shape, 145 | default_size=32, 146 | min_size=8, 147 | data_format=K.image_data_format(), 148 | require_flatten=include_top) 149 | 150 | if input_tensor is None: 151 | img_input = Input(shape=input_shape) 152 | else: 153 | if not K.is_keras_tensor(input_tensor): 154 | img_input = Input(tensor=input_tensor, shape=input_shape) 155 | else: 156 | img_input = input_tensor 157 | 158 | x = _create_dense_net(classes, img_input, include_top, depth, nb_dense_block, 159 | growth_rate, nb_filter, nb_layers_per_block, bottleneck, reduction, 160 | dropout_rate, weight_decay, subsample_initial_block, activation) 161 | 162 | # Ensure that the model takes into account 163 | # any potential predecessors of `input_tensor`. 164 | if input_tensor is not None: 165 | inputs = get_source_inputs(input_tensor) 166 | else: 167 | inputs = img_input 168 | # Create model. 169 | model = Model(inputs, x, name='densenet') 170 | 171 | # load weights 172 | if weights == 'imagenet': 173 | weights_loaded = False 174 | 175 | if weights_loaded: 176 | if K.backend() == 'theano': 177 | convert_all_kernels_in_model(model) 178 | 179 | if K.image_data_format() == 'channels_first' and K.backend() == 'tensorflow': 180 | warnings.warn('You are using the TensorFlow backend, yet you ' 181 | 'are using the Theano ' 182 | 'image data format convention ' 183 | '(`image_data_format="channels_first"`). ' 184 | 'For best performance, set ' 185 | '`image_data_format="channels_last"` in ' 186 | 'your Keras config ' 187 | 'at ~/.keras/keras.json.') 188 | 189 | print("Weights for the model were loaded successfully") 190 | 191 | return model 192 | 193 | 194 | def SparseNetImageNet121(input_shape=None, 195 | bottleneck=True, 196 | reduction=0.5, 197 | dropout_rate=0.0, 198 | weight_decay=1e-4, 199 | include_top=True, 200 | weights=None, 201 | input_tensor=None, 202 | classes=1000, 203 | activation='softmax'): 204 | return SparseNet(input_shape, depth=121, nb_dense_block=4, growth_rate=32, nb_filter=64, 205 | nb_layers_per_block=[6, 12, 24, 16], bottleneck=bottleneck, reduction=reduction, 206 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 207 | include_top=include_top, weights=weights, input_tensor=input_tensor, 208 | classes=classes, activation=activation) 209 | 210 | 211 | def SparseNetImageNet169(input_shape=None, 212 | bottleneck=True, 213 | reduction=0.5, 214 | dropout_rate=0.0, 215 | weight_decay=1e-4, 216 | include_top=True, 217 | weights=None, 218 | input_tensor=None, 219 | classes=1000, 220 | activation='softmax'): 221 | return SparseNet(input_shape, depth=169, nb_dense_block=4, growth_rate=32, nb_filter=64, 222 | nb_layers_per_block=[6, 12, 32, 32], bottleneck=bottleneck, reduction=reduction, 223 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 224 | include_top=include_top, weights=weights, input_tensor=input_tensor, 225 | classes=classes, activation=activation) 226 | 227 | 228 | def SparseNetImageNet201(input_shape=None, 229 | bottleneck=True, 230 | reduction=0.5, 231 | dropout_rate=0.0, 232 | weight_decay=1e-4, 233 | include_top=True, 234 | weights=None, 235 | input_tensor=None, 236 | classes=1000, 237 | activation='softmax'): 238 | return SparseNet(input_shape, depth=201, nb_dense_block=4, growth_rate=32, nb_filter=64, 239 | nb_layers_per_block=[6, 12, 48, 32], bottleneck=bottleneck, reduction=reduction, 240 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 241 | include_top=include_top, weights=weights, input_tensor=input_tensor, 242 | classes=classes, activation=activation) 243 | 244 | 245 | def SparseNetImageNet264(input_shape=None, 246 | bottleneck=True, 247 | reduction=0.5, 248 | dropout_rate=0.0, 249 | weight_decay=1e-4, 250 | include_top=True, 251 | weights=None, 252 | input_tensor=None, 253 | classes=1000, 254 | activation='softmax'): 255 | return SparseNet(input_shape, depth=264, nb_dense_block=4, growth_rate=32, nb_filter=64, 256 | nb_layers_per_block=[6, 12, 64, 48], bottleneck=bottleneck, reduction=reduction, 257 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 258 | include_top=include_top, weights=weights, input_tensor=input_tensor, 259 | classes=classes, activation=activation) 260 | 261 | 262 | def SparseNetImageNet161(input_shape=None, 263 | bottleneck=True, 264 | reduction=0.5, 265 | dropout_rate=0.0, 266 | weight_decay=1e-4, 267 | include_top=True, 268 | weights=None, 269 | input_tensor=None, 270 | classes=1000, 271 | activation='softmax'): 272 | return SparseNet(input_shape, depth=161, nb_dense_block=4, growth_rate=48, nb_filter=96, 273 | nb_layers_per_block=[6, 12, 36, 24], bottleneck=bottleneck, reduction=reduction, 274 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True, 275 | include_top=include_top, weights=weights, input_tensor=input_tensor, 276 | classes=classes, activation=activation) 277 | 278 | 279 | def _exponential_index_fetch(x_list): 280 | count = len(x_list) 281 | i = 1 282 | inputs = [] 283 | while i <= count: 284 | inputs.append(x_list[count - i]) 285 | i *= 2 286 | return inputs 287 | 288 | 289 | def _conv_block(ip, nb_filter, bottleneck=False, dropout_rate=None, weight_decay=1e-4): 290 | ''' Apply BatchNorm, Relu, 3x3 Conv2D, optional bottleneck block and dropout 291 | Args: 292 | ip: Input keras tensor 293 | nb_filter: number of filters 294 | bottleneck: add bottleneck block 295 | dropout_rate: dropout rate 296 | weight_decay: weight decay factor 297 | Returns: keras tensor with batch_norm, relu and convolution2d added (optional bottleneck) 298 | ''' 299 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 300 | 301 | with K.name_scope('conv_block'): 302 | x = BatchNormalization(axis=concat_axis, momentum=0.1, epsilon=1e-5)(ip) 303 | x = Activation('relu')(x) 304 | 305 | if bottleneck: 306 | inter_channel = nb_filter * 4 # Obtained from https://github.com/liuzhuang13/DenseNet/blob/master/densenet.lua 307 | 308 | x = Conv2D(inter_channel, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False, 309 | kernel_regularizer=l2(weight_decay))(x) 310 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(x) 311 | x = Activation('relu')(x) 312 | 313 | x = Conv2D(nb_filter, (3, 3), kernel_initializer='he_normal', padding='same', use_bias=False)(x) 314 | if dropout_rate: 315 | x = Dropout(dropout_rate)(x) 316 | 317 | return x 318 | 319 | 320 | def _dense_block(x, nb_layers, nb_filter, growth_rate, bottleneck=False, dropout_rate=None, weight_decay=1e-4, 321 | grow_nb_filters=True, return_concat_list=False): 322 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones 323 | Args: 324 | x: keras tensor 325 | nb_layers: the number of layers of conv_block to append to the model. 326 | nb_filter: number of filters 327 | growth_rate: growth rate 328 | bottleneck: bottleneck block 329 | dropout_rate: dropout rate 330 | weight_decay: weight decay factor 331 | grow_nb_filters: flag to decide to allow number of filters to grow 332 | return_concat_list: return the list of feature maps along with the actual output 333 | Returns: keras tensor with nb_layers of conv_block appended 334 | ''' 335 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 336 | 337 | x_list = [x] 338 | channel_list = [nb_filter] 339 | 340 | for i in range(nb_layers): 341 | #nb_channels = sum(_exponential_index_fetch(channel_list)) 342 | 343 | x = _conv_block(x, growth_rate, bottleneck, dropout_rate, weight_decay) 344 | x_list.append(x) 345 | 346 | fetch_outputs = _exponential_index_fetch(x_list) 347 | x = concatenate(fetch_outputs, axis=concat_axis) 348 | 349 | channel_list.append(growth_rate) 350 | 351 | if grow_nb_filters: 352 | nb_filter = sum(_exponential_index_fetch(channel_list)) 353 | 354 | if return_concat_list: 355 | return x, nb_filter, x_list 356 | else: 357 | return x, nb_filter 358 | 359 | 360 | def _transition_block(ip, nb_filter, compression=1.0, weight_decay=1e-4): 361 | ''' Apply BatchNorm, Relu 1x1, Conv2D, optional compression, dropout and Maxpooling2D 362 | Args: 363 | ip: keras tensor 364 | nb_filter: number of filters 365 | compression: calculated as 1 - reduction. Reduces the number of feature maps 366 | in the transition block. 367 | dropout_rate: dropout rate 368 | weight_decay: weight decay factor 369 | Returns: keras tensor, after applying batch_norm, relu-conv, dropout, maxpool 370 | ''' 371 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 372 | 373 | with K.name_scope('transition_block'): 374 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(ip) 375 | x = Activation('relu')(x) 376 | x = Conv2D(int(nb_filter * compression), (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False, 377 | kernel_regularizer=l2(weight_decay))(x) 378 | x = AveragePooling2D((2, 2), strides=(2, 2))(x) 379 | 380 | return x 381 | 382 | 383 | def _create_dense_net(nb_classes, img_input, include_top, depth=40, nb_dense_block=3, growth_rate=12, nb_filter=-1, 384 | nb_layers_per_block=-1, bottleneck=False, reduction=0.0, dropout_rate=None, weight_decay=1e-4, 385 | subsample_initial_block=False, activation='softmax'): 386 | ''' Build the DenseNet model 387 | Args: 388 | nb_classes: number of classes 389 | img_input: tuple of shape (channels, rows, columns) or (rows, columns, channels) 390 | include_top: flag to include the final Dense layer 391 | depth: number or layers 392 | nb_dense_block: number of dense blocks to add to end (generally = 3) 393 | growth_rate: number of filters to add per dense block 394 | nb_filter: initial number of filters. Default -1 indicates initial number of filters is 2 * growth_rate 395 | nb_layers_per_block: number of layers in each dense block. 396 | Can be a -1, positive integer or a list. 397 | If -1, calculates nb_layer_per_block from the depth of the network. 398 | If positive integer, a set number of layers per dense block. 399 | If list, nb_layer is used as provided. Note that list size must 400 | be (nb_dense_block + 1) 401 | bottleneck: add bottleneck blocks 402 | reduction: reduction factor of transition blocks. Note : reduction value is inverted to compute compression 403 | dropout_rate: dropout rate 404 | weight_decay: weight decay rate 405 | subsample_initial_block: Set to True to subsample the initial convolution and 406 | add a MaxPool2D before the dense blocks are added. 407 | subsample_initial: 408 | activation: Type of activation at the top layer. Can be one of 'softmax' or 'sigmoid'. 409 | Note that if sigmoid is used, classes must be 1. 410 | Returns: keras tensor with nb_layers of conv_block appended 411 | ''' 412 | 413 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1 414 | 415 | if reduction != 0.0: 416 | assert reduction <= 1.0 and reduction > 0.0, 'reduction value must lie between 0.0 and 1.0' 417 | 418 | # layers in each dense block 419 | if type(nb_layers_per_block) is list or type(nb_layers_per_block) is tuple: 420 | nb_layers = list(nb_layers_per_block) # Convert tuple to list 421 | 422 | assert len(nb_layers) == (nb_dense_block), 'If list, nb_layer is used as provided. ' \ 423 | 'Note that list size must be (nb_dense_block)' 424 | final_nb_layer = nb_layers[-1] 425 | nb_layers = nb_layers[:-1] 426 | else: 427 | if nb_layers_per_block == -1: 428 | assert (depth - 4) % 3 == 0, 'Depth must be 3 N + 4 if nb_layers_per_block == -1' 429 | count = int((depth - 4) / 3) 430 | 431 | if bottleneck: 432 | count = count // 2 433 | 434 | nb_layers = [count for _ in range(nb_dense_block)] 435 | final_nb_layer = count 436 | else: 437 | final_nb_layer = nb_layers_per_block 438 | nb_layers = [nb_layers_per_block] * nb_dense_block 439 | 440 | if type(growth_rate) is list or type(growth_rate) is tuple: 441 | growth_rate = list(growth_rate) 442 | assert len(growth_rate) == len(nb_layers) 443 | else: 444 | growth_rate = [growth_rate for _ in range(len(nb_layers))] 445 | 446 | # compute initial nb_filter if -1, else accept users initial nb_filter 447 | if nb_filter <= 0: 448 | nb_filter = growth_rate[0] 449 | 450 | # compute compression factor 451 | compression = 1.0 - reduction 452 | 453 | # Initial convolution 454 | if subsample_initial_block: 455 | initial_kernel = (7, 7) 456 | initial_strides = (2, 2) 457 | else: 458 | initial_kernel = (3, 3) 459 | initial_strides = (1, 1) 460 | 461 | x = Conv2D(nb_filter, initial_kernel, kernel_initializer='he_normal', padding='same', 462 | strides=initial_strides, use_bias=False, kernel_regularizer=l2(weight_decay))(img_input) 463 | 464 | if subsample_initial_block: 465 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(x) 466 | x = Activation('relu')(x) 467 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x) 468 | 469 | # Add dense blocks 470 | for block_idx in range(nb_dense_block - 1): 471 | x, nb_filter = _dense_block(x, nb_layers[block_idx], nb_filter, growth_rate[block_idx], bottleneck=bottleneck, 472 | dropout_rate=dropout_rate, weight_decay=weight_decay) 473 | # add transition_block 474 | x = _transition_block(x, nb_filter, compression=compression, weight_decay=weight_decay) 475 | nb_filter = int(nb_filter * compression) 476 | 477 | # The last dense_block does not have a transition_block 478 | x, nb_filter = _dense_block(x, final_nb_layer, nb_filter, growth_rate[-1], bottleneck=bottleneck, 479 | dropout_rate=dropout_rate, weight_decay=weight_decay) 480 | 481 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(x) 482 | x = Activation('relu')(x) 483 | x = GlobalAveragePooling2D()(x) 484 | 485 | if include_top: 486 | x = Dense(nb_classes, activation=activation)(x) 487 | 488 | return x 489 | 490 | 491 | if __name__ == '__main__': 492 | # from keras.utils.vis_utils import plot_model 493 | # import tensorflow as tf 494 | # from keras import backend as K 495 | # sess = tf.Session() 496 | # K.set_session(sess) 497 | 498 | model = SparseNet((32, 32, 3), depth=40, nb_dense_block=3, 499 | growth_rate=24, bottleneck=False, reduction=0.0, weights=None) 500 | model.summary() 501 | 502 | #writer = tf.summary.FileWriter('logs/', graph=sess.graph) 503 | #writer.close() 504 | 505 | #plot_model(model, 'sparse.png', show_shapes=True) 506 | 507 | -------------------------------------------------------------------------------- /train_cifar10.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os.path 4 | 5 | import sparsenet 6 | import numpy as np 7 | import sklearn.metrics as metrics 8 | 9 | from keras.datasets import cifar10 10 | from keras.utils import np_utils 11 | from keras.preprocessing.image import ImageDataGenerator 12 | from keras.optimizers import Adam 13 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau 14 | from keras import backend as K 15 | 16 | batch_size = 100 17 | nb_classes = 10 18 | nb_epoch = 100 19 | 20 | img_rows, img_cols = 32, 32 21 | img_channels = 3 22 | 23 | img_dim = (img_channels, img_rows, img_cols) if K.image_dim_ordering() == "th" else (img_rows, img_cols, img_channels) 24 | depth = 40 25 | nb_dense_block = 3 26 | growth_rate = 24 27 | nb_filter = -1 28 | dropout_rate = 0.0 # 0.0 for data augmentation 29 | 30 | model = sparsenet.SparseNet(img_dim, classes=nb_classes, depth=depth, nb_dense_block=nb_dense_block, 31 | growth_rate=growth_rate, nb_filter=nb_filter, dropout_rate=dropout_rate, weights=None) 32 | print("Model created") 33 | 34 | model.summary() 35 | optimizer = Adam(lr=1e-3, amsgrad=True) # Using Adam instead of SGD to speed up training 36 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=["accuracy"]) 37 | print("Finished compiling") 38 | print("Building model...") 39 | 40 | (trainX, trainY), (testX, testY) = cifar10.load_data() 41 | 42 | trainX = trainX.astype('float32') 43 | testX = testX.astype('float32') 44 | 45 | # trainX = sparsenet.preprocess_input(trainX) 46 | # testX = sparsenet.preprocess_input(testX) 47 | 48 | cifar_mean = trainX.mean(axis=(0, 1, 2), keepdims=True) 49 | cifar_std = trainX.std(axis=(0, 1, 2), keepdims=True) 50 | 51 | trainX = (trainX - cifar_mean) / (cifar_std + 1e-8) 52 | testX = (testX - cifar_mean) / (cifar_std + 1e-8) 53 | 54 | Y_train = np_utils.to_categorical(trainY, nb_classes) 55 | Y_test = np_utils.to_categorical(testY, nb_classes) 56 | 57 | generator = ImageDataGenerator(width_shift_range=5. / 32, 58 | height_shift_range=5. / 32, 59 | horizontal_flip=True) 60 | 61 | generator.fit(trainX, seed=0) 62 | 63 | # Load model 64 | weights_file = "weights/SparseNet-40-24-CIFAR10.h5" 65 | if os.path.exists(weights_file): 66 | model.load_weights(weights_file) 67 | print("Model loaded.") 68 | 69 | out_dir = "weights/" 70 | 71 | lr_reducer = ReduceLROnPlateau(monitor='val_acc', factor=np.sqrt(0.1), 72 | cooldown=0, patience=5, min_lr=1e-5) 73 | model_checkpoint = ModelCheckpoint(weights_file, monitor="val_acc", save_best_only=True, 74 | save_weights_only=True, verbose=1) 75 | 76 | callbacks = [lr_reducer, model_checkpoint] 77 | 78 | model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size), 79 | steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch, 80 | callbacks=callbacks, 81 | validation_data=(testX, Y_test), 82 | validation_steps=testX.shape[0] // batch_size, verbose=1) 83 | 84 | yPreds = model.predict(testX) 85 | yPred = np.argmax(yPreds, axis=1) 86 | yTrue = testY 87 | 88 | accuracy = metrics.accuracy_score(yTrue, yPred) * 100 89 | error = 100 - accuracy 90 | print("Accuracy : ", accuracy) 91 | print("Error : ", error) 92 | --------------------------------------------------------------------------------