├── .github └── FUNDING.yml ├── LICENSE ├── README.md ├── img └── conv_bn_fusion.png ├── kito ├── __init__.py ├── custom_model_test_bench.py └── test_bench.py └── setup.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | custom: https://paypal.me/zfturbo 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Keras inference time optimizer (KITO) 2 | 3 | This code takes on input trained Keras model and optimize layer structure and weights in such a way that model became 4 | much faster (~10-30%), but works identically to initial model. It can be extremely useful in case you need to process large 5 | amount of images with trained model. Reduce operation was tested on all Keras models zoo. See 6 | comparison table below. 7 | 8 | ## Installation 9 | 10 | ``` 11 | pip install kito 12 | ``` 13 | 14 | ## How it works? 15 | 16 | In current version it only apply single type of optimization: It reduces Conv2D + BatchNormalization set of layers to 17 | single Conv2D layer. Since Conv2D + BatchNormalization is very common set of layers, optimization works well 18 | almost on all modern CNNs for image processing. 19 | 20 | Also supported: 21 | * DepthwiseConv2D + BatchNormalization => DepthwiseConv2D 22 | * SeparableConv2D + BatchNormalization => SeparableConv2D 23 | * Conv2DTranspose + BatchNormalization => Conv2DTranspose 24 | * Conv3D + BatchNormalization => Conv3D 25 | * Conv1D + BatchNormalization => Conv1D 26 | 27 | ## How to use 28 | 29 | Typical code: 30 | 31 | ``` 32 | model.fit(...) 33 | ... 34 | model.predict(...) 35 | ``` 36 | 37 | must be replaced with following block: 38 | 39 | ``` 40 | from kito import reduce_keras_model 41 | model.fit(...) 42 | ... 43 | model_reduced = reduce_keras_model(model) 44 | model_reduced.predict(...) 45 | ``` 46 | 47 | So basically you need to insert 2 lines in your code to speed up operations. But note that it requires 48 | some time to convert model. You can see usage example in [test_bench.py](https://github.com/ZFTurbo/Keras-inference-time-optimizer/blob/master/kito/test_bench.py) 49 | 50 | ## Comparison table 51 | 52 | | Neural net | Input shape | Number of layers (Init) | Number of layers (Reduced) | Number of params (Init) | Number of params (Reduced) | Time to process 10000 images (Init) | Time to process 10000 images (Reduced) | Conversion Time (sec) | Maximum diff on final layer | Average difference on final layer | 53 | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | 54 | | MobileNet (1.0) | (224, 224, 3) | 102 | 75 | 4,253,864| 4,221,032| **32.38** | **22.13** | 12.45 | 2.80e-06 | 4.41e-09 | 55 | | MobileNetV2 (1.4) | (224, 224, 3) | 152 | 100 | 6,156,712| 6,084,808| **52.53** | **37.71** | 87.00 | 3.99e-06 | 6.88e-09 | 56 | | ResNet50 | (224, 224, 3) | 177 | 124 | 25,636,712 | 25,530,472 | **58.87** | **35.81** | 45.28 | 5.06e-07 | 1.24e-09 | 57 | | Inception_v3 | (299, 299, 3) | 313 | 219 | 23,851,784 | 23,817,352 | **79.15** | **59.55** | 126.02 | 7.74e-07 | 1.26e-09 | 58 | | Inception_Resnet_v2 | (299, 299, 3) | 782 | 578 | 55,873,736 | 55,813,192 | **131.16** | **102.38** | 766.14 | 8.04e-07 | 9.26e-10 | 59 | | Xception | (299, 299, 3) | 134 | 94 | 22,910,480 | 22,828,688 | **115.56** | **76.17** | 28.15 | 3.65e-07 | 9.69e-10 | 60 | | DenseNet121 | (224, 224, 3) | 428 | 369 | 8,062,504 | 8,040,040 | **68.25** | **57.57** | 392.24 | 4.61e-07 | 8.69e-09 | 61 | | DenseNet169 | (224, 224, 3) | 596 | 513 | 14,307,880 | 14,276,200 | **80.56** | **68.74** | 772.54 | 2.14e-06 | 1.79e-09 | 62 | | DenseNet201 | (224, 224, 3) | 708 | 609 | 20,242,984 | 20,205,160 | **98.99** | **87.04** | 1120.88 | 7.00e-07 | 1.27e-09 | 63 | | NasNetMobile | (224, 224, 3) | 751 | 563 | 5,326,716 | 5,272,599 | **46.05** | **31.76** | 728.96 | 1.10e-06 | 1.60e-09 | 64 | | NasNetLarge | (331, 331, 3) | 1021 | 761 | 88,949,818 | 88,658,596 | **445.58** | **328.16** | 1402.61 | 1.43e-07 | 5.88e-10 | 65 | | [ZF_UNET_224](https://github.com/ZFTurbo/ZF_UNET_224_Pretrained_Model) | (224, 224, 3) | 85 | 63 | 31,466,753 | 31,442,689 | **96.76** | **69.17** | 9.93 | 4.72e-05 | 7.54e-09 | 66 | | [DeepLabV3+](https://github.com/bonlime/keras-deeplab-v3-plus) (mobile) | (512, 512, 3) | 162 | 108 | 2,146,645 | 2,097,013 | **583.63** | **432.71** | 48.00 | 4.72e-05 | 1.00e-05 | 67 | | [DeepLabV3+](https://github.com/bonlime/keras-deeplab-v3-plus) (xception) | (512, 512, 3) | 409 | 263 | 41,258,213 | 40,954,013 | **1000.36** | **699.24** | 333.1 | 8.63e-05 | 5.22e-06 | 68 | | [ResNet152](https://github.com/broadinstitute/keras-resnet) | (224, 224, 3) | 566 | 411 | 60,344,232 | 60,117,096 | **107.92** | **68.53** | 357.65 | 8.94e-07 | 1.27e-09 | 69 | 70 | **Config**: Single NVIDIA GTX 1080 8 GB. Timing obtained on Tensorflow 1.4 (+ CUDA 8.0) backend 71 | 72 | ## Notes 73 | 74 | * It feels like conversion works very slow for no reason, but it should be much faster since all 75 | manipulations with layers and weights are very fast. Probably I use some very slow Keras operations in process. 76 | Feel free to give advice on how to change code to make it faster. 77 | * You can check that both models work the same with function: compare_two_models_results(model, model_reduced, 10000) 78 | * Non-zero difference on final layer is accumulated because of large amount of floating point operations, which is not precise 79 | * Some non-standard layer or parameters (which is not used in any keras.applications CNN) can produce wrong results. 80 | Most likely code will just fail in these conditions and you will see layer which cause it in python error message. 81 | 82 | ## Requirements 83 | 84 | * Code was tested on Keras 2.1.6 (TensorFlow 1.4 backend) and on Keras 2.2.0 (TensorFlow 1.8.0 backend) 85 | 86 | ## Formulas 87 | 88 | ![Base formulas](https://raw.githubusercontent.com/ZFTurbo/Keras-inference-time-optimizer/master/img/conv_bn_fusion.png) 89 | 90 | ## Other implementations 91 | 92 | [PyTorch BN Fusion](https://github.com/MIPT-Oulu/pytorch_bn_fusion) - with support for VGG, ResNet, SeNet. -------------------------------------------------------------------------------- /img/conv_bn_fusion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZFTurbo/Keras-inference-time-optimizer/cad25d9285a5bf9c1079c931cafac1966c4e48a8/img/conv_bn_fusion.png -------------------------------------------------------------------------------- /kito/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reduce neural net structure (Conv + BN -> Conv) 3 | Also works: 4 | DepthwiseConv2D + BN -> DepthwiseConv2D 5 | SeparableConv2D + BN -> SeparableConv2D 6 | 7 | This code takes on input trained Keras model and optimize layer structure and weights in such a way 8 | that model became much faster (~30%), but works identically to initial model. It can be extremely 9 | useful in case you need to process large amount of images with trained model. Reduce operation was 10 | tested on all Keras models zoo. See comparison table and full description by link: 11 | https://github.com/ZFTurbo/Keras-inference-time-optimizer 12 | Author: Roman Solovyev (ZFTurbo) 13 | """ 14 | 15 | from __future__ import absolute_import 16 | from __future__ import division 17 | from __future__ import print_function 18 | 19 | 20 | import numpy as np 21 | 22 | 23 | __version__ = '1.0.5' 24 | 25 | 26 | def get_keras_sub_version(): 27 | try: 28 | from keras import __version__ 29 | type = int(__version__.split('.')[1]) 30 | except: 31 | type = 2 32 | return type 33 | 34 | 35 | def get_input_layers_ids(model, layer, verbose=False): 36 | res = dict() 37 | for i, l in enumerate(model.layers): 38 | layer_id = str(id(l)) 39 | res[layer_id] = i 40 | 41 | inbound_layers = [] 42 | layer_id = str(id(layer)) 43 | for i, node in enumerate(layer._inbound_nodes): 44 | node_key = layer.name + '_ib-' + str(i) 45 | if get_keras_sub_version() == 1: 46 | network_nodes = model._container_nodes 47 | else: 48 | network_nodes = model._network_nodes 49 | if node_key in network_nodes: 50 | node_inbound_layers = node.inbound_layers 51 | if type(node_inbound_layers) is not list: 52 | node_inbound_layers = [node_inbound_layers] 53 | for inbound_layer in node_inbound_layers: 54 | inbound_layer_id = str(id(inbound_layer)) 55 | inbound_layers.append(res[inbound_layer_id]) 56 | return inbound_layers 57 | 58 | 59 | def get_output_layers_ids(model, layer, verbose=False): 60 | res = dict() 61 | for i, l in enumerate(model.layers): 62 | layer_id = str(id(l)) 63 | res[layer_id] = i 64 | 65 | outbound_layers = [] 66 | 67 | for i, node in enumerate(layer._outbound_nodes): 68 | node_key = layer.name + '_ib-' + str(i) 69 | if get_keras_sub_version() == 1: 70 | network_nodes = model._container_nodes 71 | else: 72 | network_nodes = model._network_nodes 73 | if node_key in network_nodes: 74 | outbound_layer_id = str(id(node.outbound_layer)) 75 | if outbound_layer_id in res: 76 | outbound_layers.append(res[outbound_layer_id]) 77 | else: 78 | print('Warning, some problem with outbound node on layer {}!'.format(layer.name)) 79 | return outbound_layers 80 | 81 | 82 | def get_copy_of_layer(layer, verbose=False): 83 | try: 84 | from keras.layers import Activation 85 | from keras import layers 86 | except: 87 | from tensorflow.keras.layers import Activation 88 | from tensorflow.keras import layers 89 | config = layer.get_config() 90 | 91 | # Non-standard relu6 layer (from MobileNet) 92 | if layer.__class__.__name__ == 'Activation': 93 | if config['activation'] == 'relu6': 94 | if get_keras_sub_version() == 1: 95 | from keras.applications.mobilenet import relu6 96 | else: 97 | from keras_applications.mobilenet import relu6 98 | layer_copy = Activation(relu6, name=layer.name) 99 | return layer_copy 100 | 101 | # DeepLabV3+ non-standard layer 102 | if layer.__class__.__name__ == 'BilinearUpsampling': 103 | from neural_nets.deeplab_v3_plus_model import BilinearUpsampling 104 | layer_copy = BilinearUpsampling(upsampling=config['upsampling'], output_size=config['output_size'], name=layer.name) 105 | return layer_copy 106 | 107 | # RetinaNet non-standard layer 108 | if layer.__class__.__name__ == 'UpsampleLike': 109 | from keras_retinanet.layers import UpsampleLike 110 | layer_copy = UpsampleLike(name=layer.name) 111 | return layer_copy 112 | 113 | # RetinaNet non-standard layer 114 | if layer.__class__.__name__ == 'Anchors': 115 | from keras_retinanet.layers import Anchors 116 | layer_copy = Anchors(name=layer.name, size=config['size'], stride=config['stride'], 117 | ratios=config['ratios'], scales=config['scales']) 118 | return layer_copy 119 | 120 | # RetinaNet non-standard layer 121 | if layer.__class__.__name__ == 'RegressBoxes': 122 | from keras_retinanet.layers import RegressBoxes 123 | layer_copy = RegressBoxes(name=layer.name, mean=config['mean'], std=config['std']) 124 | return layer_copy 125 | 126 | # RetinaNet non-standard layer 127 | if layer.__class__.__name__ == 'PriorProbability': 128 | from keras_retinanet.layers import PriorProbability 129 | layer_copy = PriorProbability(name=layer.name, mean=config['mean'], std=config['std']) 130 | return layer_copy 131 | 132 | # RetinaNet non-standard layer 133 | if layer.__class__.__name__ == 'ClipBoxes': 134 | from keras_retinanet.layers import ClipBoxes 135 | layer_copy = ClipBoxes(name=layer.name) 136 | return layer_copy 137 | 138 | # RetinaNet non-standard layer 139 | if layer.__class__.__name__ == 'FilterDetections': 140 | from keras_retinanet.layers import FilterDetections 141 | layer_copy = FilterDetections(name=layer.name, max_detections=config['max_detections'], 142 | nms_threshold=config['nms_threshold'], 143 | score_threshold=config['score_threshold'], 144 | nms=config['nms'], class_specific_filter=config['class_specific_filter'], 145 | trainable=config['trainable'], 146 | parallel_iterations=config['parallel_iterations']) 147 | return layer_copy 148 | 149 | layer_copy = layers.deserialize({'class_name': layer.__class__.__name__, 'config': config}) 150 | try: 151 | layer_copy.name = layer.name 152 | except: 153 | layer_copy._name = layer._name 154 | return layer_copy 155 | 156 | 157 | def get_layers_without_output(model, verbose=False): 158 | output_tensor = list(model.outputs) 159 | output_names = list(model.output_names) 160 | if verbose: 161 | print('Outputs [{}]: {}'.format(len(output_tensor), output_names)) 162 | return output_tensor, output_names 163 | 164 | 165 | def optimize_conv_batchnorm_block(m, initial_model, input_layers, conv, bn, verbose=False): 166 | try: 167 | from keras import layers 168 | from keras.models import Model 169 | except: 170 | from tensorflow.keras import layers 171 | from tensorflow.keras.models import Model 172 | 173 | conv_layer_type = conv.__class__.__name__ 174 | conv_config = conv.get_config() 175 | conv_config['use_bias'] = True 176 | bn_config = bn.get_config() 177 | if conv_config['activation'] != 'linear': 178 | print('Only linear activation supported for conv + bn optimization!') 179 | exit() 180 | 181 | # Copy Conv layer 182 | layer_copy = layers.deserialize({'class_name': conv.__class__.__name__, 'config': conv_config}) 183 | # We use batch norm name here to find it later 184 | try: 185 | layer_copy.name = bn.name 186 | except: 187 | layer_copy._name = bn._name 188 | 189 | # Create new model to initialize layer. We need to store other output tensors as well 190 | output_tensor, output_names = get_layers_without_output(m, verbose) 191 | input_layer_name = initial_model.layers[input_layers[0]].name 192 | prev_layer = m.get_layer(name=input_layer_name) 193 | x = layer_copy(prev_layer.output) 194 | 195 | output_tensor_to_use = [x] 196 | for i in range(len(output_names)): 197 | if output_names[i] != input_layer_name: 198 | output_tensor_to_use.append(output_tensor[i]) 199 | 200 | if len(output_tensor_to_use) == 1: 201 | output_tensor_to_use = output_tensor_to_use[0] 202 | 203 | tmp_model = Model(inputs=m.input, outputs=output_tensor_to_use) 204 | 205 | if conv.get_config()['use_bias']: 206 | (conv_weights, conv_bias) = conv.get_weights() 207 | else: 208 | (conv_weights,) = conv.get_weights() 209 | 210 | if bn_config['scale']: 211 | gamma, beta, run_mean, run_std = bn.get_weights() 212 | else: 213 | gamma = 1.0 214 | beta, run_mean, run_std = bn.get_weights() 215 | 216 | eps = bn_config['epsilon'] 217 | A = gamma / np.sqrt(run_std + eps) 218 | 219 | if conv.get_config()['use_bias']: 220 | B = beta + (gamma * (conv_bias - run_mean) / np.sqrt(run_std + eps)) 221 | else: 222 | B = beta - ((gamma * run_mean) / np.sqrt(run_std + eps)) 223 | 224 | if conv_layer_type == 'Conv2D': 225 | for i in range(conv_weights.shape[-1]): 226 | conv_weights[:, :, :, i] *= A[i] 227 | elif conv_layer_type == 'Conv2DTranspose': 228 | for i in range(conv_weights.shape[-2]): 229 | conv_weights[:, :, i, :] *= A[i] 230 | elif conv_layer_type == 'DepthwiseConv2D': 231 | for i in range(conv_weights.shape[-2]): 232 | conv_weights[:, :, i, :] *= A[i] 233 | elif conv_layer_type == 'Conv3D': 234 | for i in range(conv_weights.shape[-1]): 235 | conv_weights[:, :, :, :, i] *= A[i] 236 | elif conv_layer_type == 'Conv1D': 237 | for i in range(conv_weights.shape[-1]): 238 | conv_weights[:, :, i] *= A[i] 239 | 240 | tmp_model.get_layer(layer_copy.name).set_weights((conv_weights, B)) 241 | return tmp_model 242 | 243 | 244 | def optimize_separableconv2d_batchnorm_block(m, initial_model, input_layers, conv, bn, verbose=False): 245 | try: 246 | from keras import layers 247 | from keras.models import Model 248 | except: 249 | from tensorflow.keras import layers 250 | from tensorflow.keras.models import Model 251 | 252 | conv_config = conv.get_config() 253 | conv_config['use_bias'] = True 254 | bn_config = bn.get_config() 255 | if conv_config['activation'] != 'linear': 256 | print('Only linear activation supported for conv + bn optimization!') 257 | exit() 258 | 259 | layer_copy = layers.deserialize({'class_name': conv.__class__.__name__, 'config': conv_config}) 260 | # We use batch norm name here to find it later 261 | try: 262 | layer_copy.name = bn.name 263 | except: 264 | layer_copy._name = bn._name 265 | 266 | # Create new model to initialize layer. We need to store other output tensors as well 267 | output_tensor, output_names = get_layers_without_output(m, verbose) 268 | input_layer_name = initial_model.layers[input_layers[0]].name 269 | prev_layer = m.get_layer(name=input_layer_name) 270 | x = layer_copy(prev_layer.output) 271 | 272 | output_tensor_to_use = [x] 273 | for i in range(len(output_names)): 274 | if output_names[i] != input_layer_name: 275 | output_tensor_to_use.append(output_tensor[i]) 276 | 277 | if len(output_tensor_to_use) == 1: 278 | output_tensor_to_use = output_tensor_to_use[0] 279 | 280 | tmp_model = Model(inputs=m.input, outputs=output_tensor_to_use) 281 | 282 | if conv.get_config()['use_bias']: 283 | (conv_weights_3, conv_weights_1, conv_bias) = conv.get_weights() 284 | else: 285 | (conv_weights_3, conv_weights_1) = conv.get_weights() 286 | 287 | if bn_config['scale']: 288 | gamma, beta, run_mean, run_std = bn.get_weights() 289 | else: 290 | gamma = 1.0 291 | beta, run_mean, run_std = bn.get_weights() 292 | 293 | eps = bn_config['epsilon'] 294 | A = gamma / np.sqrt(run_std + eps) 295 | 296 | if conv.get_config()['use_bias']: 297 | B = beta + (gamma * (conv_bias - run_mean) / np.sqrt(run_std + eps)) 298 | else: 299 | B = beta - ((gamma * run_mean) / np.sqrt(run_std + eps)) 300 | 301 | for i in range(conv_weights_1.shape[-1]): 302 | conv_weights_1[:, :, :, i] *= A[i] 303 | 304 | # print(conv_weights_3.shape, conv_weights_1.shape, A.shape) 305 | 306 | tmp_model.get_layer(layer_copy.name).set_weights((conv_weights_3, conv_weights_1, B)) 307 | return tmp_model 308 | 309 | 310 | def reduce_keras_model(model, verbose=False): 311 | try: 312 | from keras.models import Model 313 | from keras.models import clone_model 314 | except: 315 | from tensorflow.keras.models import Model 316 | from tensorflow.keras.models import clone_model 317 | 318 | x = [] 319 | input = [] 320 | skip_layers = [] 321 | keras_sub_version = get_keras_sub_version() 322 | if verbose: 323 | print('Keras sub version: {}'.format(keras_sub_version)) 324 | 325 | # Find all inputs 326 | for level_id in range(len(model.layers)): 327 | layer = model.layers[level_id] 328 | layer_type = layer.__class__.__name__ 329 | if layer_type == 'InputLayer': 330 | inp1 = get_copy_of_layer(layer, verbose) 331 | x.append(inp1) 332 | input.append(inp1.output) 333 | tmp_model = Model(inputs=input, outputs=input) 334 | 335 | for level_id in range(len(model.layers)): 336 | layer = model.layers[level_id] 337 | layer_type = layer.__class__.__name__ 338 | 339 | # Skip input layers 340 | if layer_type == 'InputLayer': 341 | continue 342 | 343 | input_layers = get_input_layers_ids(model, layer, verbose) 344 | output_layers = get_output_layers_ids(model, layer, verbose) 345 | if verbose: 346 | print('Go for {}: {} ({}). Input layers: {} Output layers: {}'.format(level_id, layer_type, layer.name, input_layers, output_layers)) 347 | 348 | if level_id in skip_layers: 349 | if verbose: 350 | print('Skip layer because it was removed during optimization!') 351 | continue 352 | 353 | # Special cases for reducing 354 | if len(output_layers) == 1: 355 | next_layer = model.layers[output_layers[0]] 356 | next_layer_type = next_layer.__class__.__name__ 357 | if layer_type in ['Conv2D', 'DepthwiseConv2D', 'Conv2DTranspose', 'Conv3D', 'Conv1D'] and next_layer_type == 'BatchNormalization': 358 | tmp_model = optimize_conv_batchnorm_block(tmp_model, model, input_layers, layer, next_layer, verbose) 359 | x = tmp_model.layers[-1].output 360 | skip_layers.append(output_layers[0]) 361 | continue 362 | 363 | if layer_type in ['SeparableConv2D'] and next_layer_type == 'BatchNormalization': 364 | tmp_model = optimize_separableconv2d_batchnorm_block(tmp_model, model, input_layers, layer, next_layer, verbose) 365 | x = tmp_model.layers[-1].output 366 | skip_layers.append(output_layers[0]) 367 | continue 368 | 369 | if layer_type == 'Model': 370 | new_layer = clone_model(layer) 371 | new_layer.set_weights(layer.get_weights()) 372 | if verbose is True: 373 | print('Try to recurcievly reduce internal model {}!'.format(new_layer.name)) 374 | new_layer = reduce_keras_model(new_layer, verbose=verbose) 375 | else: 376 | new_layer = get_copy_of_layer(layer, verbose) 377 | 378 | prev_layer = [] 379 | for i in range(len(set(input_layers))): 380 | search_layer = tmp_model.get_layer(name=model.layers[input_layers[i]].name) 381 | try: 382 | tens = search_layer.output 383 | prev_layer.append(tens) 384 | except: 385 | # Ugly need to check for correctness 386 | for node in search_layer._inbound_nodes: 387 | for i in range(len(node.inbound_layers)): 388 | outbound_tensor_index = node.tensor_indices[i] 389 | prev_layer.append(node.output_tensors[outbound_tensor_index]) 390 | 391 | output_tensor, output_names = get_layers_without_output(tmp_model, verbose) 392 | 393 | if len(prev_layer) == 1: 394 | prev_layer = prev_layer[0] 395 | 396 | if layer_type == 'Model': 397 | if type(prev_layer) is list: 398 | for f in prev_layer: 399 | x = new_layer(f) 400 | if f in output_tensor: 401 | output_tensor.remove(f) 402 | output_tensor.append(x) 403 | else: 404 | x = new_layer(prev_layer) 405 | if prev_layer in output_tensor: 406 | output_tensor.remove(prev_layer) 407 | output_tensor.append(x) 408 | else: 409 | # print('!!!!!!!', prev_layer) 410 | x = new_layer(prev_layer) 411 | if type(prev_layer) is list: 412 | for f in prev_layer: 413 | remove_positions = [] 414 | for j in range(len(output_tensor)): 415 | if f.name == output_tensor[j].name: 416 | remove_positions.append(j) 417 | # Remove in reverse order 418 | for j in remove_positions[::-1]: 419 | output_tensor.pop(j) 420 | else: 421 | remove_positions = [] 422 | for j in range(len(output_tensor)): 423 | if prev_layer.name == output_tensor[j].name: 424 | remove_positions.append(j) 425 | # Remove in reverse order 426 | for j in remove_positions[::-1]: 427 | output_tensor.pop(j) 428 | 429 | if type(x) is list: 430 | output_tensor += x 431 | else: 432 | output_tensor.append(x) 433 | 434 | tmp_model = Model(inputs=input, outputs=output_tensor) 435 | if layer_type != 'Model': 436 | tmp_model.get_layer(name=layer.name).set_weights(layer.get_weights()) 437 | 438 | output_tensor, output_names = get_layers_without_output(tmp_model, verbose) 439 | if verbose: 440 | print('Output names: {}'.format(output_names)) 441 | model_reduced = Model(inputs=input, outputs=output_tensor, name=model.name) 442 | if verbose: 443 | print('Initial number of layers: {}'.format(len(model.layers))) 444 | print('Reduced number of layers: {}'.format(len(model_reduced.layers))) 445 | return model_reduced 446 | 447 | -------------------------------------------------------------------------------- /kito/custom_model_test_bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regression tests for KITO 3 | Author: Roman Solovyev (ZFTurbo), IPPM RAS: https://github.com/ZFTurbo/ 4 | """ 5 | 6 | import argparse 7 | 8 | try: 9 | from keras.models import load_model 10 | except: 11 | from tensorflow.keras.models import load_model 12 | 13 | from kito import * 14 | import time 15 | 16 | 17 | def compare_two_models_results(m1, m2, test_number=10000, max_batch=10000): 18 | input_shape1 = m1.input_shape 19 | input_shape2 = m2.input_shape 20 | if tuple(input_shape1) != tuple(input_shape2): 21 | print('Different input shapes for models {} vs {}'.format(input_shape1, input_shape2)) 22 | output_shape1 = m1.output_shape 23 | output_shape2 = m2.output_shape 24 | if tuple(output_shape1) != tuple(output_shape2): 25 | print('Different output shapes for models {} vs {}'.format(output_shape1, output_shape2)) 26 | print(input_shape1, input_shape2, output_shape1, output_shape2) 27 | 28 | t1 = 0 29 | t2 = 0 30 | max_error = 0 31 | avg_error = 0 32 | count = 0 33 | for i in range(0, test_number, max_batch): 34 | tst = min(test_number - i, max_batch) 35 | print('Generate random images {}...'.format(tst)) 36 | 37 | if type(input_shape1) is list: 38 | matrix = [] 39 | for i1 in input_shape1: 40 | matrix.append(np.random.uniform(0.0, 1.0, (tst,) + i1[1:])) 41 | else: 42 | # None shape fix 43 | inp_shape_fix = list(input_shape1) 44 | for i in range(1, len(inp_shape_fix)): 45 | if inp_shape_fix[i] is None: 46 | inp_shape_fix[i] = 224 47 | matrix = np.random.uniform(0.0, 1.0, (tst,) + tuple(inp_shape_fix[1:])) 48 | 49 | start_time = time.time() 50 | res1 = m1.predict(matrix) 51 | t1 += time.time() - start_time 52 | 53 | start_time = time.time() 54 | res2 = m2.predict(matrix) 55 | t2 += time.time() - start_time 56 | 57 | if type(res1) is list: 58 | for i1 in range(len(res1)): 59 | abs_diff = np.abs(res1[i1] - res2[i1]) 60 | max_error = max(max_error, abs_diff.max()) 61 | avg_error += abs_diff.sum() 62 | count += abs_diff.size 63 | else: 64 | abs_diff = np.abs(res1 - res2) 65 | max_error = max(max_error, abs_diff.max()) 66 | avg_error += abs_diff.sum() 67 | count += abs_diff.size 68 | 69 | print("Initial model prediction time for {} random images: {:.2f} seconds".format(test_number, t1)) 70 | print("Reduced model prediction time for {} same random images: {:.2f} seconds".format(test_number, t2)) 71 | print('Models raw max difference: {} (Avg difference: {})'.format(max_error, avg_error/count)) 72 | return max_error 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument('-m', dest='model_filepath', required=True) 78 | parser.add_argument('--verbose', dest='verbose', action='store_true', default=False, required=False) 79 | args = parser.parse_args() 80 | 81 | model = load_model(args.model_filepath) 82 | verbose = args.verbose 83 | 84 | if verbose: 85 | print(model.summary()) 86 | start_time = time.time() 87 | model_reduced = reduce_keras_model(model, verbose=verbose) 88 | print("Reduction time: {:.2f} seconds".format(time.time() - start_time)) 89 | if verbose: 90 | print(model_reduced.summary()) 91 | print('Initial model number layers: {}'.format(len(model.layers))) 92 | print('Reduced model number layers: {}'.format(len(model_reduced.layers))) 93 | print('Compare models...') 94 | max_error = compare_two_models_results(model, model_reduced, test_number=1000, max_batch=1000) 95 | if max_error > 1e-04: 96 | print('Possible error just happen! Max error value: {}'.format(max_error)) 97 | -------------------------------------------------------------------------------- /kito/test_bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | Regression tests for KITO 3 | Author: Roman Solovyev (ZFTurbo), IPPM RAS: https://github.com/ZFTurbo/ 4 | """ 5 | 6 | from kito import * 7 | import time 8 | try: 9 | from keras.layers import Input, Conv2D, BatchNormalization, Activation, Concatenate, GlobalAveragePooling2D, Dense, \ 10 | Conv2DTranspose, Conv3D, Conv1D 11 | from keras.models import Model 12 | from keras.applications.mobilenet import MobileNet 13 | import keras.backend as K 14 | from keras.utils import custom_object_scope 15 | except: 16 | from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation, Concatenate, \ 17 | GlobalAveragePooling2D, Dense, Conv2DTranspose, Conv3D, Conv1D 18 | from tensorflow.keras.models import Model 19 | from tensorflow.keras.applications.mobilenet import MobileNet 20 | import tensorflow.keras.backend as K 21 | from tensorflow.keras.utils import custom_object_scope 22 | 23 | 24 | def compare_two_models_results(m1, m2, test_number=10000, max_batch=10000): 25 | input_shape1 = m1.input_shape 26 | input_shape2 = m2.input_shape 27 | if tuple(input_shape1) != tuple(input_shape2): 28 | print('Different input shapes for models {} vs {}'.format(input_shape1, input_shape2)) 29 | output_shape1 = m1.output_shape 30 | output_shape2 = m2.output_shape 31 | if tuple(output_shape1) != tuple(output_shape2): 32 | print('Different output shapes for models {} vs {}'.format(output_shape1, output_shape2)) 33 | print(input_shape1, input_shape2, output_shape1, output_shape2) 34 | 35 | t1 = 0 36 | t2 = 0 37 | max_error = 0 38 | avg_error = 0 39 | count = 0 40 | for i in range(0, test_number, max_batch): 41 | tst = min(test_number - i, max_batch) 42 | print('Generate random images {}...'.format(tst)) 43 | 44 | if type(input_shape1) is list: 45 | matrix = [] 46 | for i1 in input_shape1: 47 | matrix.append(np.random.uniform(0.0, 1.0, (tst,) + i1[1:])) 48 | else: 49 | # None shape fix 50 | inp_shape_fix = list(input_shape1) 51 | for i in range(1, len(inp_shape_fix)): 52 | if inp_shape_fix[i] is None: 53 | inp_shape_fix[i] = 224 54 | matrix = np.random.uniform(0.0, 1.0, (tst,) + tuple(inp_shape_fix[1:])) 55 | 56 | start_time = time.time() 57 | res1 = m1.predict(matrix) 58 | t1 += time.time() - start_time 59 | 60 | start_time = time.time() 61 | res2 = m2.predict(matrix) 62 | t2 += time.time() - start_time 63 | 64 | if type(res1) is list: 65 | for i1 in range(len(res1)): 66 | abs_diff = np.abs(res1[i1] - res2[i1]) 67 | max_error = max(max_error, abs_diff.max()) 68 | avg_error += abs_diff.sum() 69 | count += abs_diff.size 70 | else: 71 | abs_diff = np.abs(res1 - res2) 72 | max_error = max(max_error, abs_diff.max()) 73 | avg_error += abs_diff.sum() 74 | count += abs_diff.size 75 | 76 | print("Initial model prediction time for {} random images: {:.2f} seconds".format(test_number, t1)) 77 | print("Reduced model prediction time for {} same random images: {:.2f} seconds".format(test_number, t2)) 78 | print('Models raw max difference: {} (Avg difference: {})'.format(max_error, avg_error/count)) 79 | return max_error 80 | 81 | 82 | def get_custom_multi_io_model(): 83 | inp1 = Input((224, 224, 3)) 84 | inp2 = Input((224, 224, 3)) 85 | 86 | branch1 = Conv2D(32, (3, 3), kernel_initializer='random_uniform')(inp1) 87 | branch1 = BatchNormalization()(branch1) 88 | branch1 = Activation('relu')(branch1) 89 | 90 | branch2 = Conv2D(32, (3, 3), kernel_initializer='random_uniform')(inp2) 91 | branch2 = BatchNormalization()(branch2) 92 | branch2 = Activation('relu')(branch2) 93 | 94 | x = Concatenate(axis=-1, name='concat')([branch1, branch2]) 95 | 96 | branch3 = Conv2D(32, (3, 3), kernel_initializer='random_uniform')(x) 97 | branch3 = BatchNormalization()(branch3) 98 | branch3 = Activation('relu')(branch3) 99 | 100 | out1 = GlobalAveragePooling2D()(branch2) 101 | out1 = Dense(1, activation='sigmoid', name='fc1')(out1) 102 | 103 | out2 = GlobalAveragePooling2D()(branch3) 104 | out2 = Dense(1, activation='sigmoid', name='fc2')(out2) 105 | 106 | custom_model = Model(inputs=[inp1, inp2], outputs=[out1, out2]) 107 | return custom_model 108 | 109 | 110 | def get_simple_submodel(): 111 | inp = Input((28, 28, 4)) 112 | x = Conv2D(8, (3, 3), padding='same', kernel_initializer='random_uniform')(inp) 113 | x = BatchNormalization()(x) 114 | out = Activation('relu')(x) 115 | model = Model(inputs=inp, outputs=out) 116 | return model 117 | 118 | 119 | def get_custom_model_with_other_model_as_layer(): 120 | inp1 = Input((28, 28, 3)) 121 | branch1 = Conv2D(4, (3, 3), padding='same', kernel_initializer='random_uniform')(inp1) 122 | branch1 = BatchNormalization()(branch1) 123 | branch1 = Activation('relu')(branch1) 124 | 125 | branch2 = Conv2D(4, (3, 3), padding='same', kernel_initializer='random_uniform')(inp1) 126 | branch2 = BatchNormalization()(branch2) 127 | branch2 = Activation('relu')(branch2) 128 | m1 = get_simple_submodel() 129 | m2 = get_simple_submodel() 130 | x1 = m1(branch1) 131 | x2 = m2(branch2) 132 | x = Concatenate(axis=-1, name='concat')([x1, x2]) 133 | x = Conv2D(32, (3, 3), padding='same', kernel_initializer='random_uniform')(x) 134 | custom_model = Model(inputs=inp1, outputs=x) 135 | return custom_model 136 | 137 | 138 | def get_small_model_with_other_model_as_layer(): 139 | inp_mask = Input(shape=(128, 128, 3)) 140 | pretrain_model_mask = MobileNet(input_shape=(128, 128, 3), include_top=False, weights='imagenet', pooling='avg') 141 | try: 142 | pretrain_model_mask.name = 'mobilenet' 143 | except: 144 | pretrain_model_mask._name = 'mobilenet' 145 | x = pretrain_model_mask(inp_mask) 146 | out = Dense(2, activation='sigmoid')(x) 147 | model = Model(inputs=inp_mask, outputs=[out]) 148 | return model 149 | 150 | 151 | def get_Conv2DTranspose_model(): 152 | inp = Input((28, 28, 4)) 153 | x = Conv2DTranspose(8, (3, 3), padding='same', kernel_initializer='random_uniform')(inp) 154 | x = BatchNormalization()(x) 155 | x = Conv2DTranspose(4, (3, 3), strides=(4, 4), padding='same', kernel_initializer='random_uniform')(x) 156 | x = BatchNormalization()(x) 157 | out = Activation('relu')(x) 158 | model = Model(inputs=inp, outputs=out) 159 | return model 160 | 161 | 162 | def get_RetinaNet_model(): 163 | from keras.utils import custom_object_scope 164 | from keras_resnet.layers import BatchNormalization 165 | from keras_retinanet.layers import UpsampleLike, Anchors, RegressBoxes, ClipBoxes, FilterDetections 166 | from keras_retinanet.initializers import PriorProbability 167 | from keras_retinanet import models 168 | from keras_retinanet.models.retinanet import retinanet_bbox 169 | 170 | custom_objects = { 171 | 'BatchNormalization': BatchNormalization, 172 | 'UpsampleLike': UpsampleLike, 173 | 'Anchors': Anchors, 174 | 'RegressBoxes': RegressBoxes, 175 | 'PriorProbability': PriorProbability, 176 | 'ClipBoxes': ClipBoxes, 177 | 'FilterDetections': FilterDetections, 178 | } 179 | 180 | with custom_object_scope(custom_objects): 181 | backbone = models.backbone('resnet50') 182 | model = backbone.retinanet(500) 183 | prediction_model = retinanet_bbox(model=model) 184 | # prediction_model.load_weights("...your weights here...") 185 | 186 | return prediction_model, custom_objects 187 | 188 | 189 | def get_simple_3d_model(): 190 | inp = Input((28, 28, 28, 4)) 191 | x = Conv3D(32, (3, 3, 3), padding='same', kernel_initializer='random_uniform')(inp) 192 | x = BatchNormalization()(x) 193 | x = Activation('relu')(x) 194 | x = Conv3D(32, (3, 3, 3), padding='same', kernel_initializer='random_uniform')(x) 195 | x = BatchNormalization()(x) 196 | out = Activation('relu')(x) 197 | model = Model(inputs=inp, outputs=out) 198 | return model 199 | 200 | 201 | def get_simple_1d_model(): 202 | inp = Input((256, 2)) 203 | x = Conv1D(32, 3, padding='same', kernel_initializer='random_uniform')(inp) 204 | x = BatchNormalization()(x) 205 | x = Activation('relu')(x) 206 | x = Conv1D(32, 3, padding='same', kernel_initializer='random_uniform')(x) 207 | x = BatchNormalization()(x) 208 | out = Activation('relu')(x) 209 | model = Model(inputs=inp, outputs=out) 210 | return model 211 | 212 | 213 | def get_tst_neural_net(type): 214 | model = None 215 | custom_objects = dict() 216 | if type == 'mobilenet_small': 217 | try: 218 | from keras.applications.mobilenet import MobileNet 219 | except: 220 | from tensorflow.keras.applications.mobilenet import MobileNet 221 | model = MobileNet((128, 128, 3), depth_multiplier=1, alpha=0.25, include_top=True, weights='imagenet') 222 | elif type == 'mobilenet': 223 | try: 224 | from keras.applications.mobilenet import MobileNet 225 | except: 226 | from tensorflow.keras.applications.mobilenet import MobileNet 227 | model = MobileNet((224, 224, 3), depth_multiplier=1, alpha=1.0, include_top=True, weights='imagenet') 228 | elif type == 'mobilenet_v2': 229 | try: 230 | from keras.applications.mobilenet_v2 import MobileNetV2 231 | except: 232 | from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 233 | model = MobileNetV2((224, 224, 3), alpha=1.4, include_top=True, weights='imagenet') 234 | elif type == 'resnet50': 235 | try: 236 | from keras.applications.resnet50 import ResNet50 237 | except: 238 | from tensorflow.keras.applications.resnet50 import ResNet50 239 | model = ResNet50(input_shape=(224, 224, 3), include_top=True, weights='imagenet') 240 | elif type == 'inception_v3': 241 | try: 242 | from keras.applications.inception_v3 import InceptionV3 243 | except: 244 | from tensorflow.keras.applications.inception_v3 import InceptionV3 245 | model = InceptionV3(input_shape=(299, 299, 3), include_top=True, weights='imagenet') 246 | elif type == 'inception_resnet_v2': 247 | try: 248 | from keras.applications.inception_resnet_v2 import InceptionResNetV2 249 | except: 250 | from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2 251 | model = InceptionResNetV2(input_shape=(299, 299, 3), include_top=True, weights='imagenet') 252 | elif type == 'xception': 253 | try: 254 | from keras.applications.xception import Xception 255 | except: 256 | from tensorflow.keras.applications.xception import Xception 257 | model = Xception(input_shape=(299, 299, 3), include_top=True, weights='imagenet') 258 | elif type == 'densenet121': 259 | try: 260 | from keras.applications.densenet import DenseNet121 261 | except: 262 | from tensorflow.keras.applications.densenet import DenseNet121 263 | model = DenseNet121(input_shape=(224, 224, 3), include_top=True, weights='imagenet') 264 | elif type == 'densenet169': 265 | try: 266 | from keras.applications.densenet import DenseNet169 267 | except: 268 | from tensorflow.keras.applications.densenet import DenseNet169 269 | model = DenseNet169(input_shape=(224, 224, 3), include_top=True, weights='imagenet') 270 | elif type == 'densenet201': 271 | try: 272 | from keras.applications.densenet import DenseNet201 273 | except: 274 | from tensorflow.keras.applications.densenet import DenseNet201 275 | model = DenseNet201(input_shape=(224, 224, 3), include_top=True, weights='imagenet') 276 | elif type == 'nasnetmobile': 277 | try: 278 | from keras.applications.nasnet import NASNetMobile 279 | except: 280 | from tensorflow.keras.applications.nasnet import NASNetMobile 281 | model = NASNetMobile(input_shape=(224, 224, 3), include_top=True, weights='imagenet') 282 | elif type == 'nasnetlarge': 283 | try: 284 | from keras.applications.nasnet import NASNetLarge 285 | except: 286 | from tensorflow.keras.applications.nasnet import NASNetLarge 287 | model = NASNetLarge(input_shape=(331, 331, 3), include_top=True, weights='imagenet') 288 | elif type == 'vgg16': 289 | try: 290 | from keras.applications.vgg16 import VGG16 291 | except: 292 | from tensorflow.keras.applications.vgg16 import VGG16 293 | model = VGG16(input_shape=(224, 224, 3), include_top=False, pooling='avg', weights='imagenet') 294 | elif type == 'vgg19': 295 | try: 296 | from keras.applications.vgg19 import VGG19 297 | except: 298 | from tensorflow.keras.applications.vgg19 import VGG19 299 | model = VGG19(input_shape=(224, 224, 3), include_top=False, pooling='avg', weights='imagenet') 300 | elif type == 'multi_io': 301 | model = get_custom_multi_io_model() 302 | elif type == 'multi_model_layer_1': 303 | model = get_custom_model_with_other_model_as_layer() 304 | elif type == 'multi_model_layer_2': 305 | model = get_small_model_with_other_model_as_layer() 306 | elif type == 'Conv2DTranspose': 307 | model = get_Conv2DTranspose_model() 308 | elif type == 'RetinaNet': 309 | model, custom_objects = get_RetinaNet_model() 310 | elif type == 'conv3d_model': 311 | model = get_simple_3d_model() 312 | elif type == 'conv1d_model': 313 | model = get_simple_1d_model() 314 | return model, custom_objects 315 | 316 | 317 | if __name__ == '__main__': 318 | models_to_test = ['mobilenet_small', 'mobilenet', 'mobilenet_v2', 'resnet50', 'inception_v3', 319 | 'inception_resnet_v2', 'xception', 'densenet121', 'densenet169', 'densenet201', 320 | 'nasnetmobile', 'nasnetlarge', 'multi_io', 'multi_model_layer_1', 'multi_model_layer_2', 321 | 'Conv2DTranspose', 'RetinaNet', 'conv3d_model', 'conv1d_model'] 322 | # Comment line below for full model testing 323 | models_to_test = ['conv1d_model'] 324 | verbose = True 325 | 326 | for model_name in models_to_test: 327 | print('Go for: {}'.format(model_name)) 328 | model, custom_objects = get_tst_neural_net(model_name) 329 | if verbose: 330 | print(model.summary()) 331 | start_time = time.time() 332 | with custom_object_scope(custom_objects): 333 | model_reduced = reduce_keras_model(model, verbose=verbose) 334 | print("Reduction time: {:.2f} seconds".format(time.time() - start_time)) 335 | if verbose: 336 | print(model_reduced.summary()) 337 | print('Initial model number layers: {}'.format(len(model.layers))) 338 | print('Reduced model number layers: {}'.format(len(model_reduced.layers))) 339 | print('Compare models...') 340 | if model_name in ['nasnetlarge', 'deeplab_v3plus_mobile', 'deeplab_v3plus_xception']: 341 | max_error = compare_two_models_results(model, model_reduced, test_number=10000, max_batch=128) 342 | elif model_name in ['RetinaNet', 'conv3d_model', 'conv1d_model']: 343 | max_error = compare_two_models_results(model, model_reduced, test_number=1280, max_batch=128) 344 | elif model_name in ['mobilenet_small']: 345 | max_error = compare_two_models_results(model, model_reduced, test_number=1000, max_batch=1000) 346 | else: 347 | max_error = compare_two_models_results(model, model_reduced, test_number=10000, max_batch=10000) 348 | K.clear_session() 349 | if max_error > 1e-04: 350 | print('Possible error just happen! Max error value: {}'.format(max_error)) 351 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | try: 2 | from setuptools import setup 3 | except ImportError: 4 | from distutils.core import setup 5 | 6 | setup( 7 | name='kito', 8 | version='1.0.5', 9 | author='Roman Sol (ZFTurbo)', 10 | packages=['kito', ], 11 | url='https://github.com/ZFTurbo/Keras-inference-time-optimizer', 12 | license='MIT License', 13 | description='Keras inference time optimizer', 14 | long_description='This code takes on input trained Keras model and optimize layer structure and weights ' 15 | 'in such a way that model became much faster (~10-30%), but works identically to ' 16 | 'initial model. It can be extremely useful in case you need to process large amount ' 17 | 'of images with trained model. Reduce operation was tested on all Keras models zoo. ' 18 | 'More details: https://github.com/ZFTurbo/Keras-inference-time-optimizer', 19 | install_requires=[ 20 | 'keras', 21 | "numpy", 22 | ], 23 | ) 24 | --------------------------------------------------------------------------------