├── .gitignore ├── Cifar_Pruning.ipynb ├── Models ├── Pruned_0.8.ckpt.data-00000-of-00001 ├── Pruned_0.8.ckpt.index ├── Pruned_0.8.ckpt.meta ├── Unpruned_best.ckpt.data-00000-of-00001 ├── Unpruned_best.ckpt.index ├── Unpruned_best.ckpt.meta ├── checkpoint ├── frozen_model.pb ├── quantized_model.tflite ├── quantized_pruned_model.tflite └── tflite │ ├── Conv_1_weights_masked_weight.npy │ ├── Conv_2_weights_masked_weight.npy │ ├── Conv_3_weights_masked_weight.npy │ ├── Conv_4_weights_masked_weight.npy │ ├── Conv_5_weights_masked_weight.npy │ ├── Conv_6_weights_masked_weight.npy │ └── Conv_weights_masked_weight.npy ├── Quantization.ipynb ├── README.md ├── References.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | cifar/ 2 | logs/ 3 | Model_Saves/ 4 | Models_Given/ 5 | __pycache__/ 6 | .ipynb_checkpoints/ 7 | 8 | *.npy 9 | *.h5 10 | *.tflite 11 | 12 | project_2_spec_updated.pdf 13 | Quantization_backup.ipynb 14 | Cifar_Pruning_backup.ipynb -------------------------------------------------------------------------------- /Models/Pruned_0.8.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/Pruned_0.8.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /Models/Pruned_0.8.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/Pruned_0.8.ckpt.index -------------------------------------------------------------------------------- /Models/Pruned_0.8.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/Pruned_0.8.ckpt.meta -------------------------------------------------------------------------------- /Models/Unpruned_best.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/Unpruned_best.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /Models/Unpruned_best.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/Unpruned_best.ckpt.index -------------------------------------------------------------------------------- /Models/Unpruned_best.ckpt.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/Unpruned_best.ckpt.meta -------------------------------------------------------------------------------- /Models/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "Pruned_0.99.ckpt" 2 | all_model_checkpoint_paths: "Pruned_0.1.ckpt" 3 | all_model_checkpoint_paths: "Pruned_0.2.ckpt" 4 | all_model_checkpoint_paths: "Pruned_0.3.ckpt" 5 | all_model_checkpoint_paths: "Pruned_0.4.ckpt" 6 | all_model_checkpoint_paths: "Pruned_0.5.ckpt" 7 | all_model_checkpoint_paths: "Pruned_0.6.ckpt" 8 | all_model_checkpoint_paths: "Pruned_0.7.ckpt" 9 | all_model_checkpoint_paths: "Pruned_0.8.ckpt" 10 | all_model_checkpoint_paths: "Pruned_0.9.ckpt" 11 | all_model_checkpoint_paths: "Pruned_0.95.ckpt" 12 | all_model_checkpoint_paths: "Pruned_0.99.ckpt" 13 | -------------------------------------------------------------------------------- /Models/frozen_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/frozen_model.pb -------------------------------------------------------------------------------- /Models/quantized_model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/quantized_model.tflite -------------------------------------------------------------------------------- /Models/quantized_pruned_model.tflite: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/quantized_pruned_model.tflite -------------------------------------------------------------------------------- /Models/tflite/Conv_1_weights_masked_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/tflite/Conv_1_weights_masked_weight.npy -------------------------------------------------------------------------------- /Models/tflite/Conv_2_weights_masked_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/tflite/Conv_2_weights_masked_weight.npy -------------------------------------------------------------------------------- /Models/tflite/Conv_3_weights_masked_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/tflite/Conv_3_weights_masked_weight.npy -------------------------------------------------------------------------------- /Models/tflite/Conv_4_weights_masked_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/tflite/Conv_4_weights_masked_weight.npy -------------------------------------------------------------------------------- /Models/tflite/Conv_5_weights_masked_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/tflite/Conv_5_weights_masked_weight.npy -------------------------------------------------------------------------------- /Models/tflite/Conv_6_weights_masked_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/tflite/Conv_6_weights_masked_weight.npy -------------------------------------------------------------------------------- /Models/tflite/Conv_weights_masked_weight.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vikranth94/Model-Compression/fb4f6492cca87315a0c1c88a9e9531b8ac3423a3/Models/tflite/Conv_weights_masked_weight.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Model Compression 2 | Model Pruning and Quantization using Tensorflow 3 | -------------------------------------------------------------------------------- /References.txt: -------------------------------------------------------------------------------- 1 | https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/lite/tutorials/post_training_quant.ipynb#scrollTo=g8PUvLWDlmmz 2 | https://heartbeat.fritz.ai/8-bit-quantization-and-tensorflow-lite-speeding-up-mobile-inference-with-low-precision-a882dfcafbbd 3 | 4 | https://cs230-stanford.github.io/tensorflow-input-data.html 5 | https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/5_DataManagement/tensorflow_dataset_api.py 6 | https://stackoverflow.com/questions/50437234/tensorflow-dataset-shuffle-then-batch-or-batch-then-shuffle -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from natsort import natsorted 4 | import imageio 5 | import re 6 | import time 7 | import keras 8 | from keras.models import Sequential 9 | from keras.preprocessing.image import ImageDataGenerator 10 | from keras.layers import Dense, Activation, Flatten, Dropout, BatchNormalization 11 | from keras.layers import Conv2D, MaxPooling2D, AveragePooling2D, Reshape 12 | from keras.callbacks import ModelCheckpoint, TensorBoard 13 | from keras.models import load_model 14 | from sklearn.metrics import confusion_matrix 15 | from sklearn.metrics import f1_score, accuracy_score 16 | from sklearn.utils import shuffle 17 | import matplotlib.pyplot as plt 18 | import itertools 19 | import tensorflow as tf 20 | from tensorflow.contrib.model_pruning.python import pruning 21 | from tensorflow.contrib.model_pruning.python.layers import layers 22 | 23 | NAME = 'Cifar10_CNN' 24 | data_dir = 'cifar' 25 | model_dir = 'Model_Saves' 26 | num_classes = 10 27 | 28 | classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship','truck'] 29 | 30 | class_dict = { 31 | 'airplane': 0, 32 | 'automobile':1, 33 | 'bird':2, 34 | 'cat':3, 35 | 'deer':4, 36 | 'dog':5, 37 | 'frog':6, 38 | 'horse':7, 39 | 'ship':8, 40 | 'truck':9 41 | } 42 | 43 | inv_class_dict = {v: k for k, v in class_dict.items()} 44 | 45 | 46 | def prepare_dataset(data_dir, folder_name): 47 | try: 48 | print('Loading numpy') 49 | X = np.load('X_{}.npy'.format(folder_name)) 50 | y = np.load('y_{}.npy'.format(folder_name)) 51 | 52 | except: 53 | print('Loading images') 54 | image_list = [] 55 | labels = [] 56 | pictures_dir = os.path.join(data_dir, folder_name) 57 | names = [ d for d in os.listdir( pictures_dir ) if d.endswith( '.png') ] 58 | names = natsorted(names) 59 | for image in names: 60 | image_list.append(imageio.imread(os.path.join(pictures_dir, image))) 61 | label = re.split('[._]', image) 62 | labels.append(class_dict[label[1]]) 63 | print(image) 64 | X = np.stack(image_list, axis=0) 65 | y = np.array(labels) 66 | np.save('X_{}'.format(folder_name),X) 67 | np.save('y_{}'.format(folder_name),y) 68 | return X,y 69 | 70 | #z-score 71 | def z_normalization(X, mean, std): 72 | X = (X-mean)/(std+1e-7) 73 | return X 74 | 75 | def sample_batch(dataset, labels, batch_size): 76 | N = dataset.shape[0] 77 | indices = np.random.randint(N, size=batch_size) 78 | x_epoch = dataset[indices] 79 | y_epoch = labels[indices] 80 | return x_epoch, y_epoch 81 | 82 | def set_prune_params(s): 83 | # Get, Print, and Edit Pruning Hyperparameters 84 | pruning_hparams = pruning.get_pruning_hparams() 85 | print("Pruning Hyperparameters:", pruning_hparams) 86 | 87 | # Change hyperparameters to meet our needs 88 | pruning_hparams.begin_pruning_step = 0 89 | pruning_hparams.end_pruning_step = 250 90 | pruning_hparams.pruning_frequency = 1 91 | pruning_hparams.sparsity_function_end_step = 250 92 | pruning_hparams.target_sparsity = s 93 | 94 | # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam 95 | p = pruning.Pruning(pruning_hparams, global_step=global_step) 96 | prune_op = p.conditional_mask_update_op() 97 | return prune_op 98 | 99 | def create_CNN_model(inp_shape, num_classes, p=0.2): 100 | model = Sequential() 101 | model.add(Conv2D(32, kernel_size=(3, 3), 102 | activation='relu', 103 | input_shape=inp_shape, 104 | padding='same', name='Conv_1')) 105 | model.add(BatchNormalization(name='Bn_1')) 106 | model.add(Conv2D(32, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_2')) 107 | model.add(BatchNormalization(name='Bn_2')) 108 | model.add(MaxPooling2D(pool_size=(2, 2), name='Max_pool_1')) 109 | model.add(Dropout(p, name='Drop_1')) 110 | model.add(Conv2D(64, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_3')) 111 | model.add(BatchNormalization(name='Bn_3')) 112 | model.add(Conv2D(64, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_4')) 113 | model.add(BatchNormalization(name='Bn_4')) 114 | model.add(MaxPooling2D(pool_size=(2, 2), name='Max_pool_2')) 115 | model.add(Dropout(p, name='Drop_2')) 116 | model.add(Conv2D(128, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_5')) 117 | model.add(BatchNormalization(name='Bn_5')) 118 | model.add(Conv2D(128, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_6')) 119 | model.add(BatchNormalization(name='Bn_6')) 120 | model.add(MaxPooling2D(pool_size=(2, 2), name='Max_pool_3')) 121 | model.add(Dropout(p, name='Drop_3')) 122 | model.add(Flatten(name = 'Flatten_1')) 123 | model.add(Dense(32, activation='relu')) 124 | model.add(BatchNormalization(name='Bn_7')) 125 | model.add(Dropout(p, name='Drop_4')) 126 | model.add(Dense(num_classes, activation='softmax', name='dense_out')) 127 | print(model.summary()) 128 | return model 129 | 130 | def create_FCN_model(inp_shape, num_classes, p=0.2): 131 | model = Sequential() 132 | model.add(Conv2D(96, kernel_size=(3, 3), 133 | activation='relu', 134 | input_shape=inp_shape, 135 | padding='same', name='Conv_1')) 136 | model.add(BatchNormalization(name='Bn_1')) 137 | model.add(Conv2D(96, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_2')) 138 | model.add(BatchNormalization(name='Bn_2')) 139 | model.add(MaxPooling2D(pool_size=(3, 3), strides = 2, padding = 'same', name='Max_pool_1')) 140 | model.add(Dropout(p, name='Drop_1')) 141 | model.add(Conv2D(192, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_3')) 142 | model.add(BatchNormalization(name='Bn_3')) 143 | model.add(Conv2D(192, kernel_size=(3, 3), activation='relu',padding='same', name='Conv_4')) 144 | model.add(BatchNormalization(name='Bn_4')) 145 | model.add(MaxPooling2D(pool_size=(3, 3), strides = 2, padding = 'same',name='Max_pool_2')) 146 | model.add(Dropout(p, name='Drop_2')) 147 | model.add(Conv2D(192, kernel_size=(3, 3), activation='relu',padding='valid', name='Conv_5')) 148 | model.add(BatchNormalization(name='Bn_5')) 149 | model.add(Conv2D(192, kernel_size=(1, 1), activation='relu',padding='same', name='Conv_6')) 150 | model.add(BatchNormalization(name='Bn_6')) 151 | model.add(Conv2D(10, kernel_size=(1, 1), activation='relu',padding='same', name='Conv_7')) 152 | model.add(BatchNormalization(name='Bn_7')) 153 | model.add(Dropout(p, name='Drop_4')) 154 | model.add(AveragePooling2D(pool_size=(6, 6), strides=1, name='avg_pool')) 155 | model.add(Flatten(name = 'Flatten_1')) 156 | model.add(Activation('softmax', name = 'output')) 157 | print(model.summary()) 158 | return model 159 | 160 | def tf_fcn_model(image): 161 | 162 | _=image 163 | _ = layers.masked_conv2d(_, 96, (3, 3), 1, 'SAME') 164 | _ = tf.layers.batch_normalization(_, name='norm1-1') 165 | _ = layers.masked_conv2d(_, 96, (3, 3), 1, 'SAME') 166 | _ = tf.layers.batch_normalization(_, name='norm1-2') 167 | _ = tf.layers.max_pooling2d(_, (3, 3), 2, 'SAME',name='pool1') 168 | _ = layers.masked_conv2d(_, 192, (3, 3), 1, 'SAME') 169 | _ = tf.layers.batch_normalization(_, name='norm2-1') 170 | _ = layers.masked_conv2d(_, 192, (3, 3), 1, 'SAME') 171 | _ = tf.layers.batch_normalization(_, name='norm2-2') 172 | _ = tf.layers.max_pooling2d(_, (3, 3), 2, 'SAME', name='pool2') 173 | _ = layers.masked_conv2d(_, 192, (3, 3), 1, 'VALID') 174 | _ = tf.layers.batch_normalization(_, name='norm3') 175 | _ = layers.masked_conv2d(_, 192, (1, 1), 1) 176 | _ = tf.layers.batch_normalization(_, name='norm4') 177 | _ = layers.masked_conv2d(_, 10, (1, 1), 1) 178 | _ = tf.layers.batch_normalization(_, name='norm5') 179 | _ = tf.layers.average_pooling2d(_, (6,6), 1, name='avg_pool') 180 | y = _ 181 | logits = tf.reshape(y,[tf.shape(y)[0],10]) 182 | return logits 183 | 184 | def train_model(model, X_train, y_train, X_val, y_val, model_dir, t, batch_size=256, epochs=50): 185 | 186 | model.compile(loss=keras.losses.categorical_crossentropy, 187 | optimizer='adam', 188 | metrics=['accuracy']) 189 | 190 | # checkpoint 191 | chk_path = os.path.join(model_dir, 'best_{}_{}'.format(NAME,t)) 192 | checkpoint = ModelCheckpoint(chk_path, monitor='val_acc', verbose=1, save_best_only=True, mode='max') 193 | tensorboard = TensorBoard(log_dir="logs/{}_{}".format(NAME,t)) 194 | callbacks_list = [checkpoint, tensorboard] 195 | 196 | history = model.fit(X_train, y_train, 197 | batch_size=batch_size, 198 | epochs=epochs, 199 | verbose=1, 200 | shuffle=True, 201 | validation_data=(X_val, y_val), 202 | callbacks=callbacks_list) 203 | 204 | #Saving the model 205 | model.save(os.path.join(model_dir, 'final_{}_{}'.format(NAME,t))) 206 | return model, history 207 | 208 | def calculate_metrics(model, X_test, y_test_binary): 209 | y_pred = np.argmax(model.predict(X_test), axis=1) 210 | y_true = np.argmax(y_test_binary, axis=1) 211 | mismatch = np.where(y_true != y_pred) 212 | cf_matrix = confusion_matrix(y_true, y_pred) 213 | accuracy = accuracy_score(y_true, y_pred) 214 | #micro_f1 = f1_score(y_true, y_pred, average='micro') 215 | macro_f1 = f1_score(y_true, y_pred, average='macro') 216 | return cf_matrix, accuracy, macro_f1, mismatch, y_pred 217 | 218 | 219 | def plot_confusion_matrix(cm, classes, 220 | normalize=False, 221 | title='Confusion matrix', 222 | cmap=plt.cm.Blues): 223 | """ 224 | This function prints and plots the confusion matrix. 225 | Normalization can be applied by setting `normalize=True`. 226 | """ 227 | if normalize: 228 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 229 | print("Normalized confusion matrix") 230 | print(cm) 231 | else: 232 | print('Confusion matrix, without normalization') 233 | print(cm) 234 | 235 | plt.figure(figsize = (10,7)) 236 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 237 | plt.title(title) 238 | plt.colorbar() 239 | tick_marks = np.arange(len(classes)) 240 | plt.xticks(tick_marks, classes, rotation=45, fontsize = 15) 241 | plt.yticks(tick_marks, classes, fontsize = 15) 242 | 243 | fmt = '.2f' if normalize else 'd' 244 | thresh = cm.max() / 2. 245 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 246 | plt.text(j, i, format(cm[i, j], fmt), fontsize = 15, 247 | horizontalalignment="center", 248 | color="white" if cm[i, j] > thresh else "black") 249 | 250 | plt.tight_layout() 251 | 252 | plt.ylabel('True label', fontsize = 12) 253 | plt.xlabel('Predicted label', fontsize = 12) --------------------------------------------------------------------------------