├── .gitignore
├── .idea
└── vcs.xml
├── LICENSE
├── README.md
├── images
├── dense_vs_sparse.png
└── sparse_connectivity.PNG
├── sparsenet.py
└── train_cifar10.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Somshubra Majumdar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sparse Networks in Keras
2 | Keras Implementation of Sparse Networks from the paper [Sparsely Connected Convolutional Networks](https://arxiv.org/abs/1801.05895).
3 |
4 | Code derived from the offical repository - https://github.com/Lyken17/SparseNet
5 |
6 | # Sparse Networks
7 | SparseNet is a variant of DenseNets. While DenseNets have a skip connection after every block in its dense structure, SparseNets have such skip connections only at depths of 2^N (with exponential offsets rather than a static linear offset). DenseNets posses *O(n^2)* skip connections for every dense block, whereas SparseNets have only *O(log n)* skip connections in each of its sparse blocks.
8 |
9 | This allows models which are **much less memory intensive**, while still performing at the level / even surpassing DenseNets, with fewer parameters.
10 |
11 | # Sparse Connectivity
12 |
13 |
14 | The above image from the paper shows that each input at the end only requires *log2 n* input connections.
15 |
16 | # Difference between DenseNets and SparseNets
17 |
18 |
19 | This image from their paper shows the major difference between the connectivity pattern in SparseNets vs ResNets/DenseNets.
20 |
21 | # Caveats
22 | There is a small discrepancy in the number of parameters between the paper and this repo.
23 |
24 | - SparseNet-40-24 (Keras = 0.74 M, paper = 0.76 M)
25 | - SparseNet-100-24 (Keras = 2.50 M, paper = 2.52 M)
26 |
27 | If anyone can figure out the cause of this discrepancy, I'd be grateful.
28 |
29 | # Requirements
30 |
31 | - Keras 2.1.3
32 | - Tensorflow / Theano / CNTK (I am assuming since all frameworks support ResNets, they should be able to support this as well without any modification)
33 |
--------------------------------------------------------------------------------
/images/dense_vs_sparse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-SparseNet/e07358c50017bd566745b375bc192880ff649b1e/images/dense_vs_sparse.png
--------------------------------------------------------------------------------
/images/sparse_connectivity.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/titu1994/keras-SparseNet/e07358c50017bd566745b375bc192880ff649b1e/images/sparse_connectivity.PNG
--------------------------------------------------------------------------------
/sparsenet.py:
--------------------------------------------------------------------------------
1 | '''SparseNet models for Keras.
2 | # Reference
3 | - [Sparsely Connected Convolutional Networks](https://arxiv.org/abs/1801.05895)
4 | - [Github](https://github.com/lyken17/sparsenet)
5 | '''
6 | from __future__ import print_function
7 | from __future__ import absolute_import
8 | from __future__ import division
9 |
10 | import numpy as np
11 | import warnings
12 |
13 | from keras.models import Model
14 | from keras.layers.core import Dense, Dropout, Activation, Reshape
15 | from keras.layers.convolutional import Conv2D, Conv2DTranspose, UpSampling2D
16 | from keras.layers.pooling import AveragePooling2D, MaxPooling2D
17 | from keras.layers.pooling import GlobalAveragePooling2D
18 | from keras.layers import Input
19 | from keras.layers.merge import concatenate
20 | from keras.layers.normalization import BatchNormalization
21 | from keras.regularizers import l2
22 | from keras.utils.layer_utils import convert_all_kernels_in_model, convert_dense_weights_data_format
23 | from keras.utils.data_utils import get_file
24 | from keras.engine.topology import get_source_inputs
25 | from keras.applications.imagenet_utils import _obtain_input_shape
26 | from keras.applications.imagenet_utils import decode_predictions
27 | import keras.backend as K
28 |
29 |
30 | def preprocess_input(x, data_format=None):
31 | """Preprocesses a tensor encoding a batch of images.
32 |
33 | # Arguments
34 | x: input Numpy tensor, 4D.
35 | data_format: data format of the image tensor.
36 |
37 | # Returns
38 | Preprocessed tensor.
39 | """
40 | if data_format is None:
41 | data_format = K.image_data_format()
42 | assert data_format in {'channels_last', 'channels_first'}
43 |
44 | if data_format == 'channels_first':
45 | if x.ndim == 3:
46 | # 'RGB'->'BGR'
47 | x = x[::-1, ...]
48 | # Zero-center by mean pixel
49 | x[0, :, :] -= 103.939
50 | x[1, :, :] -= 116.779
51 | x[2, :, :] -= 123.68
52 | else:
53 | x = x[:, ::-1, ...]
54 | x[:, 0, :, :] -= 103.939
55 | x[:, 1, :, :] -= 116.779
56 | x[:, 2, :, :] -= 123.68
57 | else:
58 | # 'RGB'->'BGR'
59 | x = x[..., ::-1]
60 | # Zero-center by mean pixel
61 | x[..., 0] -= 103.939
62 | x[..., 1] -= 116.779
63 | x[..., 2] -= 123.68
64 |
65 | x *= 0.017 # scale values
66 |
67 | return x
68 |
69 |
70 | def SparseNet(input_shape=None, depth=40, nb_dense_block=3, growth_rate=12, nb_filter=-1, nb_layers_per_block=-1,
71 | bottleneck=False, reduction=0.0, dropout_rate=0.0, weight_decay=1e-4, subsample_initial_block=False,
72 | include_top=True, weights=None, input_tensor=None,
73 | classes=10, activation='softmax'):
74 | '''Instantiate the SparseNet architecture,
75 | optionally loading weights pre-trained
76 | on CIFAR-10. Note that when using TensorFlow,
77 | for best performance you should set
78 | `image_data_format='channels_last'` in your Keras config
79 | at ~/.keras/keras.json.
80 | The model and the weights are compatible with both
81 | TensorFlow and Theano. The dimension ordering
82 | convention used by the model is the one
83 | specified in your Keras config file.
84 | # Arguments
85 | input_shape: optional shape tuple, only to be specified
86 | if `include_top` is False (otherwise the input shape
87 | has to be `(32, 32, 3)` (with `channels_last` dim ordering)
88 | or `(3, 32, 32)` (with `channels_first` dim ordering).
89 | It should have exactly 3 inputs channels,
90 | and width and height should be no smaller than 8.
91 | E.g. `(200, 200, 3)` would be one valid value.
92 | depth: number or layers in the DenseNet
93 | nb_dense_block: number of dense blocks to add to end (generally = 3)
94 | growth_rate: number of filters to add per dense block. Can be
95 | a single integer number or a list of numbers.
96 | If it is a list, length of list must match the length of
97 | `nb_layers_per_block`
98 | nb_filter: initial number of filters. -1 indicates initial
99 | number of filters is 2 * growth_rate
100 | nb_layers_per_block: number of layers in each dense block.
101 | Can be a -1, positive integer or a list.
102 | If -1, calculates nb_layer_per_block from the network depth.
103 | If positive integer, a set number of layers per dense block.
104 | If list, nb_layer is used as provided. Note that list size must
105 | be (nb_dense_block + 1)
106 | bottleneck: flag to add bottleneck blocks in between dense blocks
107 | reduction: reduction factor of transition blocks.
108 | Note : reduction value is inverted to compute compression.
109 | dropout_rate: dropout rate
110 | weight_decay: weight decay rate
111 | subsample_initial_block: Set to True to subsample the initial convolution and
112 | add a MaxPool2D before the dense blocks are added.
113 | include_top: whether to include the fully-connected
114 | layer at the top of the network.
115 | weights: one of `None` (random initialization) or
116 | 'imagenet' (pre-training on ImageNet)..
117 | input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
118 | to use as image input for the model.
119 | classes: optional number of classes to classify images
120 | into, only to be specified if `include_top` is True, and
121 | if no `weights` argument is specified.
122 | activation: Type of activation at the top layer. Can be one of 'softmax' or 'sigmoid'.
123 | Note that if sigmoid is used, classes must be 1.
124 | # Returns
125 | A Keras model instance.
126 | '''
127 |
128 | if weights not in {'imagenet', None}:
129 | raise ValueError('The `weights` argument should be either '
130 | '`None` (random initialization) or `cifar10` '
131 | '(pre-training on CIFAR-10).')
132 |
133 | if weights == 'imagenet' and include_top and classes != 1000:
134 | raise ValueError('If using `weights` as ImageNet with `include_top`'
135 | ' as true, `classes` should be 1000')
136 |
137 | if activation not in ['softmax', 'sigmoid']:
138 | raise ValueError('activation must be one of "softmax" or "sigmoid"')
139 |
140 | if activation == 'sigmoid' and classes != 1:
141 | raise ValueError('sigmoid activation can only be used when classes = 1')
142 |
143 | # Determine proper input shape
144 | input_shape = _obtain_input_shape(input_shape,
145 | default_size=32,
146 | min_size=8,
147 | data_format=K.image_data_format(),
148 | require_flatten=include_top)
149 |
150 | if input_tensor is None:
151 | img_input = Input(shape=input_shape)
152 | else:
153 | if not K.is_keras_tensor(input_tensor):
154 | img_input = Input(tensor=input_tensor, shape=input_shape)
155 | else:
156 | img_input = input_tensor
157 |
158 | x = _create_dense_net(classes, img_input, include_top, depth, nb_dense_block,
159 | growth_rate, nb_filter, nb_layers_per_block, bottleneck, reduction,
160 | dropout_rate, weight_decay, subsample_initial_block, activation)
161 |
162 | # Ensure that the model takes into account
163 | # any potential predecessors of `input_tensor`.
164 | if input_tensor is not None:
165 | inputs = get_source_inputs(input_tensor)
166 | else:
167 | inputs = img_input
168 | # Create model.
169 | model = Model(inputs, x, name='densenet')
170 |
171 | # load weights
172 | if weights == 'imagenet':
173 | weights_loaded = False
174 |
175 | if weights_loaded:
176 | if K.backend() == 'theano':
177 | convert_all_kernels_in_model(model)
178 |
179 | if K.image_data_format() == 'channels_first' and K.backend() == 'tensorflow':
180 | warnings.warn('You are using the TensorFlow backend, yet you '
181 | 'are using the Theano '
182 | 'image data format convention '
183 | '(`image_data_format="channels_first"`). '
184 | 'For best performance, set '
185 | '`image_data_format="channels_last"` in '
186 | 'your Keras config '
187 | 'at ~/.keras/keras.json.')
188 |
189 | print("Weights for the model were loaded successfully")
190 |
191 | return model
192 |
193 |
194 | def SparseNetImageNet121(input_shape=None,
195 | bottleneck=True,
196 | reduction=0.5,
197 | dropout_rate=0.0,
198 | weight_decay=1e-4,
199 | include_top=True,
200 | weights=None,
201 | input_tensor=None,
202 | classes=1000,
203 | activation='softmax'):
204 | return SparseNet(input_shape, depth=121, nb_dense_block=4, growth_rate=32, nb_filter=64,
205 | nb_layers_per_block=[6, 12, 24, 16], bottleneck=bottleneck, reduction=reduction,
206 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True,
207 | include_top=include_top, weights=weights, input_tensor=input_tensor,
208 | classes=classes, activation=activation)
209 |
210 |
211 | def SparseNetImageNet169(input_shape=None,
212 | bottleneck=True,
213 | reduction=0.5,
214 | dropout_rate=0.0,
215 | weight_decay=1e-4,
216 | include_top=True,
217 | weights=None,
218 | input_tensor=None,
219 | classes=1000,
220 | activation='softmax'):
221 | return SparseNet(input_shape, depth=169, nb_dense_block=4, growth_rate=32, nb_filter=64,
222 | nb_layers_per_block=[6, 12, 32, 32], bottleneck=bottleneck, reduction=reduction,
223 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True,
224 | include_top=include_top, weights=weights, input_tensor=input_tensor,
225 | classes=classes, activation=activation)
226 |
227 |
228 | def SparseNetImageNet201(input_shape=None,
229 | bottleneck=True,
230 | reduction=0.5,
231 | dropout_rate=0.0,
232 | weight_decay=1e-4,
233 | include_top=True,
234 | weights=None,
235 | input_tensor=None,
236 | classes=1000,
237 | activation='softmax'):
238 | return SparseNet(input_shape, depth=201, nb_dense_block=4, growth_rate=32, nb_filter=64,
239 | nb_layers_per_block=[6, 12, 48, 32], bottleneck=bottleneck, reduction=reduction,
240 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True,
241 | include_top=include_top, weights=weights, input_tensor=input_tensor,
242 | classes=classes, activation=activation)
243 |
244 |
245 | def SparseNetImageNet264(input_shape=None,
246 | bottleneck=True,
247 | reduction=0.5,
248 | dropout_rate=0.0,
249 | weight_decay=1e-4,
250 | include_top=True,
251 | weights=None,
252 | input_tensor=None,
253 | classes=1000,
254 | activation='softmax'):
255 | return SparseNet(input_shape, depth=264, nb_dense_block=4, growth_rate=32, nb_filter=64,
256 | nb_layers_per_block=[6, 12, 64, 48], bottleneck=bottleneck, reduction=reduction,
257 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True,
258 | include_top=include_top, weights=weights, input_tensor=input_tensor,
259 | classes=classes, activation=activation)
260 |
261 |
262 | def SparseNetImageNet161(input_shape=None,
263 | bottleneck=True,
264 | reduction=0.5,
265 | dropout_rate=0.0,
266 | weight_decay=1e-4,
267 | include_top=True,
268 | weights=None,
269 | input_tensor=None,
270 | classes=1000,
271 | activation='softmax'):
272 | return SparseNet(input_shape, depth=161, nb_dense_block=4, growth_rate=48, nb_filter=96,
273 | nb_layers_per_block=[6, 12, 36, 24], bottleneck=bottleneck, reduction=reduction,
274 | dropout_rate=dropout_rate, weight_decay=weight_decay, subsample_initial_block=True,
275 | include_top=include_top, weights=weights, input_tensor=input_tensor,
276 | classes=classes, activation=activation)
277 |
278 |
279 | def _exponential_index_fetch(x_list):
280 | count = len(x_list)
281 | i = 1
282 | inputs = []
283 | while i <= count:
284 | inputs.append(x_list[count - i])
285 | i *= 2
286 | return inputs
287 |
288 |
289 | def _conv_block(ip, nb_filter, bottleneck=False, dropout_rate=None, weight_decay=1e-4):
290 | ''' Apply BatchNorm, Relu, 3x3 Conv2D, optional bottleneck block and dropout
291 | Args:
292 | ip: Input keras tensor
293 | nb_filter: number of filters
294 | bottleneck: add bottleneck block
295 | dropout_rate: dropout rate
296 | weight_decay: weight decay factor
297 | Returns: keras tensor with batch_norm, relu and convolution2d added (optional bottleneck)
298 | '''
299 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1
300 |
301 | with K.name_scope('conv_block'):
302 | x = BatchNormalization(axis=concat_axis, momentum=0.1, epsilon=1e-5)(ip)
303 | x = Activation('relu')(x)
304 |
305 | if bottleneck:
306 | inter_channel = nb_filter * 4 # Obtained from https://github.com/liuzhuang13/DenseNet/blob/master/densenet.lua
307 |
308 | x = Conv2D(inter_channel, (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False,
309 | kernel_regularizer=l2(weight_decay))(x)
310 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(x)
311 | x = Activation('relu')(x)
312 |
313 | x = Conv2D(nb_filter, (3, 3), kernel_initializer='he_normal', padding='same', use_bias=False)(x)
314 | if dropout_rate:
315 | x = Dropout(dropout_rate)(x)
316 |
317 | return x
318 |
319 |
320 | def _dense_block(x, nb_layers, nb_filter, growth_rate, bottleneck=False, dropout_rate=None, weight_decay=1e-4,
321 | grow_nb_filters=True, return_concat_list=False):
322 | ''' Build a dense_block where the output of each conv_block is fed to subsequent ones
323 | Args:
324 | x: keras tensor
325 | nb_layers: the number of layers of conv_block to append to the model.
326 | nb_filter: number of filters
327 | growth_rate: growth rate
328 | bottleneck: bottleneck block
329 | dropout_rate: dropout rate
330 | weight_decay: weight decay factor
331 | grow_nb_filters: flag to decide to allow number of filters to grow
332 | return_concat_list: return the list of feature maps along with the actual output
333 | Returns: keras tensor with nb_layers of conv_block appended
334 | '''
335 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1
336 |
337 | x_list = [x]
338 | channel_list = [nb_filter]
339 |
340 | for i in range(nb_layers):
341 | #nb_channels = sum(_exponential_index_fetch(channel_list))
342 |
343 | x = _conv_block(x, growth_rate, bottleneck, dropout_rate, weight_decay)
344 | x_list.append(x)
345 |
346 | fetch_outputs = _exponential_index_fetch(x_list)
347 | x = concatenate(fetch_outputs, axis=concat_axis)
348 |
349 | channel_list.append(growth_rate)
350 |
351 | if grow_nb_filters:
352 | nb_filter = sum(_exponential_index_fetch(channel_list))
353 |
354 | if return_concat_list:
355 | return x, nb_filter, x_list
356 | else:
357 | return x, nb_filter
358 |
359 |
360 | def _transition_block(ip, nb_filter, compression=1.0, weight_decay=1e-4):
361 | ''' Apply BatchNorm, Relu 1x1, Conv2D, optional compression, dropout and Maxpooling2D
362 | Args:
363 | ip: keras tensor
364 | nb_filter: number of filters
365 | compression: calculated as 1 - reduction. Reduces the number of feature maps
366 | in the transition block.
367 | dropout_rate: dropout rate
368 | weight_decay: weight decay factor
369 | Returns: keras tensor, after applying batch_norm, relu-conv, dropout, maxpool
370 | '''
371 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1
372 |
373 | with K.name_scope('transition_block'):
374 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(ip)
375 | x = Activation('relu')(x)
376 | x = Conv2D(int(nb_filter * compression), (1, 1), kernel_initializer='he_normal', padding='same', use_bias=False,
377 | kernel_regularizer=l2(weight_decay))(x)
378 | x = AveragePooling2D((2, 2), strides=(2, 2))(x)
379 |
380 | return x
381 |
382 |
383 | def _create_dense_net(nb_classes, img_input, include_top, depth=40, nb_dense_block=3, growth_rate=12, nb_filter=-1,
384 | nb_layers_per_block=-1, bottleneck=False, reduction=0.0, dropout_rate=None, weight_decay=1e-4,
385 | subsample_initial_block=False, activation='softmax'):
386 | ''' Build the DenseNet model
387 | Args:
388 | nb_classes: number of classes
389 | img_input: tuple of shape (channels, rows, columns) or (rows, columns, channels)
390 | include_top: flag to include the final Dense layer
391 | depth: number or layers
392 | nb_dense_block: number of dense blocks to add to end (generally = 3)
393 | growth_rate: number of filters to add per dense block
394 | nb_filter: initial number of filters. Default -1 indicates initial number of filters is 2 * growth_rate
395 | nb_layers_per_block: number of layers in each dense block.
396 | Can be a -1, positive integer or a list.
397 | If -1, calculates nb_layer_per_block from the depth of the network.
398 | If positive integer, a set number of layers per dense block.
399 | If list, nb_layer is used as provided. Note that list size must
400 | be (nb_dense_block + 1)
401 | bottleneck: add bottleneck blocks
402 | reduction: reduction factor of transition blocks. Note : reduction value is inverted to compute compression
403 | dropout_rate: dropout rate
404 | weight_decay: weight decay rate
405 | subsample_initial_block: Set to True to subsample the initial convolution and
406 | add a MaxPool2D before the dense blocks are added.
407 | subsample_initial:
408 | activation: Type of activation at the top layer. Can be one of 'softmax' or 'sigmoid'.
409 | Note that if sigmoid is used, classes must be 1.
410 | Returns: keras tensor with nb_layers of conv_block appended
411 | '''
412 |
413 | concat_axis = 1 if K.image_data_format() == 'channels_first' else -1
414 |
415 | if reduction != 0.0:
416 | assert reduction <= 1.0 and reduction > 0.0, 'reduction value must lie between 0.0 and 1.0'
417 |
418 | # layers in each dense block
419 | if type(nb_layers_per_block) is list or type(nb_layers_per_block) is tuple:
420 | nb_layers = list(nb_layers_per_block) # Convert tuple to list
421 |
422 | assert len(nb_layers) == (nb_dense_block), 'If list, nb_layer is used as provided. ' \
423 | 'Note that list size must be (nb_dense_block)'
424 | final_nb_layer = nb_layers[-1]
425 | nb_layers = nb_layers[:-1]
426 | else:
427 | if nb_layers_per_block == -1:
428 | assert (depth - 4) % 3 == 0, 'Depth must be 3 N + 4 if nb_layers_per_block == -1'
429 | count = int((depth - 4) / 3)
430 |
431 | if bottleneck:
432 | count = count // 2
433 |
434 | nb_layers = [count for _ in range(nb_dense_block)]
435 | final_nb_layer = count
436 | else:
437 | final_nb_layer = nb_layers_per_block
438 | nb_layers = [nb_layers_per_block] * nb_dense_block
439 |
440 | if type(growth_rate) is list or type(growth_rate) is tuple:
441 | growth_rate = list(growth_rate)
442 | assert len(growth_rate) == len(nb_layers)
443 | else:
444 | growth_rate = [growth_rate for _ in range(len(nb_layers))]
445 |
446 | # compute initial nb_filter if -1, else accept users initial nb_filter
447 | if nb_filter <= 0:
448 | nb_filter = growth_rate[0]
449 |
450 | # compute compression factor
451 | compression = 1.0 - reduction
452 |
453 | # Initial convolution
454 | if subsample_initial_block:
455 | initial_kernel = (7, 7)
456 | initial_strides = (2, 2)
457 | else:
458 | initial_kernel = (3, 3)
459 | initial_strides = (1, 1)
460 |
461 | x = Conv2D(nb_filter, initial_kernel, kernel_initializer='he_normal', padding='same',
462 | strides=initial_strides, use_bias=False, kernel_regularizer=l2(weight_decay))(img_input)
463 |
464 | if subsample_initial_block:
465 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(x)
466 | x = Activation('relu')(x)
467 | x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
468 |
469 | # Add dense blocks
470 | for block_idx in range(nb_dense_block - 1):
471 | x, nb_filter = _dense_block(x, nb_layers[block_idx], nb_filter, growth_rate[block_idx], bottleneck=bottleneck,
472 | dropout_rate=dropout_rate, weight_decay=weight_decay)
473 | # add transition_block
474 | x = _transition_block(x, nb_filter, compression=compression, weight_decay=weight_decay)
475 | nb_filter = int(nb_filter * compression)
476 |
477 | # The last dense_block does not have a transition_block
478 | x, nb_filter = _dense_block(x, final_nb_layer, nb_filter, growth_rate[-1], bottleneck=bottleneck,
479 | dropout_rate=dropout_rate, weight_decay=weight_decay)
480 |
481 | x = BatchNormalization(axis=concat_axis, epsilon=1e-5, momentum=0.1)(x)
482 | x = Activation('relu')(x)
483 | x = GlobalAveragePooling2D()(x)
484 |
485 | if include_top:
486 | x = Dense(nb_classes, activation=activation)(x)
487 |
488 | return x
489 |
490 |
491 | if __name__ == '__main__':
492 | # from keras.utils.vis_utils import plot_model
493 | # import tensorflow as tf
494 | # from keras import backend as K
495 | # sess = tf.Session()
496 | # K.set_session(sess)
497 |
498 | model = SparseNet((32, 32, 3), depth=40, nb_dense_block=3,
499 | growth_rate=24, bottleneck=False, reduction=0.0, weights=None)
500 | model.summary()
501 |
502 | #writer = tf.summary.FileWriter('logs/', graph=sess.graph)
503 | #writer.close()
504 |
505 | #plot_model(model, 'sparse.png', show_shapes=True)
506 |
507 |
--------------------------------------------------------------------------------
/train_cifar10.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os.path
4 |
5 | import sparsenet
6 | import numpy as np
7 | import sklearn.metrics as metrics
8 |
9 | from keras.datasets import cifar10
10 | from keras.utils import np_utils
11 | from keras.preprocessing.image import ImageDataGenerator
12 | from keras.optimizers import Adam
13 | from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
14 | from keras import backend as K
15 |
16 | batch_size = 100
17 | nb_classes = 10
18 | nb_epoch = 100
19 |
20 | img_rows, img_cols = 32, 32
21 | img_channels = 3
22 |
23 | img_dim = (img_channels, img_rows, img_cols) if K.image_dim_ordering() == "th" else (img_rows, img_cols, img_channels)
24 | depth = 40
25 | nb_dense_block = 3
26 | growth_rate = 24
27 | nb_filter = -1
28 | dropout_rate = 0.0 # 0.0 for data augmentation
29 |
30 | model = sparsenet.SparseNet(img_dim, classes=nb_classes, depth=depth, nb_dense_block=nb_dense_block,
31 | growth_rate=growth_rate, nb_filter=nb_filter, dropout_rate=dropout_rate, weights=None)
32 | print("Model created")
33 |
34 | model.summary()
35 | optimizer = Adam(lr=1e-3, amsgrad=True) # Using Adam instead of SGD to speed up training
36 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=["accuracy"])
37 | print("Finished compiling")
38 | print("Building model...")
39 |
40 | (trainX, trainY), (testX, testY) = cifar10.load_data()
41 |
42 | trainX = trainX.astype('float32')
43 | testX = testX.astype('float32')
44 |
45 | # trainX = sparsenet.preprocess_input(trainX)
46 | # testX = sparsenet.preprocess_input(testX)
47 |
48 | cifar_mean = trainX.mean(axis=(0, 1, 2), keepdims=True)
49 | cifar_std = trainX.std(axis=(0, 1, 2), keepdims=True)
50 |
51 | trainX = (trainX - cifar_mean) / (cifar_std + 1e-8)
52 | testX = (testX - cifar_mean) / (cifar_std + 1e-8)
53 |
54 | Y_train = np_utils.to_categorical(trainY, nb_classes)
55 | Y_test = np_utils.to_categorical(testY, nb_classes)
56 |
57 | generator = ImageDataGenerator(width_shift_range=5. / 32,
58 | height_shift_range=5. / 32,
59 | horizontal_flip=True)
60 |
61 | generator.fit(trainX, seed=0)
62 |
63 | # Load model
64 | weights_file = "weights/SparseNet-40-24-CIFAR10.h5"
65 | if os.path.exists(weights_file):
66 | model.load_weights(weights_file)
67 | print("Model loaded.")
68 |
69 | out_dir = "weights/"
70 |
71 | lr_reducer = ReduceLROnPlateau(monitor='val_acc', factor=np.sqrt(0.1),
72 | cooldown=0, patience=5, min_lr=1e-5)
73 | model_checkpoint = ModelCheckpoint(weights_file, monitor="val_acc", save_best_only=True,
74 | save_weights_only=True, verbose=1)
75 |
76 | callbacks = [lr_reducer, model_checkpoint]
77 |
78 | model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
79 | steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
80 | callbacks=callbacks,
81 | validation_data=(testX, Y_test),
82 | validation_steps=testX.shape[0] // batch_size, verbose=1)
83 |
84 | yPreds = model.predict(testX)
85 | yPred = np.argmax(yPreds, axis=1)
86 | yTrue = testY
87 |
88 | accuracy = metrics.accuracy_score(yTrue, yPred) * 100
89 | error = 100 - accuracy
90 | print("Accuracy : ", accuracy)
91 | print("Error : ", error)
92 |
--------------------------------------------------------------------------------