├── .DS_Store
├── assets
└── fig1.png
├── .idea
├── misc.xml
├── vcs.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── Eagleeye_Tensorflow.iml
└── workspace.xml
├── network
├── __init__.py
├── utils.py
├── vgg.py
├── densenet.py
├── resnet.py
└── mobilenetv2.py
├── requirements.txt
├── core
├── pruner
│ ├── l1norm.py
│ └── graph_wrapper.py
├── utils.py
├── strategy_generation.py
└── eagleeye.py
├── main.py
├── train_network.py
└── README.md
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/da2so/Eagleeye_Tensorflow/HEAD/.DS_Store
--------------------------------------------------------------------------------
/assets/fig1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/da2so/Eagleeye_Tensorflow/HEAD/assets/fig1.png
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | from network.mobilnetv2 import mobilenetv2
2 | from network.vgg import vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn
3 | from network.resnet import resnet18, resnet34, resnet50,resnet101 ,resnet152
4 | from network.densenet import densenet40_f12, densenet100_f12, densenet100_f24
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.5
2 | tensorflow-gpu==2.3.0
3 | absl-py==0.11.0
4 | astor==0.8.1
5 | astunparse==1.6.3
6 | chardet==3.0.4
7 | h5py==2.10.0
8 | importlib-metadata==2.0.0
9 | Keras-Applications==1.0.8
10 | Keras-Preprocessing==1.1.2
11 | matplotlib==3.3.2
12 | pandas==1.1.4
13 | Pillow==8.0.1
14 | pydot==1.4.1
15 | tensorflow-estimator==2.3.0
16 | six==1.15.0
17 | argparse==1.4.0
18 |
--------------------------------------------------------------------------------
/.idea/Eagleeye_Tensorflow.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/core/pruner/l1norm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from tensorflow.keras import layers
4 | import tensorflow as tf
5 |
6 | from core.pruner.graph_wrapper import GraphWrapper
7 |
8 |
9 | def l1_pruning(model, channel_config):
10 | del_layer_dict={}
11 |
12 | for idx, prune_rate in channel_config.items():
13 | w_copy=model.layers[idx].get_weights()[0]
14 |
15 | #l1 norm for weights of a layer
16 | w_copy = tf.math.abs(w_copy)
17 | w_copy = tf.math.reduce_sum(w_copy, axis=[0,1,2])
18 |
19 | num_delete = int(np.shape(w_copy)[0]*(prune_rate))
20 | del_list = tf.math.top_k(tf.reshape(-w_copy, [-1]), num_delete)[1].numpy()
21 |
22 | del_layer_dict[idx]=del_list
23 |
24 | #Reconstruct the base model uisng graph wrapper
25 | graph_obj=GraphWrapper(model)
26 | pruned_model=graph_obj.build(del_layer_dict)
27 |
28 | return pruned_model
29 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | 1616838752505
31 |
32 |
33 | 1616838752505
34 |
35 |
36 |
37 |
--------------------------------------------------------------------------------
/network/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import tensorflow as tf
4 | from tensorflow.keras.datasets import cifar10, cifar100
5 | from tensorflow.keras.utils import to_categorical
6 |
7 | import network
8 |
9 |
10 | def load_dataset(dataset_name, batch_size):
11 |
12 | if dataset_name =='cifar10':
13 | (x_train, y_train), (x_test, y_test) = cifar10.load_data()
14 | num_classes=10
15 | img_shape=[32,32,3]
16 |
17 | elif dataset_name == 'cifar100':
18 | (x_train, y_train), (x_test, y_test) = cifar100.load_data()
19 | num_classes=100
20 | img_shape=[32,32,3]
21 | elif dataset_name == 'imagenet':
22 | raise ValueError('Not yet implemented')
23 | else:
24 | raise ValueError('Invalid dataset name : {}'.format(dataset_name))
25 |
26 | img_shape=[32,32,3]
27 | normalize = [ [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]]
28 |
29 | x_train = x_train.astype('float32')
30 | x_test = x_test.astype('float32')
31 |
32 | x_train /= 255.
33 | x_test /= 255.
34 |
35 | mean= normalize[0]
36 | std= normalize[1]
37 |
38 | for i in range(3):
39 | x_train[:,:,:,i] = (x_train[:,:,:,i] - mean[i]) / std[i]
40 | x_test[:,:,:,i] = (x_test[:,:,:,i] - mean[i]) / std[i]
41 |
42 |
43 | y_train = to_categorical(y_train)
44 | y_test = to_categorical(y_test)
45 |
46 | return x_train ,y_train , x_test, y_test,num_classes,img_shape
47 |
48 | def load_model_arch(model_name,img_shape,num_classes):
49 |
50 | try:
51 | model_func = getattr(network, model_name)
52 | except:
53 | raise ValueError('Invalid model name : {}'.format(model_name))
54 |
55 | model = model_func(img_shape,num_classes)
56 | return model
57 | def lr_scheduler(epoch):
58 | lr = 1e-3
59 | if epoch >180:
60 | lr*=0.001
61 | elif epoch >120:
62 | lr*=0.01
63 | elif epoch >80:
64 | lr*=0.1
65 |
66 | return lr
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | #from keras_flops import get_flops
2 |
3 | import tensorflow as tf
4 | from tensorflow.keras.applications.resnet50 import ResNet50
5 | from tensorflow.keras.applications.vgg19 import VGG19
6 | from tensorflow.keras.applications.resnet_v2 import ResNet50V2
7 | from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2
8 | from tensorflow.keras.applications.mobilenet import MobileNet
9 | from tensorflow.keras.datasets import cifar10, cifar100
10 | from tensorflow.python.framework.convert_to_constants import (
11 | convert_variables_to_constants_v2_as_graph,
12 | )
13 | from tensorflow.keras import Sequential, Model
14 |
15 | def load_network(model_path):
16 | if '.h5' in model_path:
17 | model = tf.keras.models.load_model(model_path)
18 | else:
19 | if 'resnet50v2' in model_path:
20 | model = ResNet50V2(weights='imagenet')
21 | elif 'mobilenet' in model_path:
22 | model = MobileNet(weights='imagenet')
23 | elif 'mobilenetv2' in model_path:
24 | model = MobileNetV2(weights='imagenet')
25 | elif 'vgg19' in model_path:
26 | model = VGG19(weights='imagenet')
27 | else:
28 | raise ValueError('Invalid model name : {}'.format(model_path))
29 | return model
30 |
31 |
32 | def get_flops(model, batch_size):
33 |
34 | if not isinstance(model, (Sequential, Model)):
35 | raise KeyError(
36 | "model arguments must be tf.keras.Model or tf.keras.Sequential instanse"
37 | )
38 |
39 | if batch_size is None:
40 | batch_size = 1
41 |
42 | inputs = [
43 | tf.TensorSpec([batch_size] + inp.shape[1:], inp.dtype) for inp in model.inputs
44 | ]
45 | real_model = tf.function(model).get_concrete_function(inputs)
46 | frozen_func, _ = convert_variables_to_constants_v2_as_graph(real_model)
47 |
48 | # Calculate FLOPS with tf.profiler
49 | run_meta = tf.compat.v1.RunMetadata()
50 | opts = (tf.compat.v1.profiler.ProfileOptionBuilder()
51 | .with_empty_output()
52 | .build())
53 | flops = tf.compat.v1.profiler.profile(
54 | graph=frozen_func.graph, run_meta=run_meta, cmd="scope", options=opts
55 | )
56 |
57 | return flops.total_float_ops
58 |
59 | def count_flops(model):
60 | return get_flops(model, batch_size=1)
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
4 |
5 | import tensorflow as tf
6 |
7 | from core.eagleeye import EagleEye
8 |
9 | def main():
10 |
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10','cifar100'])
13 | parser.add_argument('--model_path', type=str, default='./saved_models/cifar10_resnet34.h5', help = 'model path')
14 | parser.add_argument('--bs', type=int, default=128, help = 'batch size')
15 | parser.add_argument('--epochs', type=int, default=100, help='epoch while fine-tuning')
16 | parser.add_argument('--lr', type=float, default=0.001 , help='learning rate')
17 | parser.add_argument('--min_rate', type=float, default=0.0, help='define minimum of search space')
18 | parser.add_argument('--max_rate', type=float, default=0.5, help='define maximum of search space')
19 | parser.add_argument('--flops_target', type=float, default=0.5, help='flops constraints for pruning')
20 | parser.add_argument('--num_candidates', type=int, default=15, help='the number of candidates')
21 | parser.add_argument('--data_augmentation', type=bool, default=True, help='do data augmentation?')
22 | parser.add_argument('--result_dir', type=str, default='./result', help='result directory for a prunned model')
23 |
24 |
25 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
26 | os.environ['TF2_BEHAVIOR'] = '1'
27 | os.environ['CUDA_VISIBLE_DEVICES']= '0'
28 |
29 | config = tf.compat.v1.ConfigProto()
30 | config.gpu_options.per_process_gpu_memory_fraction = 0.8
31 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
32 |
33 | args = parser.parse_args()
34 |
35 | eagleeye_obj=EagleEye(dataset_name =args.dataset_name,
36 | model_path=args.model_path,
37 | bs=args.bs,
38 | epochs=args.epochs,
39 | lr=args.lr,
40 | min_rate=args.min_rate,
41 | max_rate=args.max_rate,
42 | flops_target=args.flops_target,
43 | num_candidates=args.num_candidates,
44 | result_dir=args.result_dir,
45 | data_augmentation=args.data_augmentation
46 | )
47 |
48 | eagleeye_obj.build()
49 |
50 | if __name__ == '__main__':
51 | main()
--------------------------------------------------------------------------------
/network/vgg.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import backend
3 | from tensorflow.keras import layers
4 | from tensorflow.keras.models import Model
5 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, ZeroPadding2D
6 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation
7 |
8 |
9 | class VGG(object):
10 | def build(input_shape, num_classes, cfg, batch_norm):
11 | rows = input_shape[0]
12 | cols = input_shape[1]
13 |
14 | img_input = layers.Input(shape=input_shape)
15 |
16 | x=img_input
17 |
18 | for v in cfg:
19 | if v == 'M':
20 | x = layers.MaxPooling2D((2, 2), strides=(2, 2))(x)
21 | else:
22 | x = layers.Conv2D(v, (3, 3),
23 | padding='same')(x)
24 | if batch_norm == True:
25 | x = layers.BatchNormalization(axis=-1)(x)
26 | x = Activation("relu")(x)
27 |
28 |
29 | x = layers.Flatten(name='flatten')(x)
30 | x = layers.Dense(4096, activation='relu', name='fc1')(x)
31 | x = layers.Dense(4096, activation='relu', name='fc2')(x)
32 | x = layers.Dense(num_classes, activation='softmax', name='predictions')(x)
33 |
34 | model = Model(img_input, x)
35 |
36 | return model
37 |
38 | cfgs = {
39 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
40 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
41 | 'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
42 | }
43 |
44 | def vgg11(input_shape, classes):
45 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['A'], batch_norm=False)
46 |
47 | def vgg11_bn(input_shape, classes):
48 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['A'], batch_norm=True)
49 |
50 | def vgg13(input_shape, classes):
51 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['B'], batch_norm=False)
52 |
53 | def vgg13_bn(input_shape, classes):
54 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['B'], batch_norm=True)
55 |
56 | def vgg16(input_shape, classes):
57 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['C'], batch_norm=False)
58 |
59 | def vgg16_bn(input_shape, classes):
60 | return VGG.build(input_shape=input_shape, num_classes=classes, cfg=cfgs['C'], batch_norm=True)
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/core/strategy_generation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | from collections import defaultdict
4 |
5 | from tensorflow.python.keras import layers
6 |
7 | def get_strartegy_generation(model, min_rate, max_rate):
8 |
9 | np.random.seed(int(time.time()))
10 | channel_config = defaultdict()
11 | layer_info = defaultdict()
12 |
13 | #Random strategy
14 | for i, layer in enumerate(model.layers):
15 |
16 | layer_info[id(layer.output)] = {'layer_input':layer.input, 'idx':i}
17 | if isinstance(layer, layers.convolutional.Conv2D):
18 | random_rate = (max_rate - min_rate) * (np.random.rand(1)) + min_rate
19 | layer_info[id(layer.output)].update( {'prune_rate':random_rate})
20 | channel_config[i] = random_rate
21 |
22 |
23 | #In case of some architectures (such as resnet)
24 | for i, layer in enumerate(model.layers):
25 |
26 | if isinstance(layer, layers.Add) or \
27 | isinstance(layer, layers.Subtract) or \
28 | isinstance(layer, layers.Multiply) or \
29 | isinstance(layer, layers.Average) or \
30 | isinstance(layer, layers.Maximum) or \
31 | isinstance(layer, layers.Minimum):
32 |
33 | idx_list = list()
34 | prune_list = list()
35 | is_specific_layer = -1
36 |
37 | for idx, in_layer in enumerate(layer.input):
38 | #Find conv layer or specific layers (i.e. Add, Substract, ...)
39 | while 'prune_rate' not in layer_info[id(in_layer)]:
40 | in_layer = layer_info[id(in_layer)]['layer_input']
41 |
42 | #Find specific layers that mentioned in abvoe comments
43 | if layer_info[id(in_layer)]['idx'] not in channel_config:
44 | is_specific_layer = idx
45 |
46 | idx_list.append(layer_info[id(in_layer)]['idx'])
47 | prune_list.append(layer_info[id(in_layer)]['prune_rate'])
48 |
49 | #if one of the input layers is a specific layer, change the minimum rate of pruning
50 | if is_specific_layer != -1:
51 | min_prune_rate = prune_list[is_specific_layer]
52 | else:
53 | min_prune_rate = min(prune_list)
54 |
55 | #Rearrange the prune rate for conv layers
56 | for idx in idx_list:
57 | if isinstance(model.layers[idx], layers.convolutional.Conv2D):
58 | channel_config[idx] = min_prune_rate
59 | layer_info[id(layer.output)]['prune_rate'] = min_prune_rate
60 |
61 | return channel_config
62 |
--------------------------------------------------------------------------------
/network/densenet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import os
6 |
7 | import tensorflow as tf
8 | from tensorflow.keras import backend
9 | from tensorflow.keras import layers
10 | from tensorflow.keras.models import Model
11 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, ZeroPadding2D
12 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation
13 |
14 |
15 | def dense_block(x, filter_num, blocks, name):
16 |
17 | for i in range(blocks):
18 | x = conv_block(x, filter_num, name=name + '_block' + str(i + 1))
19 | return x
20 |
21 |
22 | def transition_block(x, reduction, name):
23 |
24 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
25 | x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
26 | name=name + '_bn')(x)
27 | x = layers.Activation('relu', name=name + '_relu')(x)
28 | x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,
29 | use_bias=False,
30 | padding='same',
31 | name=name + '_conv')(x)
32 | x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
33 | return x
34 |
35 |
36 | def conv_block(x, filter_num, name):
37 |
38 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
39 | x1 = layers.BatchNormalization(axis=bn_axis,
40 | epsilon=1.001e-5,
41 | name=name + '_0_bn')(x)
42 | x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
43 | x1 = layers.Conv2D(filter_num, 3,
44 | use_bias=False,
45 | padding='same',
46 | name=name + '_1_conv')(x1)
47 | x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
48 | return x
49 |
50 |
51 | def DenseNet(blocks,
52 | filter_num=None,
53 | input_shape=None,
54 | reduce=None,
55 | classes=10,
56 | **kwargs):
57 |
58 |
59 | img_input = layers.Input(shape=input_shape)
60 |
61 |
62 | bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1
63 |
64 | x = layers.Conv2D(filter_num, 3, use_bias=False, padding='same',name='conv1/conv')(img_input)
65 |
66 | x = dense_block(x, filter_num, blocks[0], name='conv2')
67 | x = transition_block(x, reduce, name='pool2')
68 | x = dense_block(x, filter_num, blocks[1], name='conv3')
69 | x = transition_block(x, reduce, name='pool3')
70 | x = dense_block(x, filter_num, blocks[2], name='conv4')
71 |
72 |
73 | x = layers.BatchNormalization(
74 | axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
75 | x = layers.Activation('relu', name='relu')(x)
76 |
77 | x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
78 | x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
79 |
80 |
81 | # Create model.
82 | if blocks == [12, 12, 12] and filter_num == 12:
83 | model = Model(img_input, x, name='densenet40_f12')
84 | elif blocks == [32, 32, 32] and filter_num == 12 :
85 | model = Model(img_input, x, name='densenet100_f12')
86 | elif blocks == [32, 32, 32] and filter_num == 24:
87 | model = Model(img_input, x, name='densenet100_f24')
88 | else:
89 | model = Model(img_input, x, name='densenet')
90 |
91 | model.summary()
92 | return model
93 |
94 |
95 | def densenet40_f12(input_shape=None,classes=10):
96 | return DenseNet(blocks = [12, 12, 12],
97 | input_shape = input_shape,
98 | filter_num = 12,
99 | reduce = 1.0,
100 | classes = classes
101 | )
102 |
103 | def densenet100_f12(input_shape=None,classes=10):
104 | return DenseNet(blocks = [32, 32, 32],
105 | input_shape = input_shape,
106 | filter_num = 12,
107 | reduce = 0.5,
108 | classes = classes
109 | )
110 | def densenet100_f24(input_shape = None,classes=10):
111 | return DenseNet(blocks = [32, 32, 32],
112 | input_shape = input_shape,
113 | filter_num = 24,
114 | reduce = 0.5,
115 | classes = classes
116 | )
117 |
118 |
--------------------------------------------------------------------------------
/network/resnet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import six
3 |
4 | import tensorflow as tf
5 | from tensorflow.keras import backend
6 | from tensorflow.keras import layers
7 | from tensorflow.keras.models import Model
8 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, ZeroPadding2D
9 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation
10 | from tensorflow.keras.regularizers import l2
11 |
12 | global ROW_AXIS
13 | global COL_AXIS
14 | global CHANNEL_AXIS
15 |
16 | ROW_AXIS=1
17 | COL_AXIS=2
18 | CHANNEL_AXIS=3
19 |
20 | def _bn_relu(input):
21 | """
22 | BN -> relu block
23 | """
24 | norm = BatchNormalization()(input)
25 | return Activation("relu")(norm)
26 |
27 | def _shortcut(input, residual,init_strides):
28 | """
29 | Adds a shortcut between input and residual block and merges them with "sum"
30 | """
31 |
32 | input_shape = np.shape(input)
33 | residual_shape = np.shape(residual)
34 | equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]
35 |
36 | shortcut = input
37 | # 1 X 1 conv if shape is different. Else identity.
38 | if init_strides != 1 or not equal_channels:
39 | shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],
40 | kernel_size=1,
41 | strides=init_strides,
42 | padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input)
43 |
44 |
45 | out=Add()([shortcut, residual])
46 |
47 | return Activation("relu")(out)
48 |
49 | def basic_block(filters, init_strides=1, is_first_block_of_first_layer=False):
50 | """
51 | Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
52 | """
53 | def f(input):
54 |
55 | out = Conv2D(filters=filters, kernel_size=3, strides=init_strides,padding="same",kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input)
56 | out= _bn_relu(out)
57 |
58 | out = Conv2D(filters=filters, kernel_size=3, strides=1,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(out)
59 | residual = BatchNormalization()(out)
60 |
61 | return _shortcut(input, residual,init_strides)
62 |
63 | return f
64 |
65 |
66 | def bottleneck(filters, init_strides=1, is_first_block_of_first_layer=False):
67 | """
68 | Bottleneck architecture for > 34 layer resnet.
69 | A final conv layer of filters * 4
70 | """
71 | def f(input):
72 |
73 | out = Conv2D(filters=filters, kernel_size=1, strides=1,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input)
74 | out = _bn_relu(out)
75 |
76 | out = Conv2D(filters=filters, kernel_size=3, strides=init_strides,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(out)
77 | out = _bn_relu(out)
78 |
79 | out = Conv2D(filters=filters*4, kernel_size=1, strides=1,padding="same", kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(out)
80 | residual = BatchNormalization()(out)
81 |
82 | return _shortcut(input, residual,init_strides)
83 |
84 | return f
85 |
86 |
87 | def _residual_block(block_function, filters, repetitions, is_first_layer=False):
88 | """
89 | Builds a residual block with repeating bottleneck blocks.
90 | """
91 | def f(input):
92 | for i in range(repetitions):
93 |
94 | init_strides = 1
95 | if i == 0 and not is_first_layer:
96 | init_strides = 2
97 | input = block_function(filters=filters, init_strides=init_strides,
98 | is_first_block_of_first_layer=(is_first_layer and i == 0))(input)
99 | return input
100 |
101 | return f
102 |
103 |
104 | def _get_block(identifier):
105 | if isinstance(identifier, six.string_types):
106 | res = globals().get(identifier)
107 | if not res:
108 | raise ValueError('Invalid {}'.format(identifier))
109 | return res
110 | return identifier
111 |
112 |
113 | class ResNet(object):
114 | def build(input_shape, num_outputs, block_fn, repetitions):
115 |
116 | # Load block function
117 | block_fn = _get_block(block_fn)
118 |
119 | filters_num = 16
120 |
121 | input = Input(shape=input_shape)
122 | conv1 = Conv2D(filters=filters_num, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal',kernel_regularizer=l2(1e-4))(input)
123 | norm1 = _bn_relu(conv1)
124 |
125 | block= norm1
126 | for i, r in enumerate(repetitions):
127 |
128 | block = _residual_block(block_fn, filters=filters_num, repetitions=r, is_first_layer=(i == 0))(block)
129 | filters_num*= 2
130 |
131 | block= _bn_relu(block)
132 | block_shape = np.shape(block)
133 | pool2 = AveragePooling2D(pool_size=block_shape[ROW_AXIS])(block)
134 | flatten1 = Flatten()(pool2)
135 | dense = Dense(units=num_outputs,kernel_initializer='he_normal')(flatten1)
136 | out = Activation("softmax")(dense)
137 | model = Model(inputs=input, outputs=out)
138 | return model
139 |
140 |
141 |
142 | def resnet18(input_shape, num_outputs):
143 | return ResNet.build(input_shape, num_outputs, basic_block, [2, 2, 2, 2])
144 |
145 | def resnet34(input_shape, num_outputs):
146 | return ResNet.build(input_shape, num_outputs, basic_block, [3, 4, 6, 3])
147 |
148 | def resnet50(input_shape, num_outputs):
149 | return ResNet.build(input_shape, num_outputs, bottleneck, [3, 4, 6, 3])
150 |
151 | def resnet101(input_shape, num_outputs):
152 | return ResNet.build(input_shape, num_outputs, bottleneck, [3, 4, 23, 3])
153 |
154 | def resnet152(input_shape, num_outputs):
155 | return ResNet.build(input_shape, num_outputs, bottleneck, [3, 8, 36, 3])
--------------------------------------------------------------------------------
/network/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.keras import backend
3 | from tensorflow.keras import layers
4 | from tensorflow.keras.models import Model
5 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, ReLU, AveragePooling2D ,Add, DepthwiseConv2D
6 | from tensorflow.keras.layers import Input, Flatten, BatchNormalization, Activation, Dropout, GlobalAveragePooling2D
7 | from tensorflow.keras import regularizers
8 |
9 |
10 |
11 | def _make_divisible( v, divisor, min_value=None):
12 | if min_value is None:
13 | min_value = divisor
14 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
15 |
16 | # Make sure that round down does not go down by more than 10%.
17 | if new_v < 0.9 * v:
18 | new_v += divisor
19 | return new_v
20 |
21 | def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id):
22 | prefix = 'block_{}_'.format(block_id)
23 | channel_axis = 1 if backend.image_data_format() == 'channels_first' else -1
24 |
25 | in_channels = backend.int_shape(inputs)[channel_axis]
26 | pointwise_conv_filters = int(filters * alpha)
27 | pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
28 | x = inputs
29 |
30 | # Expand
31 | if block_id:
32 | x = Conv2D(expansion * in_channels, kernel_size=1, strides=1, padding='same', use_bias=False, activation=None, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name=prefix + 'expand')(x)
33 | x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'expand_BN')(x)
34 | x = ReLU(6., name=prefix + 'expand_relu')(x)
35 | else:
36 | prefix = 'expanded_conv_'
37 |
38 | # Depthwise
39 | x = DepthwiseConv2D(kernel_size=3, strides=stride, activation=None, use_bias=False, padding='same', kernel_initializer="he_normal", depthwise_regularizer=regularizers.l2(4e-5), name=prefix + 'depthwise')(x)
40 | x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'depthwise_BN')(x)
41 | x = ReLU(6., name=prefix + 'depthwise_relu')(x)
42 |
43 | # Project
44 | x = Conv2D(pointwise_filters, kernel_size=1, strides=1, padding='same', use_bias=False, activation=None, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name=prefix + 'project')(x)
45 | x = BatchNormalization(epsilon=1e-3, momentum=0.999, name=prefix + 'project_BN')(x)
46 |
47 |
48 | if in_channels == pointwise_filters and stride == 1:
49 | return Add(name=prefix + 'add')([inputs, x])
50 | return x
51 |
52 | class MobileNetV2(object):
53 | def build(input_shape=None,alpha=1.0, classes=10,**kwargs):
54 | rows = input_shape[0]
55 | cols = input_shape[1]
56 | print(input_shape)
57 | img_input = layers.Input(shape=input_shape)
58 |
59 | first_block_filters =_make_divisible(32 * alpha, 8)
60 |
61 | x = Conv2D(first_block_filters, kernel_size=3, strides=1, padding='same', use_bias=False, kernel_initializer="he_normal", kernel_regularizer=regularizers.l2(4e-5), name='Conv1')(img_input)
62 | #x = BatchNormalization(epsilon=1e-3, momentum=0.999, name='bn_Conv1')(x)
63 | #x = ReLU(6., name='Conv1_relu')(x)
64 |
65 | x = _inverted_res_block(x, filters=16, alpha=alpha, stride=1, expansion=1, block_id=0 )
66 |
67 | x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=1 )
68 | x = _inverted_res_block(x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=2 )
69 |
70 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=2, expansion=6, block_id=3 )
71 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=4 )
72 | x = _inverted_res_block(x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=5 )
73 |
74 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=2, expansion=6, block_id=6 )
75 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=7 )
76 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=8 )
77 | x = _inverted_res_block(x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=9 )
78 | x = Dropout(rate=0.25)(x)
79 |
80 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=10)
81 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=11)
82 | x = _inverted_res_block(x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=12)
83 | x = Dropout(rate=0.25)(x)
84 |
85 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=2, expansion=6, block_id=13)
86 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=14)
87 | x = _inverted_res_block(x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=15)
88 | x = Dropout(rate=0.25)(x)
89 |
90 | x = _inverted_res_block(x, filters=320, alpha=alpha, stride=1, expansion=6, block_id=16)
91 | x = Dropout(rate=0.25)(x)
92 |
93 | if alpha > 1.0:
94 | last_block_filters = _make_divisible(1280 * alpha, 8)
95 | else:
96 | last_block_filters = 1280
97 |
98 | x = layers.Conv2D(last_block_filters,
99 | kernel_size=1,
100 | use_bias=False,
101 | name='Conv_1')(x)
102 | x = layers.BatchNormalization( epsilon=1e-3,
103 | momentum=0.999,
104 | name='Conv_1_bn')(x)
105 | x = layers.ReLU(6., name='out_relu')(x)
106 |
107 | x = layers.GlobalAveragePooling2D()(x)
108 | x = layers.Dense(classes, activation='softmax',
109 | use_bias=True, name='Logits')(x)
110 | # Create model.
111 | model = Model(img_input, x, name='mobilenetv2_%0.2f_%s' % (alpha, rows))
112 |
113 |
114 | return model
115 |
116 |
117 |
118 | def mobilenetv2(input_shape, classes):
119 | return MobileNetV2.build(input_shape=input_shape, classes=classes)
--------------------------------------------------------------------------------
/train_network.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
3 | import numpy as np
4 | import argparse
5 |
6 | import tensorflow as tf
7 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
8 | from tensorflow.keras.optimizers import Adam , SGD
9 |
10 | from network.utils import load_dataset, load_model_arch, lr_scheduler
11 | class TrainTeacher(object):
12 | def __init__(self,dataset_name, model_name, batch_size, epochs, lr, save_dir, data_augmentation, metrics='accuracy'):
13 | self.dataset_name = dataset_name
14 | self.batch_size = batch_size
15 | self.model_name = model_name
16 | self.data_augmentation = data_augmentation
17 | if not os.path.exists(save_dir):
18 | os.makedirs(save_dir)
19 | self.save_path = f'{save_dir}{self.dataset_name}_{self.model_name}.h5'
20 | self.epochs = epochs
21 | self.metrics = metrics
22 | self.optimizer = Adam(learning_rate=lr)
23 | #self.optimizer= SGD(lr=2e-2, momentum=0.9, decay=0.0, nesterov=False)
24 | self.loss = tf.keras.losses.CategoricalCrossentropy()
25 |
26 |
27 | #Load train and validation dataset
28 | self.x_train, self.y_train, self.x_test, self.y_test, self.num_classes , self.img_shape = load_dataset(self.dataset_name, self.batch_size)
29 |
30 | #Load model architecture
31 | self.model = load_model_arch(self.model_name,self.img_shape,self.num_classes)
32 |
33 |
34 | #Define callback function
35 | def get_callback_list(self,early_stop=True, lr_reducer=True):
36 |
37 | callback_list = list()
38 |
39 | if early_stop == True:
40 | callback_list.append(tf.keras.callbacks.EarlyStopping(min_delta=0, patience=20, verbose=2, mode='auto'))
41 | if lr_reducer == True:
42 | callback_list.append(tf.keras.callbacks.ReduceLROnPlateau(factor=0.1, cooldown=0, patience=10, min_lr=0.5e-6))
43 |
44 |
45 | return callback_list
46 |
47 | def train(self):
48 | self.model.compile(loss= self.loss, optimizer=self.optimizer, metrics=[self.metrics])
49 | callback_list = self.get_callback_list()
50 |
51 | if self.data_augmentation == True:
52 |
53 | datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset
54 | samplewise_center=False, # set each sample mean to 0
55 | featurewise_std_normalization=False, # divide inputs by std of the dataset
56 | samplewise_std_normalization=False, # divide each input by its std
57 | zca_whitening=False, # apply ZCA whitening
58 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
59 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
60 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
61 | horizontal_flip=True, # randomly flip images
62 | vertical_flip=False) # randomly flip images
63 |
64 | datagen.fit(self.x_train)
65 |
66 | self.model.fit(datagen.flow(self.x_train, self.y_train, batch_size=self.batch_size), \
67 | steps_per_epoch=len(self.x_train) / self.batch_size,epochs=self.epochs, \
68 | validation_data=(self.x_test,self.y_test), callbacks=callback_list)
69 |
70 | else:
71 | self.model.fit(self.x_train, self.y_train, batch_size=self.batch_size, epochs=self.epochs, \
72 | validation_data=(self.x_test,self.y_test), callbacks=callback_list)
73 |
74 |
75 | def test(self):
76 | self.model.compile(optimizer=self.optimizer, loss=self.loss, metrics=self.metrics)
77 | self.val_dataset = tf.data.Dataset.from_tensor_slices((self.x_test, self.y_test)).batch(self.batch_size)
78 | scores = self.model.evaluate(self.val_dataset,batch_size=self.batch_size)
79 | print('Test loss:', scores[0])
80 | print('Test accuracy:', scores[1])
81 |
82 | #save model
83 | tf.keras.models.save_model(
84 | model=self.model,
85 | filepath=self.save_path,
86 | include_optimizer=False
87 | )
88 |
89 | def build(self):
90 | print(f'loading {self.model_name}..\n')
91 | self.train()
92 | self.test()
93 |
94 |
95 | def main():
96 |
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10', 'cifar100'] help='Dataset [ "cifar10", "cifar100" ]')
99 | parser.add_argument('--model_name', type=str, default='mobilenetv2', help='Model name')
100 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
101 | parser.add_argument('--epochs', type=int, default=200, help='Epochs')
102 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
103 | parser.add_argument('--save_dir', type=str, default='./saved_models/', help='Saved model path')
104 | parser.add_argument('--data_augmentation', type=bool, default=True, help='Saved model path')
105 |
106 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
107 | os.environ['TF2_BEHAVIOR'] = '1'
108 | os.environ['CUDA_VISIBLE_DEVICES']= '0'
109 |
110 | config = tf.compat.v1.ConfigProto()
111 | config.gpu_options.per_process_gpu_memory_fraction = 0.8
112 | tf.compat.v1.keras.backend.set_session(tf.compat.v1.Session(config=config))
113 |
114 | args = parser.parse_args()
115 |
116 | trainer=TrainTeacher(dataset_name=args.dataset_name,
117 | model_name=args.model_name,
118 | batch_size=args.batch_size,
119 | epochs=args.epochs,
120 | lr=args.lr,
121 | save_dir=args.save_dir,
122 | data_augmentation=args.data_augmentation,
123 | metrics='accuracy'
124 | )
125 |
126 | trainer.build()
127 |
128 | if __name__ == '__main__':
129 |
130 | main()
--------------------------------------------------------------------------------
/core/eagleeye.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 |
4 | import tensorflow as tf
5 | import tensorflow.keras as k
6 | import tensorflow.keras.backend as b
7 | from tensorflow.python.keras.utils.layer_utils import count_params
8 | from tensorflow.keras.optimizers import Adam
9 | from tensorflow.keras.preprocessing.image import ImageDataGenerator
10 |
11 | from core.utils import load_network ,count_flops
12 | from network.utils import load_dataset, lr_scheduler
13 | from core.strategy_generation import get_strartegy_generation
14 | from core.pruner.l1norm import l1_pruning
15 |
16 | class EagleEye(object):
17 | def __init__(self,
18 | dataset_name,
19 | model_path,
20 | bs,
21 | epochs,
22 | lr,
23 | min_rate,
24 | max_rate,
25 | flops_target,
26 | num_candidates,
27 | result_dir,
28 | data_augmentation):
29 |
30 | self.dataset_name = dataset_name
31 | self.bs = bs
32 | self.epochs = epochs
33 | self.lr = lr
34 | self.min_rate = min_rate
35 | self.max_rate = max_rate
36 | self.flops_target = flops_target
37 | self.result_dir = result_dir
38 | if not os.path.exists(self.result_dir):
39 | os.makedirs(self.result_dir)
40 | self.model_path = model_path
41 | self.num_candidates = num_candidates
42 | self.data_augmentation = data_augmentation
43 |
44 | #Load a base model
45 | self.net = load_network(model_path)
46 |
47 | #Load a dataset
48 | self.x_train, self.y_train, self.x_test, self.y_test, self.num_classes, self.img_shape = load_dataset(self.dataset_name, self.bs)
49 |
50 | def build(self):
51 |
52 | pruned_model_list = list()
53 | val_acc_list = list()
54 | for i in range(self.num_candidates):
55 | # Pruning strategy
56 | channel_config = get_strartegy_generation(self.net, self.min_rate, self.max_rate)
57 |
58 | #Get pruned model using l1 pruning
59 | pruned_model = l1_pruning(self.net, channel_config)
60 | pruned_model_list.append(pruned_model)
61 |
62 | # Adpative BN-statistics
63 | slice_idx = int(np.shape(self.x_train)[0]/30)
64 | sliced_x_train = self.x_train[:slice_idx, :, :, :]
65 | sliced_y_train = self.y_train[:slice_idx, :]
66 |
67 | sliced_train_dataset = tf.data.Dataset.from_tensor_slices((sliced_x_train, sliced_y_train)).batch(64)
68 |
69 | max_iters=10
70 | for j in range(max_iters):
71 | for x_batch, y_batch in sliced_train_dataset:
72 | output = pruned_model(x_batch, training=True)
73 |
74 |
75 | #Evaluate top-1 accuracy for prunned model
76 | small_val_datsaet = tf.data.Dataset.from_tensor_slices((sliced_x_train, sliced_y_train)).batch(64)
77 | small_val_acc = k.metrics.CategoricalAccuracy()
78 | for x_batch, y_batch in small_val_datsaet:
79 | output = pruned_model(x_batch)
80 |
81 | small_val_acc.update_state(y_batch, output)
82 |
83 | small_val_acc = small_val_acc.result().numpy()
84 | print(f'Adaptive-BN-based accuracy for {i}-th prunned model: {small_val_acc}')
85 |
86 | val_acc_list.append(small_val_acc)
87 |
88 | #Select the best candidate model
89 | val_acc_np = np.array(val_acc_list)
90 | best_candidate_idx = np.argmax(val_acc_np)
91 | best_model = pruned_model_list[best_candidate_idx]
92 | print(f'\n The best candidate is {best_candidate_idx}-th prunned model (Acc: {val_acc_np[best_candidate_idx]})')
93 |
94 | #Fine tuning
95 | metrics = 'accuracy'
96 | optimizer = Adam(learning_rate=self.lr)
97 | loss = tf.keras.losses.CategoricalCrossentropy()
98 |
99 |
100 | def get_callback_list(save_path, early_stop=True, lr_reducer=True):
101 | callback_list=list()
102 |
103 | if lr_reducer == True:
104 | callback_list.append(tf.keras.callbacks.ReduceLROnPlateau(factor=0.1, cooldown=0, patience=10, mode='auto', epsilon=0.0001, min_lr=0))
105 | if early_stop == True:
106 | callback_list.append(tf.keras.callbacks.EarlyStopping(min_delta=0, patience=20, verbose=2, mode='auto'))
107 |
108 | return callback_list
109 |
110 |
111 | best_model.compile(loss=loss, optimizer=optimizer, metrics=[metrics])
112 | self.net.compile(loss= loss, optimizer=optimizer, metrics=[metrics])
113 | callback_list = get_callback_list(self.result_dir)
114 |
115 | if self.data_augmentation == True:
116 |
117 | datagen = ImageDataGenerator( featurewise_center=False, # set input mean to 0 over the dataset
118 | samplewise_center=False, # set each sample mean to 0
119 | featurewise_std_normalization=False, # divide inputs by std of the dataset
120 | samplewise_std_normalization=False, # divide each input by its std
121 | zca_whitening=False, # apply ZCA whitening
122 | rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
123 | width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
124 | height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
125 | horizontal_flip=True, # randomly flip images
126 | vertical_flip=False) # randomly flip images
127 |
128 | datagen.fit(self.x_train)
129 |
130 | best_model.fit(datagen.flow(self.x_train, self.y_train, batch_size=self.bs), \
131 | steps_per_epoch=len(self.x_train) / self.bs,epochs=self.epochs, \
132 | validation_data=(self.x_test,self.y_test), callbacks=callback_list)
133 |
134 | else:
135 | best_model.fit(self.x_train, self.y_train, batch_size=self.bs, epochs=self.epochs,\
136 | validation_data=(self.x_test,self.y_test), callbacks=callback_list)
137 |
138 |
139 |
140 | #Get flops and parameters of the base model and pruned model
141 | params_prev = count_params(self.net.trainable_weights)
142 | flops_prev = count_flops(self.net)
143 | scores_prev = self.net.evaluate(self.x_test, self.y_test, batch_size=self.bs, verbose=0)
144 | print(f'\nTest loss (on base model): {scores_prev[0]}')
145 | print(f'Test accuracy (on base model): {scores_prev[1]}')
146 | print(f'The number of parameters (on base model): {params_prev}')
147 | print(f'The number of flops (on base model): {flops_prev}')
148 | params_after = count_params(best_model.trainable_weights)
149 | flops_after = count_flops(best_model)
150 | scores_after = best_model.evaluate(self.x_test, self.y_test, batch_size=self.bs, verbose=0)
151 | print(f'\nTest loss (on pruned model): {scores_after[0]}')
152 | print(f'Test accuracy (on pruned model): {scores_after[1]}')
153 | print(f'The number of parameters (on pruned model): {params_after}')
154 | print(f'The number of flops (on prund model): {flops_after}')
155 |
156 | #save the best candidate model
157 | slash_idx = self.model_path.rfind('/')
158 | ext_idx = self.model_path.rfind('.')
159 | save_name = self.model_path[slash_idx:ext_idx]
160 | tf.keras.models.save_model(
161 | model=best_model,
162 | filepath=self.result_dir+save_name+'_pruned.h5',
163 | include_optimizer=False
164 | )
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Eagleeye: fast sub-net evaluation for efficient neural network pruning
3 |
4 |
5 | 
6 | 
7 |
8 | :star: Star us on GitHub — it helps!!
9 |
10 |
11 | Tensorflow keras implementation for *[EagleEye: Fast Sub-net Evaluation for Efficient Neural Network Pruning](https://arxiv.org/abs/2007.02491)*
12 |
13 | ## Install
14 |
15 | You will need a machine with a GPU and CUDA installed.
16 | Then, you prepare runtime environment:
17 |
18 | ```shell
19 | pip install -r requirements.txt
20 | ```
21 |
22 | ## Use
23 |
24 | ### Train base model
25 |
26 | If you want to train a network for yourself:
27 |
28 | ```shell
29 | python train_network.py --dataset_name=cifar10 --model_name=resnet34
30 | ```
31 |
32 | Arguments:
33 |
34 | - `dataset_name` - Select a dataset ['cifar10' or 'cifar100']
35 | - `model_name` - Trainable network names
36 | - Available list
37 | - VGG: ['vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn']
38 | - ResNet: ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
39 | - MobileNet: ['mobilenetv2']
40 | - DenseNet: ['densenet40_f12','densenet100_f12']
41 | - `batch_size` - Batch size
42 | - `epochs` - The number of epochs
43 | - `lr` - Learning rate
44 | - Default is 0.001 (set 0.01 when training mobilenetv2)
45 |
46 | Else if you want to use a pre-trained model, download the model through a download link in the below result section.
47 | Please put the downloaded models in the directory of `./saved_models/`.
48 |
49 |
50 | ### Prune the model using EagleEye!!
51 |
52 | Next, you can prune the trained model as following:
53 |
54 | ```shell
55 | python main.py --dataset_name=cifar10 --model_path=./saved_models/cifar10_resnet34.h5 --epochs=100 --min_rate=0.0 --max_rate=0.5 --num_candidates=15
56 | ```
57 |
58 | Arguments:
59 |
60 | - `dataset_name` - Select a dataset ['cifar10' or 'cifar100']
61 | - `model_path` - Model path
62 | - `bs` - Batch size
63 | - `epochs` - The number of epochs
64 | - `lr` - Learning rate
65 | - `min_rate` - Minimum rate of search space
66 | - `max_rate` - Maximum rate of search space
67 | - `num_candidates` - The number of candidates
68 |
69 |
70 | ## Result
71 |
72 | The result of progress looks like as following:
73 | ```
74 | Adaptive-BN-based accuracy for 0-th prunned model: 0.08163265138864517
75 | Adaptive-BN-based accuracy for 1-th prunned model: 0.20527011036872864
76 | Adaptive-BN-based accuracy for 2-th prunned model: 0.10084033757448196
77 | ...
78 | Adaptive-BN-based accuracy for 13-th prunned model: 0.10804321616888046
79 | Adaptive-BN-based accuracy for 14-th prunned model: 0.11284513771533966
80 |
81 | The best candidate is 1-th prunned model (Acc: 0.20527011036872864)
82 |
83 | Epoch 1/200
84 | 196/195 [==============================] - 25s 125ms/step - loss: 1.7189 - accuracy: 0.4542 - val_loss: 1.8755 - val_accuracy: 0.4265
85 | ...
86 |
87 | Test loss (on base model): 0.3912913203239441
88 | Test accuracy (on base model): 0.9172999858856201
89 | The number of parameters (on base model): 33646666
90 | The number of flops (on base model): 664633404
91 |
92 | Test loss (on pruned model): 0.520440936088562
93 | Test accuracy (on pruned model): 0.9136999845504761
94 | The number of parameters (on pruned model): 27879963
95 | The number of flops (on prund model): 401208200
96 | ```
97 |
98 | Then, the pruned model will be saved in `./result/`(default) folder.
99 | A result example is represented when before and after pruning:
100 |
101 |
102 |
103 |
104 | ### ResNet 34 on cifar10
105 |
106 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download|
107 | |-----|---|--------------------|-----|---------|---------|--------|
108 | |Original|90.11%|None|145M|1.33M|5.44MB|[cifar10_resnet34.h5](https://drive.google.com/file/d/1SJS61fUh_GsnlBI3WB__JtdO70Zfj7Ms/view?usp=sharing)|
109 | |Pruned|90.50%|[0, 0.5]|106M|0.86M|3.64MB|[cifar10_resnet34_pruned0.5.h5](https://drive.google.com/file/d/1VRTAiIvF7B7-AejxLunaCtOWFTan5p0U/view?usp=sharing)|
110 | |Pruned|89.02%|[0, 0.7]|69M|0.55M|2.60MB|[cifar10_resnet34_pruned0.7.h5](https://drive.google.com/file/d/1z77mbXxagEyc9TKxDXISgNmQ-74Tkc0_/view?usp=sharing)|
111 | |Pruned|89.30%|[0, 0.9]|62M|0.54M|2.52MB|[cifar10_resnet34_pruned0.9.h5](https://drive.google.com/file/d/1uBJoaovFkEwSbaF_AggxUP70MShPOJSc/view?usp=sharing)|
112 |
113 | ### ResNet 18 on cifar10
114 |
115 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download|
116 | |-----|---|--------------------|-----|---------|---------|--------|
117 | |Original|88.30%|None|70.2M|0.70M|2.87MB|[cifar10_resnet18.h5](https://drive.google.com/file/d/1fu_DlI-YLm3IunHFmq-UFi-p4ecXBUq6/view?usp=sharing)|
118 | |Pruned|88.02%|[0, 0.5]|50.9M|0.47M|2.01MB|[cifar10_resnet18_pruned0.5.h5](https://drive.google.com/file/d/187fYZvDfHj8w_YM_JguPyet0SsPD8GXg/view?usp=sharing)|
119 | |Pruned|87.09%|[0, 0.7]|33.1M|0.26M|1.20MB|[cifar10_resnet18_pruned0.7.h5](https://drive.google.com/file/d/18nrWoX-1TKtLFbiedOkJywHxUS6PGpfo/view?usp=sharing)|
120 | |Pruned|85.65%|[0, 0.9]|21.4M|0.08M|0.51MB|[cifar10_resnet18_pruned0.9.h5](https://drive.google.com/file/d/1JjoqWf3motY_swaQ4rBIQMRTteJo2zsL/view?usp=sharing)|
121 |
122 |
123 | ### VGG 16_bn on cifar10
124 |
125 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download|
126 | |-----|---|--------------------|-----|---------|---------|--------|
127 | |Original|91.99%|None|664M|33.64M|128MB|[cifar10_vgg16_bn.h5](https://drive.google.com/file/d/1YZJ-I5V2RZOwMYwwY5T7VhQd48wPsTQb/view?usp=sharing)|
128 | |Pruned|91.44%|[0, 0.5]|394M|26.27M|100MB|[cifar10_vgg16_bn_pruned0.5.h5](https://drive.google.com/file/d/1T_dHGEJphXzufahiimdSH6U-CCaNV9po/view?usp=sharing)|
129 | |Pruned|90.81%|[0, 0.7]|298M|25.18M|96.2MB|[cifar10_vgg16_bn_pruned0.7.h5](https://drive.google.com/file/d/1u15n_rRYd-ARaIF-IPP57MAFyoMfU7Q7/view?usp=sharing)|
130 | |Pruned|90.42%|[0, 0.9]|293M|21.9M|83.86MB|[cifar10_vgg16_bn_pruned0.9.h5](https://drive.google.com/file/d/1Hm1Ui068Cd7fdVg_-VytZ7sstTMoP9MK/view?usp=sharing)|
131 |
132 | ### MobileNetv2 on cifar10
133 |
134 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download|
135 | |-----|---|--------------------|-----|---------|---------|--------|
136 | |Original|90.14%|None|175M|2.23M|9.10MB|[cifar10_mobilenetv2_bn.h5](https://drive.google.com/file/d/1zVtZ2jwNTyn3Bo0D8V0xfRPpEjWZg09s/view?usp=sharing)|
137 | |Pruned|90.57%|[0, 0.5]|126M|1.62M|6.72MB|[cifar10_mobilenetv2_pruned0.5.h5](https://drive.google.com/file/d/1MLg909rak-78fZ5gHOsLMFBy2deXzZkJ/view?usp=sharing)|
138 | |Pruned|89.13%|[0, 0.7]|81M|0.90M|3.96MB|[cifar10_mobilenetv2_pruned0.7.h5](https://drive.google.com/file/d/1CDYObh6cJUIRS0Gc4m6BmVlQejcVs8We/view?usp=sharing)|
139 | |Pruned|90.54%|[0, 0.9]|70M|0.62M|2.92MB|[cifar10_mobilenetv2_pruned0.9.h5](https://drive.google.com/file/d/1GjUXOepwrMkBOYm5c239wJP64sHLrb-8/view?usp=sharing)|
140 |
141 | ### DenseNet40 on cifar10
142 |
143 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download|
144 | |-----|---|--------------------|-----|---------|---------|--------|
145 | |Original|91.32%|None|51M|1.00M|4.27MB|[cifar10_densenet40_bn.h5](https://drive.google.com/file/d/1gHUozIUE9pLGMekbP9lM5oqcw_Q_rgYX/view?usp=sharing)|
146 | |Pruned|91.13%|[0, 0.5]|32M|0.68M|3.07MB|[cifar10_densenet40_pruned0.5.h5](https://drive.google.com/file/d/1GCC3KSiJsHjUjQsO2ZL2L0HjUvBYP4F9/view?usp=sharing)|
147 | |Pruned|90.95%|[0, 0.7]|24M|0.54M|2.49MB|[cifar10_densenet40_pruned0.7.h5](https://drive.google.com/file/d/1F-5pk36kTMSqgs8zqaWWl8mC06AG5D_Z/view?usp=sharing)|
148 | |Pruned|90.31%|[0, 0.9]|20M|0.41M|2.00MB|[cifar10_densenet40_pruned0.9.h5](https://drive.google.com/file/d/1LLH8s5SAo3kyoQ556W_ePDDTuh8DomMO/view?usp=sharing)|
149 |
150 | ### DenseNet100 on cifar10
151 |
152 | |Model|Acc|[min_rate, max_rate]|Flops|Param num|File size|Download|
153 | |-----|---|--------------------|-----|---------|---------|--------|
154 | |Original|93.26%|None|2540M|3.98M|16.4MB|[cifar10_densenet100_bn.h5](https://drive.google.com/file/d/1GmCfYv32MnoWvSyoYc7M26AhNGc2xRSI/view?usp=sharing)|
155 | |Pruned|93.12%|[0, 0.5]|1703M|2.74M|11.6MB|[cifar10_densenet100_pruned0.5.h5](https://drive.google.com/file/d/1RLNUeS4DKdSAGf1iUaY9FoDn1LEymdUJ/view?usp=sharing)|
156 | |Pruned|93.12%|[0, 0.7]|1467M|2.41M|10.4MB|[cifar10_densenet100_pruned0.7.h5](https://drive.google.com/file/d/17KCmlrBfuecI0Qke45W78wCF6WnauVkp/view?usp=sharing)|
157 | |Pruned|93.00%|[0, 0.9]|1000M|1.65M|7.44MB|[cifar10_densenet100_pruned0.9.h5](https://drive.google.com/file/d/1q5yvwSd146wYm2hnTWfTac3mQ-F1hzhe/view?usp=sharing)|
158 |
159 |
160 | :no_entry: If you run this code, the result would be a bit different from mine.
161 |
162 |
163 | ## Understanding this paper
164 |
165 | :white_check_mark: Check my blog!!
166 | [Here](https://da2so.github.io/2020-10-25-EagleEye_Fast_Sub_net_Evaluation_for_Efficient_Neur_Network_Pruning/)
--------------------------------------------------------------------------------
/core/pruner/graph_wrapper.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import sys
3 | import numpy as np
4 | import inspect
5 |
6 | from tensorflow.python.keras import layers
7 | from tensorflow.keras.models import Model
8 | import tensorflow as tf
9 |
10 | class GraphWrapper(object):
11 | def __init__(self, model):
12 | self.model = model
13 | self.layer_info = defaultdict()
14 | self.input_shape = list(model.layers[0].input.shape[1:])
15 |
16 | #construct
17 | for idx, layer in enumerate(self.model.layers):
18 | if idx == 0:
19 | self.layer_info[id(layer.output)] = {'count': 0}
20 | continue
21 |
22 | self.layer_info[id(layer.output)] = {'layer_input': layer.input, 'count': 0}
23 |
24 | if isinstance(layer.input, list):
25 | for in_layer in layer.input:
26 | self.layer_info[id(in_layer)]['count'] += 1
27 | else:
28 | self.layer_info[id(layer.input)]['count'] += 1
29 |
30 | #check if a conv layer determine the number of classes (ouput classes)
31 | for idx, layer in enumerate(reversed(self.model.layers)):
32 |
33 | if type(layer) is layers.Dense:
34 | self.is_lastConv2D = False
35 | break
36 | elif type(layer) is layers.Conv2D:
37 | self.is_lastConv2D = True
38 | self.lastConv2D = len(self.model.layers)-(idx+1)
39 | break
40 | else:
41 | continue
42 |
43 |
44 | def build(self,del_layer_dict):
45 | prev_prune_idx = None
46 | first_dense = True
47 | #output processing for keeping the number of classes
48 | if self.is_lastConv2D:
49 | del_layer_dict[self.lastConv2D] = []
50 |
51 | #initialize input
52 | input = layers.Input(shape=self.input_shape)
53 | x = input
54 |
55 | for idx, layer in enumerate(self.model.layers):
56 | if idx == 0:
57 | self.layer_info[id(layer.output)].update({'out': x})
58 | continue
59 |
60 | # reconstruct each layer
61 | if type(layer) is layers.convolutional.Conv2D:
62 | prev_prune_idx=self.get_prev_pruned_idx(layer.input, self.layer_info)
63 |
64 | pruned_weights = self.get_conv_pruned_weights(layer, del_layer_dict[idx], prev_prune_idx)
65 | cutted_channl_num = len (del_layer_dict[idx])
66 |
67 | attr_name = inspect.getargspec(layers.convolutional.Conv2D.__init__)[0]
68 | attr_name.pop(0)
69 | attr_value = {attr: getattr(layer ,attr) for attr in attr_name}
70 | attr_value['filters'] -= cutted_channl_num
71 | recon_layer = layers.convolutional.Conv2D(**attr_value)
72 |
73 | self.layer_info[id(layer.output)].update({'pruned_idx': del_layer_dict[idx]})
74 |
75 | elif type(layer) is layers.convolutional.DepthwiseConv2D:
76 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info)
77 |
78 | pruned_weights = self.get_depthwiseconv_pruned_weights(layer, prev_prune_idx)
79 |
80 | recon_layer = layer.__class__.from_config(layer.get_config())
81 |
82 |
83 | elif type(layer) is layers.normalization_v2.BatchNormalization:
84 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info)
85 |
86 | pruned_weights = self.get_batchnorm_pruned_weights(layer, prev_prune_idx)
87 |
88 | cutted_channl_num = len(prev_prune_idx)
89 | config = layer.get_config()
90 | config['gamma_regularizer'] = None
91 | recon_layer = layers.normalization_v2.BatchNormalization.from_config(config)
92 |
93 |
94 | elif type(layer) is layers.Dense:
95 | if first_dense == True:
96 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info)
97 |
98 | pruned_weights = self.get_dense_pruned_weights(layer, prev_prune_idx)
99 |
100 |
101 | first_dense = False
102 | recon_layer = layer.__class__.from_config(layer.get_config())
103 | elif type(layer) is layers.Reshape:
104 | prev_prune_idx = self.get_prev_pruned_idx(layer.input, self.layer_info)
105 |
106 | prev_prune_num = len( prev_prune_idx)
107 |
108 | config = layer.get_config()
109 | original_shape = config['target_shape']
110 | original_shape = list(original_shape)
111 | original_shape[-1] = original_shape[-1]-prev_prune_num
112 | new_shape = tuple(original_shape)
113 | config['target_shape'] = new_shape
114 |
115 | recon_layer = layer.__class__.from_config(config)
116 | else:
117 |
118 | if type(layer) is layers.Add or \
119 | type(layer) is layers.Subtract or \
120 | type(layer) is layers.Multiply or \
121 | type(layer) is layers.Average or \
122 | type(layer) is layers.Maximum or \
123 | type(layer) is layers.Minimum:
124 | pruned_idx = self.set_pruned_idx(layer.input, self.layer_info)
125 | self.layer_info[id(layer.output)].update({'pruned_idx': pruned_idx})
126 |
127 |
128 | if type(layer) is layers.Concatenate:
129 | pruned_idx = self.set_pruned_idx_for_concat(layer.input, self.layer_info)
130 | self.layer_info[id(layer.output)].update({'pruned_idx': pruned_idx})
131 |
132 | recon_layer = layer.__class__.from_config(layer.get_config())
133 |
134 |
135 | # connect layers
136 | if isinstance(layer.input, list):
137 | input_list = []
138 | for in_layer in layer.input:
139 | input_list.append(self.layer_info[id(in_layer)]['out'])
140 | self.layer_info[id(in_layer)]['count'] -= 1
141 | self.del_key(in_layer, self.layer_info)
142 | x = input_list
143 | else:
144 | x = self.layer_info[id(layer.input)]['out']
145 | self.layer_info[id(layer.input)]['count'] -= 1
146 | self.del_key(layer.input, self.layer_info)
147 |
148 | x = recon_layer(x)
149 | self.layer_info[id(layer.output)]['out'] = x
150 |
151 | try:
152 | recon_layer.set_weights(pruned_weights)
153 | except:
154 | pass
155 |
156 | model = Model(inputs=input, outputs=x)
157 |
158 | return model
159 |
160 |
161 | def get_conv_pruned_weights(self, layer, prune_idx, prev_pruned_idx=None):
162 |
163 | weights = layer.get_weights()
164 |
165 | #prune kernel
166 | kernel = weights[0]
167 | pruned_kernel= np.delete(kernel, prune_idx, axis=-1)
168 |
169 | try:
170 | pruned_kernel = np.delete(pruned_kernel, prev_pruned_idx, axis=-2)
171 | except:
172 | pass
173 |
174 | #prune bias
175 | prunned_bias = None
176 | if layer.use_bias:
177 | bias = weights[1]
178 | prunned_bias = np.delete(bias, prune_idx)
179 |
180 | if layer.use_bias == True:
181 | return [pruned_kernel, prunned_bias]
182 | else:
183 | return [pruned_kernel]
184 |
185 | def get_dense_pruned_weights(self, layer, prev_pruned_idx=None):
186 | weights = layer.get_weights()
187 | kernel = weights[0]
188 | pruned_kernel = np.delete(kernel, prev_pruned_idx, axis=-2)
189 |
190 | bias = weights[1]
191 |
192 | return [pruned_kernel, bias]
193 |
194 |
195 | def get_depthwiseconv_pruned_weights(self, layer, prev_pruned_idx=None):
196 | weights = layer.get_weights()
197 | kernel = weights[0]
198 | pruned_kernel = np.delete(kernel, prev_pruned_idx, axis=-2)
199 |
200 | #prune bias
201 | prunned_bias = None
202 | if layer.use_bias:
203 | bias = weights[1]
204 |
205 | if layer.use_bias == True:
206 | return [pruned_kernel, bias]
207 | else:
208 | return [pruned_kernel]
209 |
210 | def get_batchnorm_pruned_weights(self, layer, prune_idx):
211 | weights = layer.get_weights()
212 | pruned_weights = [np.delete(w, prune_idx) for w in weights]
213 |
214 | return pruned_weights
215 |
216 | def get_prev_pruned_idx(self, in_layer, layer_info):
217 |
218 | while 1:
219 | if 'layer_input' not in layer_info[id(in_layer)]:
220 | return None
221 | if 'pruned_idx' not in layer_info[id(in_layer)]:
222 | in_layer = layer_info[id(in_layer)]['layer_input']
223 | else:
224 | prev_pruned_idx = layer_info[id(in_layer)]['pruned_idx']
225 | return prev_pruned_idx
226 |
227 | def set_pruned_idx_for_concat(self, layer_list, layer_info):
228 | pruned_idx_set = set()
229 | is_second_input = False
230 | for in_layer in layer_list:
231 |
232 | while 1:
233 | if 'pruned_idx' not in layer_info[id(in_layer)]:
234 | in_layer = layer_info[id(in_layer)]['layer_input']
235 | else:
236 | if is_second_input == False:
237 | prev_layer_len = in_layer.shape[3]
238 | pruned_idx_set.update(layer_info[id(in_layer)]['pruned_idx'])
239 | is_second_input = True
240 | else:
241 | pruned_idx = layer_info[id(in_layer)]['pruned_idx']+prev_layer_len
242 | pruned_idx_set.update(pruned_idx)
243 | prev_layer_len = in_layer.shape[3]
244 | break
245 |
246 | pruned_idx_list = list(pruned_idx_set)
247 | return pruned_idx_list
248 |
249 |
250 | def set_pruned_idx(self, layer_list, layer_info):
251 | pruned_idx_set = set()
252 | for in_layer in layer_list:
253 |
254 | while 1:
255 | if 'pruned_idx' not in layer_info[id(in_layer)]:
256 | in_layer = layer_info[id(in_layer)]['layer_input']
257 | else:
258 | pruned_num = len(layer_info[id(in_layer)]['pruned_idx'])
259 | pruned_idx_set.update(layer_info[id(in_layer)]['pruned_idx'])
260 | break
261 |
262 | pruned_idx_list = list( pruned_idx_set) [:pruned_num]
263 | return pruned_idx_list
264 |
265 | def del_key(self, layer_input, layer_info):
266 | if layer_info[id(layer_input)]['count'] == 0:
267 | del layer_info[id(layer_input)]['out']
--------------------------------------------------------------------------------