├── .DS_Store ├── assets └── fig1.png ├── .idea ├── misc.xml ├── vcs.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── Eagleeye_Tensorflow.iml └── workspace.xml ├── network ├── __init__.py ├── utils.py ├── vgg.py ├── densenet.py ├── resnet.py └── mobilenetv2.py ├── requirements.txt ├── core ├── pruner │ ├── l1norm.py │ └── graph_wrapper.py ├── utils.py ├── strategy_generation.py └── eagleeye.py ├── main.py ├── train_network.py └── README.md /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da2so/Eagleeye_Tensorflow/HEAD/.DS_Store -------------------------------------------------------------------------------- /assets/fig1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/da2so/Eagleeye_Tensorflow/HEAD/assets/fig1.png -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from network.mobilnetv2 import mobilenetv2 2 | from network.vgg import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn 3 | from network.resnet import resnet18, resnet34, resnet50,resnet101 ,resnet152 4 | from network.densenet import densenet40_f12, densenet100_f12, densenet100_f24 5 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.5 2 | tensorflow-gpu==2.3.0 3 | absl-py==0.11.0 4 | astor==0.8.1 5 | astunparse==1.6.3 6 | chardet==3.0.4 7 | h5py==2.10.0 8 | importlib-metadata==2.0.0 9 | Keras-Applications==1.0.8 10 | Keras-Preprocessing==1.1.2 11 | matplotlib==3.3.2 12 | pandas==1.1.4 13 | Pillow==8.0.1 14 | pydot==1.4.1 15 | tensorflow-estimator==2.3.0 16 | six==1.15.0 17 | argparse==1.4.0 18 | -------------------------------------------------------------------------------- /.idea/Eagleeye_Tensorflow.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /core/pruner/l1norm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from tensorflow.keras import layers 4 | import tensorflow as tf 5 | 6 | from core.pruner.graph_wrapper import GraphWrapper 7 | 8 | 9 | def l1_pruning(model, channel_config): 10 | del_layer_dict={} 11 | 12 | for idx, prune_rate in channel_config.items(): 13 | w_copy=model.layers[idx].get_weights()[0] 14 | 15 | #l1 norm for weights of a layer 16 | w_copy = tf.math.abs(w_copy) 17 | w_copy = tf.math.reduce_sum(w_copy, axis=[0,1,2]) 18 | 19 | num_delete = int(np.shape(w_copy)[0]*(prune_rate)) 20 | del_list = tf.math.top_k(tf.reshape(-w_copy, [-1]), num_delete)[1].numpy() 21 | 22 | del_layer_dict[idx]=del_list 23 | 24 | #Reconstruct the base model uisng graph wrapper 25 | graph_obj=GraphWrapper(model) 26 | pruned_model=graph_obj.build(del_layer_dict) 27 | 28 | return pruned_model 29 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 10 | 11 | 13 | 14 | 15 | 16 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 1616838752505 31 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /network/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras.datasets import cifar10, cifar100 5 | from tensorflow.keras.utils import to_categorical 6 | 7 | import network 8 | 9 | 10 | def load_dataset(dataset_name, batch_size): 11 | 12 | if dataset_name =='cifar10': 13 | (x_train, y_train), (x_test, y_test) = cifar10.load_data() 14 | num_classes=10 15 | img_shape=[32,32,3] 16 | 17 | elif dataset_name == 'cifar100': 18 | (x_train, y_train), (x_test, y_test) = cifar100.load_data() 19 | num_classes=100 20 | img_shape=[32,32,3] 21 | elif dataset_name == 'imagenet': 22 | raise ValueError('Not yet implemented') 23 | else: 24 | raise ValueError('Invalid dataset name : {}'.format(dataset_name)) 25 | 26 | img_shape=[32,32,3] 27 | normalize = [ [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]] 28 | 29 | x_train = x_train.astype('float32') 30 | x_test = x_test.astype('float32') 31 | 32 | x_train /= 255. 33 | x_test /= 255. 34 | 35 | mean= normalize[0] 36 | std= normalize[1] 37 | 38 | for i in range(3): 39 | x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i] 40 | x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i] 41 | 42 | 43 | y_train = to_categorical(y_train) 44 | y_test = to_categorical(y_test) 45 | 46 | return x_train ,y_train , x_test, y_test,num_classes,img_shape 47 | 48 | def load_model_arch(model_name,img_shape,num_classes): 49 | 50 | try: 51 | model_func = getattr(network, model_name) 52 | except: 53 | raise ValueError('Invalid model name : {}'.format(model_name)) 54 | 55 | model = model_func(img_shape,num_classes) 56 | return model 57 | def lr_scheduler(epoch): 58 | lr = 1e-3 59 | if epoch >180: 60 | lr*=0.001 61 | elif epoch >120: 62 | lr*=0.01 63 | elif epoch >80: 64 | lr*=0.1 65 | 66 | return lr -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | #from keras_flops import get_flops 2 | 3 | import tensorflow as tf 4 | from tensorflow.keras.applications.resnet50 import ResNet50 5 | from tensorflow.keras.applications.vgg19 import VGG19 6 | from tensorflow.keras.applications.resnet_v2 import ResNet50V2 7 | from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2 8 | from tensorflow.keras.applications.mobilenet import MobileNet 9 | from tensorflow.keras.datasets import cifar10, cifar100 10 | from tensorflow.python.framework.convert_to_constants import ( 11 | convert_variables_to_constants_v2_as_graph, 12 | ) 13 | from tensorflow.keras import Sequential, Model 14 | 15 | def load_network(model_path): 16 | if '.h5' in model_path: 17 | model = tf.keras.models.load_model(model_path) 18 | else: 19 | if 'resnet50v2' in model_path: 20 | model = ResNet50V2(weights='imagenet') 21 | elif 'mobilenet' in model_path: 22 | model = MobileNet(weights='imagenet') 23 | elif 'mobilenetv2' in model_path: 24 | model = MobileNetV2(weights='imagenet') 25 | elif 'vgg19' in model_path: 26 | model = VGG19(weights='imagenet') 27 | else: 28 | raise ValueError('Invalid model name : {}'.format(model_path)) 29 | return model 30 | 31 | 32 | def get_flops(model, batch_size): 33 | 34 | if not isinstance(model, (Sequential, Model)): 35 | raise KeyError( 36 | "model arguments must be tf.keras.Model or tf.keras.Sequential instanse" 37 | ) 38 | 39 | if batch_size is None: 40 | batch_size = 1 41 | 42 | inputs = [ 43 | tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype) for inp in model.inputs 44 | ] 45 | real_model = tf.function(model).get_concrete_function(inputs) 46 | frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model) 47 | 48 | # Calculate FLOPS with tf.profiler 49 | run_meta = tf.compat.v1.RunMetadata() 50 | opts = (tf.compat.v1.profiler.ProfileOptionBuilder() 51 | .with_empty_output() 52 | .build()) 53 | flops = tf.compat.v1.profiler.profile( 54 | graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts 55 | ) 56 | 57 | return flops.total_float_ops 58 | 59 | def count_flops(model): 60 | return get_flops(model, batch_size=1) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 4 | 5 | import tensorflow as tf 6 | 7 | from core.eagleeye import EagleEye 8 | 9 | def main(): 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10','cifar100']) 13 | parser.add_argument('--model_path', type=str, default='./saved_models/cifar10_resnet34.h5', help = 'model path') 14 | parser.add_argument('--bs', type=int, default=128, help = 'batch size') 15 | parser.add_argument('--epochs', type=int, default=100, help='epoch while fine-tuning') 16 | parser.add_argument('--lr', type=float, default=0.001 , help='learning rate') 17 | parser.add_argument('--min_rate', type=float, default=0.0, help='define minimum of search space') 18 | parser.add_argument('--max_rate', type=float, default=0.5, help='define maximum of search space') 19 | parser.add_argument('--flops_target', type=float, default=0.5, help='flops constraints for pruning') 20 | parser.add_argument('--num_candidates', type=int, default=15, help='the number of candidates') 21 | parser.add_argument('--data_augmentation', type=bool, default=True, help='do data augmentation?') 22 | parser.add_argument('--result_dir', type=str, default='./result', help='result directory for a prunned model') 23 | 24 | 25 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 26 | os.environ['TF2_BEHAVIOR'] = '1' 27 | os.environ['CUDA_VISIBLE_DEVICES']= '0' 28 | 29 | config = tf.compat.v1.ConfigProto() 30 | config.gpu_options.per_process_gpu_memory_fraction = 0.8 31 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config)) 32 | 33 | args = parser.parse_args() 34 | 35 | eagleeye_obj=EagleEye(dataset_name =args.dataset_name, 36 | model_path=args.model_path, 37 | bs=args.bs, 38 | epochs=args.epochs, 39 | lr=args.lr, 40 | min_rate=args.min_rate, 41 | max_rate=args.max_rate, 42 | flops_target=args.flops_target, 43 | num_candidates=args.num_candidates, 44 | result_dir=args.result_dir, 45 | data_augmentation=args.data_augmentation 46 | ) 47 | 48 | eagleeye_obj.build() 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /network/vgg.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend 3 | from tensorflow.keras import layers 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, ZeroPadding2D 6 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation 7 | 8 | 9 | class VGG(object): 10 | def build(input_shape, num_classes, cfg, batch_norm): 11 | rows = input_shape[0] 12 | cols = input_shape[1] 13 | 14 | img_input = layers.Input(shape=input_shape) 15 | 16 | x=img_input 17 | 18 | for v in cfg: 19 | if v == 'M': 20 | x = layers.MaxPooling2D((2, 2), strides=(2, 2))(x) 21 | else: 22 | x = layers.Conv2D(v, (3, 3), 23 | padding='same')(x) 24 | if batch_norm == True: 25 | x = layers.BatchNormalization(axis=-1)(x) 26 | x = Activation("relu")(x) 27 | 28 | 29 | x = layers.Flatten(name='flatten')(x) 30 | x = layers.Dense(4096, activation='relu', name='fc1')(x) 31 | x = layers.Dense(4096, activation='relu', name='fc2')(x) 32 | x = layers.Dense(num_classes, activation='softmax', name='predictions')(x) 33 | 34 | model = Model(img_input, x) 35 | 36 | return model 37 | 38 | cfgs = { 39 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 40 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 41 | 'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 42 | } 43 | 44 | def vgg11(input_shape, classes): 45 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['A'], batch_norm=False) 46 | 47 | def vgg11_bn(input_shape, classes): 48 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['A'], batch_norm=True) 49 | 50 | def vgg13(input_shape, classes): 51 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['B'], batch_norm=False) 52 | 53 | def vgg13_bn(input_shape, classes): 54 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['B'], batch_norm=True) 55 | 56 | def vgg16(input_shape, classes): 57 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['C'], batch_norm=False) 58 | 59 | def vgg16_bn(input_shape, classes): 60 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['C'], batch_norm=True) 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /core/strategy_generation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | from collections import defaultdict 4 | 5 | from tensorflow.python.keras import layers 6 | 7 | def get_strartegy_generation(model, min_rate, max_rate): 8 | 9 | np.random.seed(int(time.time())) 10 | channel_config = defaultdict() 11 | layer_info = defaultdict() 12 | 13 | #Random strategy 14 | for i, layer in enumerate(model.layers): 15 | 16 | layer_info[id(layer.output)] = {'layer_input':layer.input, 'idx':i} 17 | if isinstance(layer, layers.convolutional.Conv2D): 18 | random_rate = (max_rate - min_rate) * (np.random.rand(1)) + min_rate 19 | layer_info[id(layer.output)].update( {'prune_rate':random_rate}) 20 | channel_config[i] = random_rate 21 | 22 | 23 | #In case of some architectures (such as resnet) 24 | for i, layer in enumerate(model.layers): 25 | 26 | if isinstance(layer, layers.Add) or \ 27 | isinstance(layer, layers.Subtract) or \ 28 | isinstance(layer, layers.Multiply) or \ 29 | isinstance(layer, layers.Average) or \ 30 | isinstance(layer, layers.Maximum) or \ 31 | isinstance(layer, layers.Minimum): 32 | 33 | idx_list = list() 34 | prune_list = list() 35 | is_specific_layer = -1 36 | 37 | for idx, in_layer in enumerate(layer.input): 38 | #Find conv layer or specific layers (i.e. Add, Substract, ...) 39 | while 'prune_rate' not in layer_info[id(in_layer)]: 40 | in_layer = layer_info[id(in_layer)]['layer_input'] 41 | 42 | #Find specific layers that mentioned in abvoe comments 43 | if layer_info[id(in_layer)]['idx'] not in channel_config: 44 | is_specific_layer = idx 45 | 46 | idx_list.append(layer_info[id(in_layer)]['idx']) 47 | prune_list.append(layer_info[id(in_layer)]['prune_rate']) 48 | 49 | #if one of the input layers is a specific layer, change the minimum rate of pruning 50 | if is_specific_layer != -1: 51 | min_prune_rate = prune_list[is_specific_layer] 52 | else: 53 | min_prune_rate = min(prune_list) 54 | 55 | #Rearrange the prune rate for conv layers 56 | for idx in idx_list: 57 | if isinstance(model.layers[idx], layers.convolutional.Conv2D): 58 | channel_config[idx] = min_prune_rate 59 | layer_info[id(layer.output)]['prune_rate'] = min_prune_rate 60 | 61 | return channel_config 62 | -------------------------------------------------------------------------------- /network/densenet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | 7 | import tensorflow as tf 8 | from tensorflow.keras import backend 9 | from tensorflow.keras import layers 10 | from tensorflow.keras.models import Model 11 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, ZeroPadding2D 12 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation 13 | 14 | 15 | def dense_block(x, filter_num, blocks, name): 16 | 17 | for i in range(blocks): 18 | x = conv_block(x, filter_num, name=name + '_block' + str(i + 1)) 19 | return x 20 | 21 | 22 | def transition_block(x, reduction, name): 23 | 24 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 25 | x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, 26 | name=name + '_bn')(x) 27 | x = layers.Activation('relu', name=name + '_relu')(x) 28 | x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1, 29 | use_bias=False, 30 | padding='same', 31 | name=name + '_conv')(x) 32 | x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x) 33 | return x 34 | 35 | 36 | def conv_block(x, filter_num, name): 37 | 38 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 39 | x1 = layers.BatchNormalization(axis=bn_axis, 40 | epsilon=1.001e-5, 41 | name=name + '_0_bn')(x) 42 | x1 = layers.Activation('relu', name=name + '_0_relu')(x1) 43 | x1 = layers.Conv2D(filter_num, 3, 44 | use_bias=False, 45 | padding='same', 46 | name=name + '_1_conv')(x1) 47 | x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1]) 48 | return x 49 | 50 | 51 | def DenseNet(blocks, 52 | filter_num=None, 53 | input_shape=None, 54 | reduce=None, 55 | classes=10, 56 | **kwargs): 57 | 58 | 59 | img_input = layers.Input(shape=input_shape) 60 | 61 | 62 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 63 | 64 | x = layers.Conv2D(filter_num, 3, use_bias=False, padding='same',name='conv1/conv')(img_input) 65 | 66 | x = dense_block(x, filter_num, blocks[0], name='conv2') 67 | x = transition_block(x, reduce, name='pool2') 68 | x = dense_block(x, filter_num, blocks[1], name='conv3') 69 | x = transition_block(x, reduce, name='pool3') 70 | x = dense_block(x, filter_num, blocks[2], name='conv4') 71 | 72 | 73 | x = layers.BatchNormalization( 74 | axis=bn_axis, epsilon=1.001e-5, name='bn')(x) 75 | x = layers.Activation('relu', name='relu')(x) 76 | 77 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x) 78 | x = layers.Dense(classes, activation='softmax', name='fc1000')(x) 79 | 80 | 81 | # Create model. 82 | if blocks == [12, 12, 12] and filter_num == 12: 83 | model = Model(img_input, x, name='densenet40_f12') 84 | elif blocks == [32, 32, 32] and filter_num == 12 : 85 | model = Model(img_input, x, name='densenet100_f12') 86 | elif blocks == [32, 32, 32] and filter_num == 24: 87 | model = Model(img_input, x, name='densenet100_f24') 88 | else: 89 | model = Model(img_input, x, name='densenet') 90 | 91 | model.summary() 92 | return model 93 | 94 | 95 | def densenet40_f12(input_shape=None,classes=10): 96 | return DenseNet(blocks = [12, 12, 12], 97 | input_shape = input_shape, 98 | filter_num = 12, 99 | reduce = 1.0, 100 | classes = classes 101 | ) 102 | 103 | def densenet100_f12(input_shape=None,classes=10): 104 | return DenseNet(blocks = [32, 32, 32], 105 | input_shape = input_shape, 106 | filter_num = 12, 107 | reduce = 0.5, 108 | classes = classes 109 | ) 110 | def densenet100_f24(input_shape = None,classes=10): 111 | return DenseNet(blocks = [32, 32, 32], 112 | input_shape = input_shape, 113 | filter_num = 24, 114 | reduce = 0.5, 115 | classes = classes 116 | ) 117 | 118 | -------------------------------------------------------------------------------- /network/resnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import six 3 | 4 | import tensorflow as tf 5 | from tensorflow.keras import backend 6 | from tensorflow.keras import layers 7 | from tensorflow.keras.models import Model 8 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, ZeroPadding2D 9 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation 10 | from tensorflow.keras.regularizers import l2 11 | 12 | global ROW_AXIS 13 | global COL_AXIS 14 | global CHANNEL_AXIS 15 | 16 | ROW_AXIS=1 17 | COL_AXIS=2 18 | CHANNEL_AXIS=3 19 | 20 | def _bn_relu(input): 21 | """ 22 | BN -> relu block 23 | """ 24 | norm = BatchNormalization()(input) 25 | return Activation("relu")(norm) 26 | 27 | def _shortcut(input, residual,init_strides): 28 | """ 29 | Adds a shortcut between input and residual block and merges them with "sum" 30 | """ 31 | 32 | input_shape = np.shape(input) 33 | residual_shape = np.shape(residual) 34 | equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS] 35 | 36 | shortcut = input 37 | # 1 X 1 conv if shape is different. Else identity. 38 | if init_strides != 1 or not equal_channels: 39 | shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS], 40 | kernel_size=1, 41 | strides=init_strides, 42 | padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input) 43 | 44 | 45 | out=Add()([shortcut, residual]) 46 | 47 | return Activation("relu")(out) 48 | 49 | def basic_block(filters, init_strides=1, is_first_block_of_first_layer=False): 50 | """ 51 | Basic 3 X 3 convolution blocks for use on resnets with layers <= 34. 52 | """ 53 | def f(input): 54 | 55 | out = Conv2D(filters=filters, kernel_size=3, strides=init_strides,padding="same",kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input) 56 | out= _bn_relu(out) 57 | 58 | out = Conv2D(filters=filters, kernel_size=3, strides=1,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(out) 59 | residual = BatchNormalization()(out) 60 | 61 | return _shortcut(input, residual,init_strides) 62 | 63 | return f 64 | 65 | 66 | def bottleneck(filters, init_strides=1, is_first_block_of_first_layer=False): 67 | """ 68 | Bottleneck architecture for > 34 layer resnet. 69 | A final conv layer of filters * 4 70 | """ 71 | def f(input): 72 | 73 | out = Conv2D(filters=filters, kernel_size=1, strides=1,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input) 74 | out = _bn_relu(out) 75 | 76 | out = Conv2D(filters=filters, kernel_size=3, strides=init_strides,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(out) 77 | out = _bn_relu(out) 78 | 79 | out = Conv2D(filters=filters*4, kernel_size=1, strides=1,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(out) 80 | residual = BatchNormalization()(out) 81 | 82 | return _shortcut(input, residual,init_strides) 83 | 84 | return f 85 | 86 | 87 | def _residual_block(block_function, filters, repetitions, is_first_layer=False): 88 | """ 89 | Builds a residual block with repeating bottleneck blocks. 90 | """ 91 | def f(input): 92 | for i in range(repetitions): 93 | 94 | init_strides = 1 95 | if i == 0 and not is_first_layer: 96 | init_strides = 2 97 | input = block_function(filters=filters, init_strides=init_strides, 98 | is_first_block_of_first_layer=(is_first_layer and i == 0))(input) 99 | return input 100 | 101 | return f 102 | 103 | 104 | def _get_block(identifier): 105 | if isinstance(identifier, six.string_types): 106 | res = globals().get(identifier) 107 | if not res: 108 | raise ValueError('Invalid {}'.format(identifier)) 109 | return res 110 | return identifier 111 | 112 | 113 | class ResNet(object): 114 | def build(input_shape, num_outputs, block_fn, repetitions): 115 | 116 | # Load block function 117 | block_fn = _get_block(block_fn) 118 | 119 | filters_num = 16 120 | 121 | input = Input(shape=input_shape) 122 | conv1 = Conv2D(filters=filters_num, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input) 123 | norm1 = _bn_relu(conv1) 124 | 125 | block= norm1 126 | for i, r in enumerate(repetitions): 127 | 128 | block = _residual_block(block_fn, filters=filters_num, repetitions=r, is_first_layer=(i == 0))(block) 129 | filters_num*= 2 130 | 131 | block= _bn_relu(block) 132 | block_shape = np.shape(block) 133 | pool2 = AveragePooling2D(pool_size=block_shape[ROW_AXIS])(block) 134 | flatten1 = Flatten()(pool2) 135 | dense = Dense(units=num_outputs,kernel_initializer='he_normal')(flatten1) 136 | out = Activation("softmax")(dense) 137 | model = Model(inputs=input, outputs=out) 138 | return model 139 | 140 | 141 | 142 | def resnet18(input_shape, num_outputs): 143 | return ResNet.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2]) 144 | 145 | def resnet34(input_shape, num_outputs): 146 | return ResNet.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3]) 147 | 148 | def resnet50(input_shape, num_outputs): 149 | return ResNet.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3]) 150 | 151 | def resnet101(input_shape, num_outputs): 152 | return ResNet.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3]) 153 | 154 | def resnet152(input_shape, num_outputs): 155 | return ResNet.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3]) -------------------------------------------------------------------------------- /network/mobilenetv2.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend 3 | from tensorflow.keras import layers 4 | from tensorflow.keras.models import Model 5 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, DepthwiseConv2D 6 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation, Dropout, GlobalAveragePooling2D 7 | from tensorflow.keras import regularizers 8 | 9 | 10 | 11 | def _make_divisible( v, divisor, min_value=None): 12 | if min_value is None: 13 | min_value = divisor 14 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 15 | 16 | # Make sure that round down does not go down by more than 10%. 17 | if new_v < 0.9 * v: 18 | new_v += divisor 19 | return new_v 20 | 21 | def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id): 22 | prefix = 'block_{}_'.format(block_id) 23 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1 24 | 25 | in_channels = backend.int_shape(inputs)[channel_axis] 26 | pointwise_conv_filters = int(filters * alpha) 27 | pointwise_filters = _make_divisible(pointwise_conv_filters, 8) 28 | x = inputs 29 | 30 | # Expand 31 | if block_id: 32 | x = Conv2D(expansion * in_channels, kernel_size=1, strides=1, padding='same', use_bias=False, activation=None, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name=prefix + 'expand')(x) 33 | x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'expand_BN')(x) 34 | x = ReLU(6., name=prefix + 'expand_relu')(x) 35 | else: 36 | prefix = 'expanded_conv_' 37 | 38 | # Depthwise 39 | x = DepthwiseConv2D(kernel_size=3, strides=stride, activation=None, use_bias=False, padding='same', kernel_initializer="he_normal", depthwise_regularizer=regularizers.l2(4e-5), name=prefix + 'depthwise')(x) 40 | x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'depthwise_BN')(x) 41 | x = ReLU(6., name=prefix + 'depthwise_relu')(x) 42 | 43 | # Project 44 | x = Conv2D(pointwise_filters, kernel_size=1, strides=1, padding='same', use_bias=False, activation=None, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name=prefix + 'project')(x) 45 | x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'project_BN')(x) 46 | 47 | 48 | if in_channels == pointwise_filters and stride == 1: 49 | return Add(name=prefix + 'add')([inputs, x]) 50 | return x 51 | 52 | class MobileNetV2(object): 53 | def build(input_shape=None,alpha=1.0, classes=10,**kwargs): 54 | rows = input_shape[0] 55 | cols = input_shape[1] 56 | print(input_shape) 57 | img_input = layers.Input(shape=input_shape) 58 | 59 | first_block_filters =_make_divisible(32 * alpha, 8) 60 | 61 | x = Conv2D(first_block_filters, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name='Conv1')(img_input) 62 | #x = BatchNormalization(epsilon=1e-3, momentum=0.999, name='bn_Conv1')(x) 63 | #x = ReLU(6., name='Conv1_relu')(x) 64 | 65 | x = _inverted_res_block(x, filters=16, alpha=alpha, stride=1, expansion=1, block_id=0 ) 66 | 67 | x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=1 ) 68 | x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=2 ) 69 | 70 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=2, expansion=6, block_id=3 ) 71 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=4 ) 72 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=5 ) 73 | 74 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=2, expansion=6, block_id=6 ) 75 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=7 ) 76 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=8 ) 77 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=9 ) 78 | x = Dropout(rate=0.25)(x) 79 | 80 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=10) 81 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=11) 82 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=12) 83 | x = Dropout(rate=0.25)(x) 84 | 85 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=2, expansion=6, block_id=13) 86 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=14) 87 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=15) 88 | x = Dropout(rate=0.25)(x) 89 | 90 | x = _inverted_res_block(x, filters=320, alpha=alpha, stride=1, expansion=6, block_id=16) 91 | x = Dropout(rate=0.25)(x) 92 | 93 | if alpha > 1.0: 94 | last_block_filters = _make_divisible(1280 * alpha, 8) 95 | else: 96 | last_block_filters = 1280 97 | 98 | x = layers.Conv2D(last_block_filters, 99 | kernel_size=1, 100 | use_bias=False, 101 | name='Conv_1')(x) 102 | x = layers.BatchNormalization( epsilon=1e-3, 103 | momentum=0.999, 104 | name='Conv_1_bn')(x) 105 | x = layers.ReLU(6., name='out_relu')(x) 106 | 107 | x = layers.GlobalAveragePooling2D()(x) 108 | x = layers.Dense(classes, activation='softmax', 109 | use_bias=True, name='Logits')(x) 110 | # Create model. 111 | model = Model(img_input, x, name='mobilenetv2_%0.2f_%s' % (alpha, rows)) 112 | 113 | 114 | return model 115 | 116 | 117 | 118 | def mobilenetv2(input_shape, classes): 119 | return MobileNetV2.build(input_shape=input_shape, classes=classes) -------------------------------------------------------------------------------- /train_network.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 3 | import numpy as np 4 | import argparse 5 | 6 | import tensorflow as tf 7 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 8 | from tensorflow.keras.optimizers import Adam , SGD 9 | 10 | from network.utils import load_dataset, load_model_arch, lr_scheduler 11 | class TrainTeacher(object): 12 | def __init__(self,dataset_name, model_name, batch_size, epochs, lr, save_dir, data_augmentation, metrics='accuracy'): 13 | self.dataset_name = dataset_name 14 | self.batch_size = batch_size 15 | self.model_name = model_name 16 | self.data_augmentation = data_augmentation 17 | if not os.path.exists(save_dir): 18 | os.makedirs(save_dir) 19 | self.save_path = f'{save_dir}{self.dataset_name}_{self.model_name}.h5' 20 | self.epochs = epochs 21 | self.metrics = metrics 22 | self.optimizer = Adam(learning_rate=lr) 23 | #self.optimizer= SGD(lr=2e-2, momentum=0.9, decay=0.0, nesterov=False) 24 | self.loss = tf.keras.losses.CategoricalCrossentropy() 25 | 26 | 27 | #Load train and validation dataset 28 | self.x_train, self.y_train, self.x_test, self.y_test, self.num_classes , self.img_shape = load_dataset(self.dataset_name, self.batch_size) 29 | 30 | #Load model architecture 31 | self.model = load_model_arch(self.model_name,self.img_shape,self.num_classes) 32 | 33 | 34 | #Define callback function 35 | def get_callback_list(self,early_stop=True, lr_reducer=True): 36 | 37 | callback_list = list() 38 | 39 | if early_stop == True: 40 | callback_list.append(tf.keras.callbacks.EarlyStopping(min_delta=0, patience=20, verbose=2, mode='auto')) 41 | if lr_reducer == True: 42 | callback_list.append(tf.keras.callbacks.ReduceLROnPlateau(factor=0.1, cooldown=0, patience=10, min_lr=0.5e-6)) 43 | 44 | 45 | return callback_list 46 | 47 | def train(self): 48 | self.model.compile(loss= self.loss, optimizer=self.optimizer, metrics=[self.metrics]) 49 | callback_list = self.get_callback_list() 50 | 51 | if self.data_augmentation == True: 52 | 53 | datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset 54 | samplewise_center=False, # set each sample mean to 0 55 | featurewise_std_normalization=False, # divide inputs by std of the dataset 56 | samplewise_std_normalization=False, # divide each input by its std 57 | zca_whitening=False, # apply ZCA whitening 58 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 59 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 60 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 61 | horizontal_flip=True, # randomly flip images 62 | vertical_flip=False) # randomly flip images 63 | 64 | datagen.fit(self.x_train) 65 | 66 | self.model.fit(datagen.flow(self.x_train, self.y_train, batch_size=self.batch_size), \ 67 | steps_per_epoch=len(self.x_train) / self.batch_size,epochs=self.epochs, \ 68 | validation_data=(self.x_test,self.y_test), callbacks=callback_list) 69 | 70 | else: 71 | self.model.fit(self.x_train, self.y_train, batch_size=self.batch_size, epochs=self.epochs, \ 72 | validation_data=(self.x_test,self.y_test), callbacks=callback_list) 73 | 74 | 75 | def test(self): 76 | self.model.compile(optimizer=self.optimizer, loss=self.loss, metrics=self.metrics) 77 | self.val_dataset = tf.data.Dataset.from_tensor_slices((self.x_test, self.y_test)).batch(self.batch_size) 78 | scores = self.model.evaluate(self.val_dataset,batch_size=self.batch_size) 79 | print('Test loss:', scores[0]) 80 | print('Test accuracy:', scores[1]) 81 | 82 | #save model 83 | tf.keras.models.save_model( 84 | model=self.model, 85 | filepath=self.save_path, 86 | include_optimizer=False 87 | ) 88 | 89 | def build(self): 90 | print(f'loading {self.model_name}..\n') 91 | self.train() 92 | self.test() 93 | 94 | 95 | def main(): 96 | 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10', 'cifar100'] help='Dataset [ "cifar10", "cifar100" ]') 99 | parser.add_argument('--model_name', type=str, default='mobilenetv2', help='Model name') 100 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size') 101 | parser.add_argument('--epochs', type=int, default=200, help='Epochs') 102 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 103 | parser.add_argument('--save_dir', type=str, default='./saved_models/', help='Saved model path') 104 | parser.add_argument('--data_augmentation', type=bool, default=True, help='Saved model path') 105 | 106 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 107 | os.environ['TF2_BEHAVIOR'] = '1' 108 | os.environ['CUDA_VISIBLE_DEVICES']= '0' 109 | 110 | config = tf.compat.v1.ConfigProto() 111 | config.gpu_options.per_process_gpu_memory_fraction = 0.8 112 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config)) 113 | 114 | args = parser.parse_args() 115 | 116 | trainer=TrainTeacher(dataset_name=args.dataset_name, 117 | model_name=args.model_name, 118 | batch_size=args.batch_size, 119 | epochs=args.epochs, 120 | lr=args.lr, 121 | save_dir=args.save_dir, 122 | data_augmentation=args.data_augmentation, 123 | metrics='accuracy' 124 | ) 125 | 126 | trainer.build() 127 | 128 | if __name__ == '__main__': 129 | 130 | main() -------------------------------------------------------------------------------- /core/eagleeye.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import tensorflow as tf 5 | import tensorflow.keras as k 6 | import tensorflow.keras.backend as b 7 | from tensorflow.python.keras.utils.layer_utils import count_params 8 | from tensorflow.keras.optimizers import Adam 9 | from tensorflow.keras.preprocessing.image import ImageDataGenerator 10 | 11 | from core.utils import load_network ,count_flops 12 | from network.utils import load_dataset, lr_scheduler 13 | from core.strategy_generation import get_strartegy_generation 14 | from core.pruner.l1norm import l1_pruning 15 | 16 | class EagleEye(object): 17 | def __init__(self, 18 | dataset_name, 19 | model_path, 20 | bs, 21 | epochs, 22 | lr, 23 | min_rate, 24 | max_rate, 25 | flops_target, 26 | num_candidates, 27 | result_dir, 28 | data_augmentation): 29 | 30 | self.dataset_name = dataset_name 31 | self.bs = bs 32 | self.epochs = epochs 33 | self.lr = lr 34 | self.min_rate = min_rate 35 | self.max_rate = max_rate 36 | self.flops_target = flops_target 37 | self.result_dir = result_dir 38 | if not os.path.exists(self.result_dir): 39 | os.makedirs(self.result_dir) 40 | self.model_path = model_path 41 | self.num_candidates = num_candidates 42 | self.data_augmentation = data_augmentation 43 | 44 | #Load a base model 45 | self.net = load_network(model_path) 46 | 47 | #Load a dataset 48 | self.x_train, self.y_train, self.x_test, self.y_test, self.num_classes, self.img_shape = load_dataset(self.dataset_name, self.bs) 49 | 50 | def build(self): 51 | 52 | pruned_model_list = list() 53 | val_acc_list = list() 54 | for i in range(self.num_candidates): 55 | # Pruning strategy 56 | channel_config = get_strartegy_generation(self.net, self.min_rate, self.max_rate) 57 | 58 | #Get pruned model using l1 pruning 59 | pruned_model = l1_pruning(self.net, channel_config) 60 | pruned_model_list.append(pruned_model) 61 | 62 | # Adpative BN-statistics 63 | slice_idx = int(np.shape(self.x_train)[0]/30) 64 | sliced_x_train = self.x_train[:slice_idx, :, :, :] 65 | sliced_y_train = self.y_train[:slice_idx, :] 66 | 67 | sliced_train_dataset = tf.data.Dataset.from_tensor_slices((sliced_x_train, sliced_y_train)).batch(64) 68 | 69 | max_iters=10 70 | for j in range(max_iters): 71 | for x_batch, y_batch in sliced_train_dataset: 72 | output = pruned_model(x_batch, training=True) 73 | 74 | 75 | #Evaluate top-1 accuracy for prunned model 76 | small_val_datsaet = tf.data.Dataset.from_tensor_slices((sliced_x_train, sliced_y_train)).batch(64) 77 | small_val_acc = k.metrics.CategoricalAccuracy() 78 | for x_batch, y_batch in small_val_datsaet: 79 | output = pruned_model(x_batch) 80 | 81 | small_val_acc.update_state(y_batch, output) 82 | 83 | small_val_acc = small_val_acc.result().numpy() 84 | print(f'Adaptive-BN-based accuracy for {i}-th prunned model: {small_val_acc}') 85 | 86 | val_acc_list.append(small_val_acc) 87 | 88 | #Select the best candidate model 89 | val_acc_np = np.array(val_acc_list) 90 | best_candidate_idx = np.argmax(val_acc_np) 91 | best_model = pruned_model_list[best_candidate_idx] 92 | print(f'\n The best candidate is {best_candidate_idx}-th prunned model (Acc: {val_acc_np[best_candidate_idx]})') 93 | 94 | #Fine tuning 95 | metrics = 'accuracy' 96 | optimizer = Adam(learning_rate=self.lr) 97 | loss = tf.keras.losses.CategoricalCrossentropy() 98 | 99 | 100 | def get_callback_list(save_path, early_stop=True, lr_reducer=True): 101 | callback_list=list() 102 | 103 | if lr_reducer == True: 104 | callback_list.append(tf.keras.callbacks.ReduceLROnPlateau(factor=0.1, cooldown=0, patience=10, mode='auto', epsilon=0.0001, min_lr=0)) 105 | if early_stop == True: 106 | callback_list.append(tf.keras.callbacks.EarlyStopping(min_delta=0, patience=20, verbose=2, mode='auto')) 107 | 108 | return callback_list 109 | 110 | 111 | best_model.compile(loss=loss, optimizer=optimizer, metrics=[metrics]) 112 | self.net.compile(loss= loss, optimizer=optimizer, metrics=[metrics]) 113 | callback_list = get_callback_list(self.result_dir) 114 | 115 | if self.data_augmentation == True: 116 | 117 | datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset 118 | samplewise_center=False, # set each sample mean to 0 119 | featurewise_std_normalization=False, # divide inputs by std of the dataset 120 | samplewise_std_normalization=False, # divide each input by its std 121 | zca_whitening=False, # apply ZCA whitening 122 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 123 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 124 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 125 | horizontal_flip=True, # randomly flip images 126 | vertical_flip=False) # randomly flip images 127 | 128 | datagen.fit(self.x_train) 129 | 130 | best_model.fit(datagen.flow(self.x_train, self.y_train, batch_size=self.bs), \ 131 | steps_per_epoch=len(self.x_train) / self.bs,epochs=self.epochs, \ 132 | validation_data=(self.x_test,self.y_test), callbacks=callback_list) 133 | 134 | else: 135 | best_model.fit(self.x_train, self.y_train, batch_size=self.bs, epochs=self.epochs,\ 136 | validation_data=(self.x_test,self.y_test), callbacks=callback_list) 137 | 138 | 139 | 140 | #Get flops and parameters of the base model and pruned model 141 | params_prev = count_params(self.net.trainable_weights) 142 | flops_prev = count_flops(self.net) 143 | scores_prev = self.net.evaluate(self.x_test, self.y_test, batch_size=self.bs, verbose=0) 144 | print(f'\nTest loss (on base model): {scores_prev[0]}') 145 | print(f'Test accuracy (on base model): {scores_prev[1]}') 146 | print(f'The number of parameters (on base model): {params_prev}') 147 | print(f'The number of flops (on base model): {flops_prev}') 148 | params_after = count_params(best_model.trainable_weights) 149 | flops_after = count_flops(best_model) 150 | scores_after = best_model.evaluate(self.x_test, self.y_test, batch_size=self.bs, verbose=0) 151 | print(f'\nTest loss (on pruned model): {scores_after[0]}') 152 | print(f'Test accuracy (on pruned model): {scores_after[1]}') 153 | print(f'The number of parameters (on pruned model): {params_after}') 154 | print(f'The number of flops (on prund model): {flops_after}') 155 | 156 | #save the best candidate model 157 | slash_idx = self.model_path.rfind('/') 158 | ext_idx = self.model_path.rfind('.') 159 | save_name = self.model_path[slash_idx:ext_idx] 160 | tf.keras.models.save_model( 161 | model=best_model, 162 | filepath=self.result_dir+save_name+'_pruned.h5', 163 | include_optimizer=False 164 | ) 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Eagleeye: fast sub-net evaluation for efficient neural network pruning 3 | 4 | 5 | ![Python version support](https://img.shields.io/badge/python-3.6-blue.svg) 6 | ![Tensorflow version support](https://img.shields.io/badge/tensorflow-2.3.0-red.svg) 7 | 8 | :star: Star us on GitHub — it helps!! 9 | 10 | 11 | Tensorflow keras implementation for *[EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning](https://arxiv.org/abs/2007.02491)* 12 | 13 | ## Install 14 | 15 | You will need a machine with a GPU and CUDA installed. 16 | Then, you prepare runtime environment: 17 | 18 | ```shell 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Use 23 | 24 | ### Train base model 25 | 26 | If you want to train a network for yourself: 27 | 28 | ```shell 29 | python train_network.py --dataset_name=cifar10 --model_name=resnet34 30 | ``` 31 | 32 | Arguments: 33 | 34 | - `dataset_name` - Select a dataset ['cifar10' or 'cifar100'] 35 | - `model_name` - Trainable network names 36 | - Available list 37 | - VGG: ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn'] 38 | - ResNet: ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] 39 | - MobileNet: ['mobilenetv2'] 40 | - DenseNet: ['densenet40_f12','densenet100_f12'] 41 | - `batch_size` - Batch size 42 | - `epochs` - The number of epochs 43 | - `lr` - Learning rate 44 | - Default is 0.001 (set 0.01 when training mobilenetv2) 45 | 46 | Else if you want to use a pre-trained model, download the model through a download link in the below result section. 47 | Please put the downloaded models in the directory of `./saved_models/`. 48 | 49 | 50 | ### Prune the model using EagleEye!! 51 | 52 | Next, you can prune the trained model as following: 53 | 54 | ```shell 55 | python main.py --dataset_name=cifar10 --model_path=./saved_models/cifar10_resnet34.h5 --epochs=100 --min_rate=0.0 --max_rate=0.5 --num_candidates=15 56 | ``` 57 | 58 | Arguments: 59 | 60 | - `dataset_name` - Select a dataset ['cifar10' or 'cifar100'] 61 | - `model_path` - Model path 62 | - `bs` - Batch size 63 | - `epochs` - The number of epochs 64 | - `lr` - Learning rate 65 | - `min_rate` - Minimum rate of search space 66 | - `max_rate` - Maximum rate of search space 67 | - `num_candidates` - The number of candidates 68 | 69 | 70 | ## Result 71 | 72 | The result of progress looks like as following: 73 | ``` 74 | Adaptive-BN-based accuracy for 0-th prunned model: 0.08163265138864517 75 | Adaptive-BN-based accuracy for 1-th prunned model: 0.20527011036872864 76 | Adaptive-BN-based accuracy for 2-th prunned model: 0.10084033757448196 77 | ... 78 | Adaptive-BN-based accuracy for 13-th prunned model: 0.10804321616888046 79 | Adaptive-BN-based accuracy for 14-th prunned model: 0.11284513771533966 80 | 81 | The best candidate is 1-th prunned model (Acc: 0.20527011036872864) 82 | 83 | Epoch 1/200 84 | 196/195 [==============================] - 25s 125ms/step - loss: 1.7189 - accuracy: 0.4542 - val_loss: 1.8755 - val_accuracy: 0.4265 85 | ... 86 | 87 | Test loss (on base model): 0.3912913203239441 88 | Test accuracy (on base model): 0.9172999858856201 89 | The number of parameters (on base model): 33646666 90 | The number of flops (on base model): 664633404 91 | 92 | Test loss (on pruned model): 0.520440936088562 93 | Test accuracy (on pruned model): 0.9136999845504761 94 | The number of parameters (on pruned model): 27879963 95 | The number of flops (on prund model): 401208200 96 | ``` 97 | 98 | Then, the pruned model will be saved in `./result/`(default) folder. 99 | A result example is represented when before and after pruning: 100 | 101 | drawing 102 | 103 | 104 | ### ResNet 34 on cifar10 105 | 106 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download| 107 | |-----|---|--------------------|-----|---------|---------|--------| 108 | |Original|90.11%|None|145M|1.33M|5.44MB|[cifar10_resnet34.h5](https://drive.google.com/file/d/1SJS61fUh_GsnlBI3WB__JtdO70Zfj7Ms/view?usp=sharing)| 109 | |Pruned|90.50%|[0, 0.5]|106M|0.86M|3.64MB|[cifar10_resnet34_pruned0.5.h5](https://drive.google.com/file/d/1VRTAiIvF7B7-AejxLunaCtOWFTan5p0U/view?usp=sharing)| 110 | |Pruned|89.02%|[0, 0.7]|69M|0.55M|2.60MB|[cifar10_resnet34_pruned0.7.h5](https://drive.google.com/file/d/1z77mbXxagEyc9TKxDXISgNmQ-74Tkc0_/view?usp=sharing)| 111 | |Pruned|89.30%|[0, 0.9]|62M|0.54M|2.52MB|[cifar10_resnet34_pruned0.9.h5](https://drive.google.com/file/d/1uBJoaovFkEwSbaF_AggxUP70MShPOJSc/view?usp=sharing)| 112 | 113 | ### ResNet 18 on cifar10 114 | 115 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download| 116 | |-----|---|--------------------|-----|---------|---------|--------| 117 | |Original|88.30%|None|70.2M|0.70M|2.87MB|[cifar10_resnet18.h5](https://drive.google.com/file/d/1fu_DlI-YLm3IunHFmq-UFi-p4ecXBUq6/view?usp=sharing)| 118 | |Pruned|88.02%|[0, 0.5]|50.9M|0.47M|2.01MB|[cifar10_resnet18_pruned0.5.h5](https://drive.google.com/file/d/187fYZvDfHj8w_YM_JguPyet0SsPD8GXg/view?usp=sharing)| 119 | |Pruned|87.09%|[0, 0.7]|33.1M|0.26M|1.20MB|[cifar10_resnet18_pruned0.7.h5](https://drive.google.com/file/d/18nrWoX-1TKtLFbiedOkJywHxUS6PGpfo/view?usp=sharing)| 120 | |Pruned|85.65%|[0, 0.9]|21.4M|0.08M|0.51MB|[cifar10_resnet18_pruned0.9.h5](https://drive.google.com/file/d/1JjoqWf3motY_swaQ4rBIQMRTteJo2zsL/view?usp=sharing)| 121 | 122 | 123 | ### VGG 16_bn on cifar10 124 | 125 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download| 126 | |-----|---|--------------------|-----|---------|---------|--------| 127 | |Original|91.99%|None|664M|33.64M|128MB|[cifar10_vgg16_bn.h5](https://drive.google.com/file/d/1YZJ-I5V2RZOwMYwwY5T7VhQd48wPsTQb/view?usp=sharing)| 128 | |Pruned|91.44%|[0, 0.5]|394M|26.27M|100MB|[cifar10_vgg16_bn_pruned0.5.h5](https://drive.google.com/file/d/1T_dHGEJphXzufahiimdSH6U-CCaNV9po/view?usp=sharing)| 129 | |Pruned|90.81%|[0, 0.7]|298M|25.18M|96.2MB|[cifar10_vgg16_bn_pruned0.7.h5](https://drive.google.com/file/d/1u15n_rRYd-ARaIF-IPP57MAFyoMfU7Q7/view?usp=sharing)| 130 | |Pruned|90.42%|[0, 0.9]|293M|21.9M|83.86MB|[cifar10_vgg16_bn_pruned0.9.h5](https://drive.google.com/file/d/1Hm1Ui068Cd7fdVg_-VytZ7sstTMoP9MK/view?usp=sharing)| 131 | 132 | ### MobileNetv2 on cifar10 133 | 134 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download| 135 | |-----|---|--------------------|-----|---------|---------|--------| 136 | |Original|90.14%|None|175M|2.23M|9.10MB|[cifar10_mobilenetv2_bn.h5](https://drive.google.com/file/d/1zVtZ2jwNTyn3Bo0D8V0xfRPpEjWZg09s/view?usp=sharing)| 137 | |Pruned|90.57%|[0, 0.5]|126M|1.62M|6.72MB|[cifar10_mobilenetv2_pruned0.5.h5](https://drive.google.com/file/d/1MLg909rak-78fZ5gHOsLMFBy2deXzZkJ/view?usp=sharing)| 138 | |Pruned|89.13%|[0, 0.7]|81M|0.90M|3.96MB|[cifar10_mobilenetv2_pruned0.7.h5](https://drive.google.com/file/d/1CDYObh6cJUIRS0Gc4m6BmVlQejcVs8We/view?usp=sharing)| 139 | |Pruned|90.54%|[0, 0.9]|70M|0.62M|2.92MB|[cifar10_mobilenetv2_pruned0.9.h5](https://drive.google.com/file/d/1GjUXOepwrMkBOYm5c239wJP64sHLrb-8/view?usp=sharing)| 140 | 141 | ### DenseNet40 on cifar10 142 | 143 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download| 144 | |-----|---|--------------------|-----|---------|---------|--------| 145 | |Original|91.32%|None|51M|1.00M|4.27MB|[cifar10_densenet40_bn.h5](https://drive.google.com/file/d/1gHUozIUE9pLGMekbP9lM5oqcw_Q_rgYX/view?usp=sharing)| 146 | |Pruned|91.13%|[0, 0.5]|32M|0.68M|3.07MB|[cifar10_densenet40_pruned0.5.h5](https://drive.google.com/file/d/1GCC3KSiJsHjUjQsO2ZL2L0HjUvBYP4F9/view?usp=sharing)| 147 | |Pruned|90.95%|[0, 0.7]|24M|0.54M|2.49MB|[cifar10_densenet40_pruned0.7.h5](https://drive.google.com/file/d/1F-5pk36kTMSqgs8zqaWWl8mC06AG5D_Z/view?usp=sharing)| 148 | |Pruned|90.31%|[0, 0.9]|20M|0.41M|2.00MB|[cifar10_densenet40_pruned0.9.h5](https://drive.google.com/file/d/1LLH8s5SAo3kyoQ556W_ePDDTuh8DomMO/view?usp=sharing)| 149 | 150 | ### DenseNet100 on cifar10 151 | 152 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download| 153 | |-----|---|--------------------|-----|---------|---------|--------| 154 | |Original|93.26%|None|2540M|3.98M|16.4MB|[cifar10_densenet100_bn.h5](https://drive.google.com/file/d/1GmCfYv32MnoWvSyoYc7M26AhNGc2xRSI/view?usp=sharing)| 155 | |Pruned|93.12%|[0, 0.5]|1703M|2.74M|11.6MB|[cifar10_densenet100_pruned0.5.h5](https://drive.google.com/file/d/1RLNUeS4DKdSAGf1iUaY9FoDn1LEymdUJ/view?usp=sharing)| 156 | |Pruned|93.12%|[0, 0.7]|1467M|2.41M|10.4MB|[cifar10_densenet100_pruned0.7.h5](https://drive.google.com/file/d/17KCmlrBfuecI0Qke45W78wCF6WnauVkp/view?usp=sharing)| 157 | |Pruned|93.00%|[0, 0.9]|1000M|1.65M|7.44MB|[cifar10_densenet100_pruned0.9.h5](https://drive.google.com/file/d/1q5yvwSd146wYm2hnTWfTac3mQ-F1hzhe/view?usp=sharing)| 158 | 159 | 160 | :no_entry: If you run this code, the result would be a bit different from mine. 161 | 162 | 163 | ## Understanding this paper 164 | 165 | :white_check_mark: Check my blog!! 166 | [Here](https://da2so.github.io/2020-10-25-EagleEye_Fast_Sub_net_Evaluation_for_Efficient_Neur_Network_Pruning/) -------------------------------------------------------------------------------- /core/pruner/graph_wrapper.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import sys 3 | import numpy as np 4 | import inspect 5 | 6 | from tensorflow.python.keras import layers 7 | from tensorflow.keras.models import Model 8 | import tensorflow as tf 9 | 10 | class GraphWrapper(object): 11 | def __init__(self, model): 12 | self.model = model 13 | self.layer_info = defaultdict() 14 | self.input_shape = list(model.layers[0].input.shape[1:]) 15 | 16 | #construct 17 | for idx, layer in enumerate(self.model.layers): 18 | if idx == 0: 19 | self.layer_info[id(layer.output)] = {'count': 0} 20 | continue 21 | 22 | self.layer_info[id(layer.output)] = {'layer_input': layer.input, 'count': 0} 23 | 24 | if isinstance(layer.input, list): 25 | for in_layer in layer.input: 26 | self.layer_info[id(in_layer)]['count'] += 1 27 | else: 28 | self.layer_info[id(layer.input)]['count'] += 1 29 | 30 | #check if a conv layer determine the number of classes (ouput classes) 31 | for idx, layer in enumerate(reversed(self.model.layers)): 32 | 33 | if type(layer) is layers.Dense: 34 | self.is_lastConv2D = False 35 | break 36 | elif type(layer) is layers.Conv2D: 37 | self.is_lastConv2D = True 38 | self.lastConv2D = len(self.model.layers)-(idx+1) 39 | break 40 | else: 41 | continue 42 | 43 | 44 | def build(self,del_layer_dict): 45 | prev_prune_idx = None 46 | first_dense = True 47 | #output processing for keeping the number of classes 48 | if self.is_lastConv2D: 49 | del_layer_dict[self.lastConv2D] = [] 50 | 51 | #initialize input 52 | input = layers.Input(shape=self.input_shape) 53 | x = input 54 | 55 | for idx, layer in enumerate(self.model.layers): 56 | if idx == 0: 57 | self.layer_info[id(layer.output)].update({'out': x}) 58 | continue 59 | 60 | # reconstruct each layer 61 | if type(layer) is layers.convolutional.Conv2D: 62 | prev_prune_idx=self.get_prev_pruned_idx(layer.input, self.layer_info) 63 | 64 | pruned_weights = self.get_conv_pruned_weights(layer, del_layer_dict[idx], prev_prune_idx) 65 | cutted_channl_num = len (del_layer_dict[idx]) 66 | 67 | attr_name = inspect.getargspec(layers.convolutional.Conv2D.__init__)[0] 68 | attr_name.pop(0) 69 | attr_value = {attr: getattr(layer ,attr) for attr in attr_name} 70 | attr_value['filters'] -= cutted_channl_num 71 | recon_layer = layers.convolutional.Conv2D(**attr_value) 72 | 73 | self.layer_info[id(layer.output)].update({'pruned_idx': del_layer_dict[idx]}) 74 | 75 | elif type(layer) is layers.convolutional.DepthwiseConv2D: 76 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info) 77 | 78 | pruned_weights = self.get_depthwiseconv_pruned_weights(layer, prev_prune_idx) 79 | 80 | recon_layer = layer.__class__.from_config(layer.get_config()) 81 | 82 | 83 | elif type(layer) is layers.normalization_v2.BatchNormalization: 84 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info) 85 | 86 | pruned_weights = self.get_batchnorm_pruned_weights(layer, prev_prune_idx) 87 | 88 | cutted_channl_num = len(prev_prune_idx) 89 | config = layer.get_config() 90 | config['gamma_regularizer'] = None 91 | recon_layer = layers.normalization_v2.BatchNormalization.from_config(config) 92 | 93 | 94 | elif type(layer) is layers.Dense: 95 | if first_dense == True: 96 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info) 97 | 98 | pruned_weights = self.get_dense_pruned_weights(layer, prev_prune_idx) 99 | 100 | 101 | first_dense = False 102 | recon_layer = layer.__class__.from_config(layer.get_config()) 103 | elif type(layer) is layers.Reshape: 104 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info) 105 | 106 | prev_prune_num = len( prev_prune_idx) 107 | 108 | config = layer.get_config() 109 | original_shape = config['target_shape'] 110 | original_shape = list(original_shape) 111 | original_shape[-1] = original_shape[-1]-prev_prune_num 112 | new_shape = tuple(original_shape) 113 | config['target_shape'] = new_shape 114 | 115 | recon_layer = layer.__class__.from_config(config) 116 | else: 117 | 118 | if type(layer) is layers.Add or \ 119 | type(layer) is layers.Subtract or \ 120 | type(layer) is layers.Multiply or \ 121 | type(layer) is layers.Average or \ 122 | type(layer) is layers.Maximum or \ 123 | type(layer) is layers.Minimum: 124 | pruned_idx = self.set_pruned_idx(layer.input, self.layer_info) 125 | self.layer_info[id(layer.output)].update({'pruned_idx': pruned_idx}) 126 | 127 | 128 | if type(layer) is layers.Concatenate: 129 | pruned_idx = self.set_pruned_idx_for_concat(layer.input, self.layer_info) 130 | self.layer_info[id(layer.output)].update({'pruned_idx': pruned_idx}) 131 | 132 | recon_layer = layer.__class__.from_config(layer.get_config()) 133 | 134 | 135 | # connect layers 136 | if isinstance(layer.input, list): 137 | input_list = [] 138 | for in_layer in layer.input: 139 | input_list.append(self.layer_info[id(in_layer)]['out']) 140 | self.layer_info[id(in_layer)]['count'] -= 1 141 | self.del_key(in_layer, self.layer_info) 142 | x = input_list 143 | else: 144 | x = self.layer_info[id(layer.input)]['out'] 145 | self.layer_info[id(layer.input)]['count'] -= 1 146 | self.del_key(layer.input, self.layer_info) 147 | 148 | x = recon_layer(x) 149 | self.layer_info[id(layer.output)]['out'] = x 150 | 151 | try: 152 | recon_layer.set_weights(pruned_weights) 153 | except: 154 | pass 155 | 156 | model = Model(inputs=input, outputs=x) 157 | 158 | return model 159 | 160 | 161 | def get_conv_pruned_weights(self, layer, prune_idx, prev_pruned_idx=None): 162 | 163 | weights = layer.get_weights() 164 | 165 | #prune kernel 166 | kernel = weights[0] 167 | pruned_kernel= np.delete(kernel, prune_idx, axis=-1) 168 | 169 | try: 170 | pruned_kernel = np.delete(pruned_kernel, prev_pruned_idx, axis=-2) 171 | except: 172 | pass 173 | 174 | #prune bias 175 | prunned_bias = None 176 | if layer.use_bias: 177 | bias = weights[1] 178 | prunned_bias = np.delete(bias, prune_idx) 179 | 180 | if layer.use_bias == True: 181 | return [pruned_kernel, prunned_bias] 182 | else: 183 | return [pruned_kernel] 184 | 185 | def get_dense_pruned_weights(self, layer, prev_pruned_idx=None): 186 | weights = layer.get_weights() 187 | kernel = weights[0] 188 | pruned_kernel = np.delete(kernel, prev_pruned_idx, axis=-2) 189 | 190 | bias = weights[1] 191 | 192 | return [pruned_kernel, bias] 193 | 194 | 195 | def get_depthwiseconv_pruned_weights(self, layer, prev_pruned_idx=None): 196 | weights = layer.get_weights() 197 | kernel = weights[0] 198 | pruned_kernel = np.delete(kernel, prev_pruned_idx, axis=-2) 199 | 200 | #prune bias 201 | prunned_bias = None 202 | if layer.use_bias: 203 | bias = weights[1] 204 | 205 | if layer.use_bias == True: 206 | return [pruned_kernel, bias] 207 | else: 208 | return [pruned_kernel] 209 | 210 | def get_batchnorm_pruned_weights(self, layer, prune_idx): 211 | weights = layer.get_weights() 212 | pruned_weights = [np.delete(w, prune_idx) for w in weights] 213 | 214 | return pruned_weights 215 | 216 | def get_prev_pruned_idx(self, in_layer, layer_info): 217 | 218 | while 1: 219 | if 'layer_input' not in layer_info[id(in_layer)]: 220 | return None 221 | if 'pruned_idx' not in layer_info[id(in_layer)]: 222 | in_layer = layer_info[id(in_layer)]['layer_input'] 223 | else: 224 | prev_pruned_idx = layer_info[id(in_layer)]['pruned_idx'] 225 | return prev_pruned_idx 226 | 227 | def set_pruned_idx_for_concat(self, layer_list, layer_info): 228 | pruned_idx_set = set() 229 | is_second_input = False 230 | for in_layer in layer_list: 231 | 232 | while 1: 233 | if 'pruned_idx' not in layer_info[id(in_layer)]: 234 | in_layer = layer_info[id(in_layer)]['layer_input'] 235 | else: 236 | if is_second_input == False: 237 | prev_layer_len = in_layer.shape[3] 238 | pruned_idx_set.update(layer_info[id(in_layer)]['pruned_idx']) 239 | is_second_input = True 240 | else: 241 | pruned_idx = layer_info[id(in_layer)]['pruned_idx']+prev_layer_len 242 | pruned_idx_set.update(pruned_idx) 243 | prev_layer_len = in_layer.shape[3] 244 | break 245 | 246 | pruned_idx_list = list(pruned_idx_set) 247 | return pruned_idx_list 248 | 249 | 250 | def set_pruned_idx(self, layer_list, layer_info): 251 | pruned_idx_set = set() 252 | for in_layer in layer_list: 253 | 254 | while 1: 255 | if 'pruned_idx' not in layer_info[id(in_layer)]: 256 | in_layer = layer_info[id(in_layer)]['layer_input'] 257 | else: 258 | pruned_num = len(layer_info[id(in_layer)]['pruned_idx']) 259 | pruned_idx_set.update(layer_info[id(in_layer)]['pruned_idx']) 260 | break 261 | 262 | pruned_idx_list = list( pruned_idx_set) [:pruned_num] 263 | return pruned_idx_list 264 | 265 | def del_key(self, layer_input, layer_info): 266 | if layer_info[id(layer_input)]['count'] == 0: 267 | del layer_info[id(layer_input)]['out'] --------------------------------------------------------------------------------