├── .gitignore ├── README.md ├── mnist_iterative_pruning ├── iterative_prune.py ├── log.txt ├── log_iterative.txt ├── read_image.py └── sparse_op.py └── mnist_pruning ├── __init__.py ├── read_image.py ├── seven.png ├── sparse_op.py ├── test_dense.py ├── test_sparse.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | cifar10_iterative_pruning\models 2 | cifar10_iterative_pruning\cifar10_data -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | An implementation of [Iterative Pruning](https://arxiv.org/abs/1506.02626), current on mnist only. 2 | 3 | Thanks [this repository](https://github.com/garion9013/impl-pruning-TF) 4 | 5 | ## Usage 6 | ### Iterative Pruning 7 | ``` 8 | cd mnist_iterative_pruning 9 | python iterative_prune.py -1 -2 -3 10 | ``` 11 | this would train a convolution model on mnist. Then do pruning on fc layer and retraining for 20 times. Finally fc layers would be transformed to a sparse format and saved. 12 | 13 | ## Performance 14 | 15 | we have a pretty good pruning performance, keeping accuracy at 0.987 while pruning 99.77% weights in fc layer. 16 | 17 | |weight kept ratio|accuracy| 18 | |-----------------|--------| 19 | |1 |0.99 | 20 | |0.7 |0.991 | 21 | |0.49 |0.993 | 22 | |0.24 |0.994 | 23 | |0.117 |0.993 | 24 | |0.057 |0.994 | 25 | |0.013 |0.993 | 26 | |0.009 |0.992 | 27 | |0.0047 |0.99 | 28 | |0.0023 |0.987 | 29 | |0.0016 |0.889 | 30 | |0.0011 |0.886 | 31 | |0.00079 |0.677 | 32 | |0.00056 |0.409 | 33 | 34 | in term of inference time, dense vs sparse: 1.47 vs 0.68 35 | -------------------------------------------------------------------------------- /mnist_iterative_pruning/iterative_prune.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tensorflow as tf 3 | import numpy as np 4 | import argparse 5 | from sparse_op import sparse_dense_matmul_b 6 | import read_image 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | import time 9 | 10 | argparser = argparse.ArgumentParser() 11 | argparser.add_argument("-1", "--train", action="store_true", 12 | help="train dense MNIST model with 20000 iterations") 13 | argparser.add_argument("-2", "--prune", action="store_true", 14 | help="prune model and retrain, iteratively") 15 | argparser.add_argument("-3", "--sparse", action="store_true", 16 | help="transform model to a sparse format and save it") 17 | argparser.add_argument("-m", "--checkpoint", default="./models/model_ckpt_dense", 18 | help="Target checkpoint model file for 2nd and 3rd round") 19 | argparser.add_argument("-i", "--iteration", default="20", 20 | help="iteration number") 21 | argparser.add_argument("-p", "--percentage", default="0.7", 22 | help="pruning percentage in every iteration") 23 | args = argparser.parse_args() 24 | if(args.train or args.prune or args.sparse) == False: 25 | argparser.print_help() 26 | sys.exit() 27 | mnist = input_data.read_data_sets('H:/data/', one_hot=True) 28 | sess = tf.Session() 29 | file_obj = open("log.txt", "w") 30 | 31 | def dense_cnn_model(image, weights, keep_prob): 32 | def conv2d(x, W): 33 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME") 34 | def max_pool_2x2(x): 35 | return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 36 | x_image = tf.reshape(image, [-1,28,28,1]) 37 | h_conv1 = tf.nn.relu(conv2d(x_image, weights["w_conv1"]) + weights["b_conv1"]) 38 | #[-1,14,14,32] 39 | h_pool1 = max_pool_2x2(h_conv1) 40 | h_conv2 = tf.nn.relu(conv2d(h_pool1, weights["w_conv2"]) + weights["b_conv2"]) 41 | #[-1,7,7,64] 42 | h_pool2 = max_pool_2x2(h_conv2) 43 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 44 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, weights["w_fc1"]) + weights["b_fc1"]) 45 | h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob=keep_prob) 46 | #[-1,10] 47 | logit = tf.matmul(h_fc1_dropout, weights["w_fc2"]) + weights["b_fc2"] 48 | return h_pool2_flat, h_fc1, logit 49 | 50 | def test(predict_logit): 51 | correct_prediction = tf.equal(tf.argmax(predict_logit,1), tf.argmax(y_,1)) 52 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 53 | result = 0 54 | for i in range(20): 55 | batch = mnist.test.next_batch(500) 56 | result = result + sess.run(accuracy, feed_dict={x:batch[0], y_:batch[1], keep_prob : 1.0}) 57 | result = result / 20.0 58 | return result 59 | 60 | def prune(weights, th): 61 | ''' 62 | :param weights: weight Variable 63 | :param th: float value, weight under th will be pruned 64 | :return: sparse_weight: weight matrix after pruning, which is a 2d numpy array 65 | ~under_threshold: boolean matrix in same shape with sparse_weight indicating whether corresponding element is zero 66 | ''' 67 | shape = weights.shape 68 | weight_arr = sess.run(weights) 69 | under_threshold = abs(weight_arr) < th 70 | weight_arr[under_threshold] = 0 71 | tmp = weight_arr 72 | #set last matrix elemet to a small number, I have to do that since the drawback of tensorflow sparse matrix support 73 | #hope it would have less impact on model 74 | for i in range(len(shape) - 1): 75 | tmp = tmp[-1] 76 | if(tmp[-1] == 0): 77 | tmp[-1] = 0.01 78 | count = np.sum(under_threshold) 79 | #print ("None-zero element: %s" % (weight_arr.size - count)) 80 | return weight_arr, ~under_threshold 81 | 82 | def get_th(weight, percentage=0.8): 83 | flat = tf.reshape(weight, [-1]) 84 | flat_list = sorted(map(abs,sess.run(flat))) 85 | return flat_list[int(len(flat_list) * percentage)] 86 | 87 | def delete_none_grads(grads): 88 | count = 0 89 | length = len(grads) 90 | while(count < length): 91 | if(grads[count][0] == None): 92 | del grads[count] 93 | length -= 1 94 | else: 95 | count += 1 96 | 97 | def transfer_to_sparse(weight): 98 | weight_arr = sess.run(weight) 99 | values = weight_arr[weight_arr != 0] 100 | indices = np.transpose(np.nonzero(weight_arr)) 101 | shape = list(weight_arr.shape) 102 | return [indices, values, shape] 103 | 104 | dense_w = { 105 | "w_conv1":tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1), name="w_conv1"), 106 | "b_conv1":tf.Variable(tf.constant(0.1, shape=[32]), name="b_conv1"), 107 | "w_conv2":tf.Variable(tf.truncated_normal([5,5,32,64], stddev=0.1), name="w_conv2"), 108 | "b_conv2":tf.Variable(tf.constant(0.1, shape=[64]), name="b_conv2"), 109 | "w_fc1":tf.Variable(tf.truncated_normal([7*7*64,1024], stddev=0.1), name="w_fc1"), 110 | "b_fc1":tf.Variable(tf.constant(0.1, shape=[1024]), name="b_fc1"), 111 | "w_fc2":tf.Variable(tf.truncated_normal([1024,10], stddev=0.1), name="w_fc2"), 112 | "b_fc2":tf.Variable(tf.constant(0.1, shape=[10]), name="b_fc2") 113 | } 114 | 115 | if(args.train == True): 116 | x = tf.placeholder(tf.float32, [None, 784], name="x") 117 | y_ = tf.placeholder(tf.float32, [None, 10], name="y_") 118 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 119 | 120 | useless1, useless2, logit = dense_cnn_model(x, dense_w, keep_prob) 121 | 122 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y_) 123 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 124 | 125 | correct_prediction = tf.equal(tf.arg_max(logit,1), tf.arg_max(y_, 1)) 126 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 127 | 128 | sess.run(tf.global_variables_initializer()) 129 | 130 | for i in range(20000): 131 | batch = mnist.train.next_batch(50) 132 | if i % 100 == 0: 133 | train_acc = sess.run(accuracy, feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 134 | print("step %d, training acc %g" % (i , train_acc)) 135 | sess.run(train_step, feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 136 | 137 | start = time.clock() 138 | test_acc = test(logit) 139 | end = time.clock() 140 | print("test acc %g" % test_acc + ", inference time:" + str(end - start)) 141 | saver = tf.train.Saver() 142 | saver.save(sess, "./models/model_ckpt_dense") 143 | 144 | if(args.prune == True): 145 | print("total pruning iteration: %d. pruning percentage each iter: %g" % (int(args.iteration), float(args.percentage))) 146 | file_obj.write("total pruning iteration: %d. pruning percentage each iter: %g\n" % (int(args.iteration), float(args.percentage))) 147 | saver = tf.train.Saver() 148 | saver.restore(sess, args.checkpoint) 149 | p = 1.0 150 | for i in range(int(args.iteration)): 151 | p = p * float(args.percentage) 152 | print("\033[0;31miteration %d, p=%g\033[0m" % (i, p)) 153 | file_obj.write("iteration %d, p=%g\n" % (i, p)) 154 | file_obj.flush() 155 | th_fc1 = get_th(dense_w["w_fc1"], percentage=(1.0 - p)) 156 | th_fc2 = get_th(dense_w["w_fc2"], percentage=(1.0 - p)) 157 | sp_w_fc1, idx_fc1 = prune(dense_w["w_fc1"], th_fc1) 158 | sp_w_fc2, idx_fc2 = prune(dense_w["w_fc2"], th_fc2) 159 | sess.run(tf.assign(dense_w["w_fc1"], sp_w_fc1)) 160 | sess.run(tf.assign(dense_w["w_fc2"], sp_w_fc2)) 161 | 162 | array_wfc1 = sess.run(dense_w["w_fc1"]) 163 | array_wfc2 = sess.run(dense_w["w_fc2"]) 164 | 165 | print("none-zero in fc1 :%d" % np.sum(array_wfc1 != 0)) 166 | print("none-zero in fc2 :%d" % np.sum(array_wfc2 != 0)) 167 | 168 | for var in tf.global_variables(): 169 | if sess.run(tf.is_variable_initialized(var)) == False: 170 | sess.run(var.initializer) 171 | 172 | x = tf.placeholder(tf.float32, [None, 784], name="x") 173 | y_ = tf.placeholder(tf.float32, [None, 10], name="y_") 174 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 175 | useless1, useless2, logit = dense_cnn_model(x, dense_w, keep_prob) 176 | test_acc = test(logit) 177 | print("\033[0;31mtest acc after iteration %d pruning: %g\033[0m" % (i, test_acc)) 178 | file_obj.write("test acc after iteration %d pruning: %g\n" % (i, test_acc)) 179 | 180 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y_) 181 | trainer = tf.train.AdamOptimizer(1e-4) 182 | grads = trainer.compute_gradients(cross_entropy) 183 | delete_none_grads(grads) 184 | 185 | count = 0 186 | for grad, var in grads: 187 | if (var.name == "w_fc1:0"): 188 | idx_in1 = tf.cast(tf.constant(idx_fc1), tf.float32) 189 | grads[count] = (tf.multiply(idx_in1, grad), var) 190 | if (var.name == "w_fc2:0"): 191 | idx_in2 = tf.cast(tf.constant(idx_fc2), tf.float32) 192 | grads[count] = (tf.multiply(idx_in2, grad), var) 193 | count += 1 194 | train_step = trainer.apply_gradients(grads) 195 | 196 | correct_prediction = tf.equal(tf.argmax(logit, 1), tf.argmax(y_, 1)) 197 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 198 | 199 | for var in tf.global_variables(): 200 | if sess.run(tf.is_variable_initialized(var)) == False: 201 | sess.run(var.initializer) 202 | 203 | for j in range(10000): 204 | batch = mnist.train.next_batch(50) 205 | idx_in1_value = sess.run(idx_in1) 206 | grads_fc1_value = sess.run(grads, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 207 | if (j % 3000 == 0 or j == 100 or j == 500 or j == 1000 or j == 1500 or j == 9999): 208 | train_acc = sess.run(accuracy, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 209 | print("retraining step %d, acc %g" % (j, train_acc)) 210 | sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 211 | 212 | array_wfc1 = sess.run(dense_w["w_fc1"]) 213 | array_wfc2 = sess.run(dense_w["w_fc2"]) 214 | 215 | print("none-zero in fc1 after retrain:%d" % np.sum(array_wfc1 != 0)) 216 | print("none-zero in fc2 after retrain:%d" % np.sum(array_wfc2 != 0)) 217 | 218 | test_acc = test(logit) 219 | print("\033[0;31mtest acc after interation %d pruning and retraining: %g\033[0m" % (i, test_acc)) 220 | file_obj.write("test acc after interation %d pruning and retraining: %g\n" % (i, test_acc)) 221 | 222 | saver = tf.train.Saver(dense_w) 223 | saver.save(sess, "./models/model_ckpt_dense_retrained", global_step=i) 224 | 225 | saver = tf.train.Saver(dense_w) 226 | saver.save(sess, "./models/model_ckpt_dense_retrained") 227 | 228 | if(args.sparse == True): 229 | saver = tf.train.Saver() 230 | saver.restore(sess, "./models/model_ckpt_dense_retrained-16") 231 | 232 | sparse_w = { 233 | "w_conv1": tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1)), 234 | "b_conv1": tf.Variable(tf.constant(0.1, shape=[32])), 235 | "w_conv2": tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1)), 236 | "b_conv2": tf.Variable(tf.constant(0.1, shape=[64])), 237 | "w_fc1": tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1)), 238 | "b_fc1": tf.Variable(tf.constant(0.1, shape=[1024])), 239 | "w_fc2": tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1)), 240 | "b_fc2": tf.Variable(tf.constant(0.1, shape=[10])) 241 | } 242 | 243 | def sparse_cnn_model(image, sparse_weight): 244 | def conv2d(x, W): 245 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") 246 | 247 | def max_pool_2x2(x): 248 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") 249 | 250 | x_image = tf.reshape(image, [-1, 28, 28, 1]) 251 | h_conv1 = tf.nn.relu(conv2d(x_image, sparse_weight["w_conv1"]) + sparse_weight["b_conv1"]) 252 | # [-1,14,14,32] 253 | h_pool1 = max_pool_2x2(h_conv1) 254 | h_conv2 = tf.nn.relu(conv2d(h_pool1, sparse_weight["w_conv2"]) + sparse_weight["b_conv2"]) 255 | # [-1,7,7,64] 256 | h_pool2 = max_pool_2x2(h_conv2) 257 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 258 | ndarray_w_fc1_idx = sess.run(sparse_weight["w_fc1_idx"]) 259 | ndarray_w_fc1 = sess.run(sparse_weight["w_fc1"]) 260 | ndarray_w_fc1_shape = sess.run(sparse_weight["w_fc1_shape"]) 261 | h_fc1 = tf.nn.relu(sparse_dense_matmul_b(ndarray_w_fc1_idx, ndarray_w_fc1, ndarray_w_fc1_shape, h_pool2_flat, True) + sparse_weight["b_fc1"]) 262 | ndarray_w_fc2_idx = sess.run(sparse_weight["w_fc2_idx"]) 263 | ndarray_w_fc2 = sess.run(sparse_weight["w_fc2"]) 264 | ndarray_w_fc2_shape = sess.run(sparse_weight["w_fc2_shape"]) 265 | logit = sparse_dense_matmul_b(ndarray_w_fc2_idx, ndarray_w_fc2, ndarray_w_fc2_shape, h_fc1, True) + sparse_weight["b_fc2"] 266 | return h_pool2_flat, h_fc1, logit 267 | 268 | copy_ops = [] 269 | for key, value in dense_w.items(): 270 | copy_ops.append(sparse_w[key].assign(value)) 271 | for e in copy_ops: 272 | sess.run(e) 273 | 274 | fc1_sparse_tmp = transfer_to_sparse(dense_w["w_fc1"]) 275 | sparse_w["w_fc1_idx"] = tf.Variable(tf.constant(fc1_sparse_tmp[0], dtype=tf.int64)\ 276 | , name="w_fc1_idx") 277 | sparse_w["w_fc1"] = tf.Variable(tf.constant(fc1_sparse_tmp[1], dtype=tf.float32)\ 278 | , name="w_fc1") 279 | sparse_w["w_fc1_shape"] = tf.Variable(tf.constant(fc1_sparse_tmp[2], dtype=tf.int64)\ 280 | , name="w_fc1_shape") 281 | fc2_sparse_tmp = transfer_to_sparse(dense_w["w_fc2"]) 282 | sparse_w["w_fc2_idx"] = tf.Variable(tf.constant(fc2_sparse_tmp[0], dtype=tf.int64)\ 283 | , name="w_fc2_idx") 284 | sparse_w["w_fc2"] = tf.Variable(tf.constant(fc2_sparse_tmp[1], dtype=tf.float32)\ 285 | , name="w_fc2") 286 | sparse_w["w_fc2_shape"] = tf.Variable(tf.constant(fc2_sparse_tmp[2], dtype=tf.int64)\ 287 | , name="w_fc2_shape") 288 | 289 | for var in tf.global_variables(): 290 | if sess.run(tf.is_variable_initialized(var)) == False: 291 | sess.run(var.initializer) 292 | 293 | x = tf.placeholder(tf.float32, [None, 784], name="x") 294 | y_ = tf.placeholder(tf.float32, [None, 10], name="y_") 295 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 296 | tf.add_to_collection("x_placeholder", x) 297 | 298 | #dense_prediction 299 | dense_pool2_flat, dense_h_fc1, dense_logit = dense_cnn_model(x, dense_w, keep_prob) 300 | #sparse prediction 301 | sp_pool2_flat, sp_h_fc1, sp_logit = sparse_cnn_model(x, sparse_w) 302 | 303 | start = time.clock() 304 | test_acc_dense = test(dense_logit) 305 | end = time.clock() 306 | print("dense acc:" + str(test_acc_dense) + ", inference time:" + str(end - start)) 307 | 308 | start = time.clock() 309 | test_acc_sp = test(sp_logit) 310 | end = time.clock() 311 | print("sp acc" + str(test_acc_sp) + ", inference time:" + str(end - start)) 312 | 313 | tf.add_to_collection("sp_logit", sp_logit) 314 | 315 | sparse_saver = tf.train.Saver(sparse_w) 316 | sparse_saver.save(sess, "./models/model_ckpt_sparse_retrained") 317 | 318 | file_obj.close() -------------------------------------------------------------------------------- /mnist_iterative_pruning/log.txt: -------------------------------------------------------------------------------- 1 | total pruning iteration: 20. pruning percentage each iter: 0.7 2 | iteration 0, p=0.7 3 | -------------------------------------------------------------------------------- /mnist_iterative_pruning/log_iterative.txt: -------------------------------------------------------------------------------- 1 | total pruning iteration: 100. pruning percentage each iter: 0.7 2 | iteration 0, p=0.7 3 | test acc after iteration 0 pruning: 0.9908 4 | test acc after interation 0 pruning and retraining: 0.9929 5 | iteration 1, p=0.49 6 | test acc after iteration 1 pruning: 0.9902 7 | test acc after interation 1 pruning and retraining: 0.9934 8 | iteration 2, p=0.343 9 | test acc after iteration 2 pruning: 0.9911 10 | test acc after interation 2 pruning and retraining: 0.993 11 | iteration 3, p=0.2401 12 | test acc after iteration 3 pruning: 0.9873 13 | test acc after interation 3 pruning and retraining: 0.9939 14 | iteration 4, p=0.16807 15 | test acc after iteration 4 pruning: 0.9435 16 | test acc after interation 4 pruning and retraining: 0.9937 17 | iteration 5, p=0.117649 18 | test acc after iteration 5 pruning: 0.9726 19 | test acc after interation 5 pruning and retraining: 0.9929 20 | iteration 6, p=0.0823543 21 | test acc after iteration 6 pruning: 0.9756 22 | test acc after interation 6 pruning and retraining: 0.9934 23 | iteration 7, p=0.057648 24 | test acc after iteration 7 pruning: 0.9719 25 | test acc after interation 7 pruning and retraining: 0.994 26 | iteration 8, p=0.0403536 27 | test acc after iteration 8 pruning: 0.9642 28 | test acc after interation 8 pruning and retraining: 0.9939 29 | iteration 9, p=0.0282475 30 | test acc after iteration 9 pruning: 0.9547 31 | test acc after interation 9 pruning and retraining: 0.9939 32 | iteration 10, p=0.0197733 33 | test acc after iteration 10 pruning: 0.9457 34 | test acc after interation 10 pruning and retraining: 0.9929 35 | iteration 11, p=0.0138413 36 | test acc after iteration 11 pruning: 0.9769 37 | test acc after interation 11 pruning and retraining: 0.993 38 | iteration 12, p=0.0096889 39 | test acc after iteration 12 pruning: 0.9855 40 | test acc after interation 12 pruning and retraining: 0.9918 41 | iteration 13, p=0.00678223 42 | test acc after iteration 13 pruning: 0.9731 43 | test acc after interation 13 pruning and retraining: 0.9914 44 | iteration 14, p=0.00474756 45 | test acc after iteration 14 pruning: 0.987 46 | test acc after interation 14 pruning and retraining: 0.9905 47 | iteration 15, p=0.00332329 48 | test acc after iteration 15 pruning: 0.9772 49 | test acc after interation 15 pruning and retraining: 0.9868 50 | iteration 16, p=0.00232631 51 | test acc after iteration 16 pruning: 0.9705 52 | test acc after interation 16 pruning and retraining: 0.9866 53 | iteration 17, p=0.00162841 54 | test acc after iteration 17 pruning: 0.7872 55 | test acc after interation 17 pruning and retraining: 0.8896 56 | iteration 18, p=0.00113989 57 | test acc after iteration 18 pruning: 0.7894 58 | test acc after interation 18 pruning and retraining: 0.8864 59 | iteration 19, p=0.000797923 60 | test acc after iteration 19 pruning: 0.5552 61 | test acc after interation 19 pruning and retraining: 0.6766 62 | iteration 20, p=0.000558546 63 | test acc after iteration 20 pruning: 0.3844 64 | test acc after interation 20 pruning and retraining: 0.4098 65 | iteration 21, p=0.000390982 66 | test acc after iteration 21 pruning: 0.3007 67 | test acc after interation 21 pruning and retraining: 0.3079 68 | iteration 22, p=0.000273687 69 | test acc after iteration 22 pruning: 0.3067 70 | test acc after interation 22 pruning and retraining: 0.3074 71 | iteration 23, p=0.000191581 72 | test acc after iteration 23 pruning: 0.2882 73 | test acc after interation 23 pruning and retraining: 0.3065 74 | iteration 24, p=0.000134107 75 | test acc after iteration 24 pruning: 0.2275 76 | test acc after interation 24 pruning and retraining: 0.3037 77 | iteration 25, p=9.38748e-05 78 | test acc after iteration 25 pruning: 0.1955 79 | test acc after interation 25 pruning and retraining: 0.2127 80 | iteration 26, p=6.57124e-05 81 | test acc after iteration 26 pruning: 0.2135 82 | test acc after interation 26 pruning and retraining: 0.2119 83 | iteration 27, p=4.59987e-05 84 | test acc after iteration 27 pruning: 0.2084 85 | test acc after interation 27 pruning and retraining: 0.2118 86 | iteration 28, p=3.21991e-05 87 | test acc after iteration 28 pruning: 0.213 88 | test acc after interation 28 pruning and retraining: 0.2111 89 | iteration 29, p=2.25393e-05 90 | test acc after iteration 29 pruning: 0.2129 91 | test acc after interation 29 pruning and retraining: 0.2103 92 | iteration 30, p=1.57775e-05 93 | test acc after iteration 30 pruning: 0.1414 94 | test acc after interation 30 pruning and retraining: 0.2112 95 | iteration 31, p=1.10443e-05 96 | test acc after iteration 31 pruning: 0.1041 97 | test acc after interation 31 pruning and retraining: 0.2115 98 | iteration 32, p=7.73099e-06 99 | test acc after iteration 32 pruning: 0.1009 100 | test acc after interation 32 pruning and retraining: 0.2109 101 | iteration 33, p=5.4117e-06 102 | test acc after iteration 33 pruning: 0.2102 103 | test acc after interation 33 pruning and retraining: 0.2105 104 | iteration 34, p=3.78819e-06 105 | test acc after iteration 34 pruning: 0.1657 106 | test acc after interation 34 pruning and retraining: 0.2114 107 | iteration 35, p=2.65173e-06 108 | test acc after iteration 35 pruning: 0.2038 109 | test acc after interation 35 pruning and retraining: 0.2117 110 | iteration 36, p=1.85621e-06 111 | test acc after iteration 36 pruning: 0.1135 112 | test acc after interation 36 pruning and retraining: 0.1135 113 | iteration 37, p=1.29935e-06 114 | test acc after iteration 37 pruning: 0.1135 115 | test acc after interation 37 pruning and retraining: 0.1135 116 | iteration 38, p=9.09544e-07 117 | test acc after iteration 38 pruning: 0.1135 118 | test acc after interation 38 pruning and retraining: 0.1135 119 | iteration 39, p=6.36681e-07 120 | -------------------------------------------------------------------------------- /mnist_iterative_pruning/read_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | def read_image(path): 5 | tmp = np.array(Image.open(path).resize((28, 28), resample=2)) 6 | img = np.zeros((28, 28, 1)) 7 | img[:, :, 0] = tmp[:, :, 0] 8 | return img -------------------------------------------------------------------------------- /mnist_iterative_pruning/sparse_op.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def printtt(): 5 | pass 6 | 7 | def sparse_dense_matmul_b(sp_indices, sp_value, sp_shape, b, swap = False): 8 | ''' 9 | multiplication of a sparse matrix and a dense matrix 10 | first three params is exactly the three parameter of SparseTensor constructor 11 | a * b if swap==False else b * sp_a 12 | shape error happens in extremely sparse case, where right-bottom margin elements are all zero 13 | I do not know how to fix it, seldom happens in practice 14 | set last number of sparse matrix to a very small number is recommanded 15 | :param sp_indices: 16 | :param sp_value: 17 | :param sp_shape: 18 | :param b: 19 | :param swap: 20 | :return: 21 | ''' 22 | if(not swap): 23 | sp_a = tf.SparseTensor(sp_indices, sp_value, sp_shape) 24 | return tf.sparse_tensor_dense_matmul(sp_a, b) 25 | else: 26 | b = tf.transpose(b) 27 | sp_indices = np.array(sp_indices) 28 | internal_sp_indices = sp_indices[:,1] 29 | tmp = [] 30 | for c in internal_sp_indices: 31 | tmp.append([c, 0]) 32 | internal_sp_indices = tmp 33 | sp_indice_value = sp_indices[:, 0] 34 | sp_value = np.array(sp_value).astype(float) 35 | tmp1 = tf.sparse_reorder(tf.SparseTensor(indices=internal_sp_indices, values=sp_indice_value, dense_shape=sp_shape)) 36 | tmp2 = tf.sparse_reorder(tf.SparseTensor(indices=internal_sp_indices, values=sp_value, dense_shape=sp_shape)) 37 | y = tf.transpose(tf.nn.embedding_lookup_sparse(b, tmp1, tmp2, combiner="sum")) 38 | return y 39 | 40 | def sparse_dense_matmul(sp_a, b, swap=False): 41 | ''' 42 | multiplication of a sparse matrix and a dense matrix 43 | sp_a * b if swap==False else b * sp_a 44 | shape error happens in extremely sparse case, where right-bottom margin elements are all zero 45 | I do not know how to fix it, seldom happens in practice 46 | set last number of sparse matrix to a very small number is recommanded 47 | :param sp_a: SparseTensor 48 | :param b: 2dTensor 49 | :param swap: Boolean 50 | :return: Tensor 51 | ''' 52 | # if(type(b[0][0]) is not type(0.1)): 53 | # b = np.array(b).astype(float).tolist() 54 | if(not swap): 55 | return tf.sparse_tensor_dense_matmul(sp_a, b) 56 | else: 57 | sess = tf.Session() 58 | sess.run(tf.global_variables_initializer()) 59 | b = tf.transpose(b) 60 | sp_indices = sess.run(sp_a.indices)[:,1] 61 | tmp = [] 62 | for c in sp_indices: 63 | tmp.append([c, 0]) 64 | sp_indices = tmp 65 | sp_indice_value = sess.run(sp_a.indices)[:,0] 66 | sp_value = np.array(sess.run(sp_a.values)).astype(float) 67 | sp_shape = sess.run(sp_a.dense_shape) 68 | sess.close() 69 | tmp1 = tf.sparse_reorder(tf.SparseTensor(indices=sp_indices, values=sp_indice_value, dense_shape=sp_shape)) 70 | tmp2 = tf.sparse_reorder(tf.SparseTensor(indices=sp_indices, values=sp_value, dense_shape=sp_shape)) 71 | y = tf.transpose(tf.nn.embedding_lookup_sparse(b, tmp1, tmp2, combiner="sum")) 72 | tmp = y.shape 73 | return y 74 | 75 | if __name__ == '__main__': 76 | a = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 77 | 78 | X = tf.placeholder(tf.float32, shape=[2,3]) 79 | 80 | x = np.array([[1,2,3],[2,4,6]], dtype=np.float32) 81 | 82 | mul = sparse_dense_matmul(a, X, True) 83 | mul2 = sparse_dense_matmul_b(sp_indices=[[0, 0], [1, 2]], sp_value=[1, 2], sp_shape=[3,4], b=X, swap=True) 84 | 85 | sess = tf.Session() 86 | print(sess.run(mul, feed_dict={X:x})) 87 | print(sess.run(mul2, feed_dict={X: x})) -------------------------------------------------------------------------------- /mnist_pruning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nephashi/DeepCompression/45f37ba0331b4d21841f4968123ae806bdd21822/mnist_pruning/__init__.py -------------------------------------------------------------------------------- /mnist_pruning/read_image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | def read_image(path): 5 | tmp = np.array(Image.open(path).resize((28, 28), resample=2)) 6 | img = np.zeros((28, 28, 1)) 7 | img[:, :, 0] = tmp[:, :, 0] 8 | return img -------------------------------------------------------------------------------- /mnist_pruning/seven.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nephashi/DeepCompression/45f37ba0331b4d21841f4968123ae806bdd21822/mnist_pruning/seven.png -------------------------------------------------------------------------------- /mnist_pruning/sparse_op.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | def printtt(): 5 | pass 6 | 7 | def sparse_dense_matmul_b(sp_indices, sp_value, sp_shape, b, swap = False): 8 | ''' 9 | first three params is exactly the three parameter of SparseTensor constructor 10 | multiplication of a sparse matrix and a dense matrix 11 | a * b if swap==False else b * sp_a 12 | shape error happens in extremely sparse case, where right-bottom margin elements are all zero 13 | I do not know how to fix it, seldom happens in practice 14 | set last number of sparse matrix to a very small number is recommanded 15 | :param sp_indices: 16 | :param sp_value: 17 | :param sp_shape: 18 | :param b: 19 | :param swap: 20 | :return: 21 | ''' 22 | if(not swap): 23 | sp_a = tf.SparseTensor(sp_indices, sp_value, sp_shape) 24 | return tf.sparse_tensor_dense_matmul(sp_a, b) 25 | else: 26 | b = tf.transpose(b) 27 | sp_indices = np.array(sp_indices) 28 | internal_sp_indices = sp_indices[:,1] 29 | tmp = [] 30 | for c in internal_sp_indices: 31 | tmp.append([c, 0]) 32 | internal_sp_indices = tmp 33 | sp_indice_value = sp_indices[:, 0] 34 | sp_value = np.array(sp_value).astype(float) 35 | tmp1 = tf.sparse_reorder(tf.SparseTensor(indices=internal_sp_indices, values=sp_indice_value, dense_shape=sp_shape)) 36 | tmp2 = tf.sparse_reorder(tf.SparseTensor(indices=internal_sp_indices, values=sp_value, dense_shape=sp_shape)) 37 | y = tf.transpose(tf.nn.embedding_lookup_sparse(b, tmp1, tmp2, combiner="sum")) 38 | return y 39 | 40 | def sparse_dense_matmul(sp_a, b, swap=False): 41 | ''' 42 | multiplication of a sparse matrix and a dense matrix 43 | sp_a * b if swap==False else b * sp_a 44 | shape error happens in extremely sparse case, where right-bottom margin elements are all zero 45 | I do not know how to fix it, seldom happens in practice 46 | set last number of sparse matrix to a very small number is recommanded 47 | :param sp_a: SparseTensor 48 | :param b: 2dTensor 49 | :param swap: Boolean 50 | :return: Tensor 51 | ''' 52 | # if(type(b[0][0]) is not type(0.1)): 53 | # b = np.array(b).astype(float).tolist() 54 | if(not swap): 55 | return tf.sparse_tensor_dense_matmul(sp_a, b) 56 | else: 57 | sess = tf.Session() 58 | sess.run(tf.global_variables_initializer()) 59 | b = tf.transpose(b) 60 | sp_indices = sess.run(sp_a.indices)[:,1] 61 | tmp = [] 62 | for c in sp_indices: 63 | tmp.append([c, 0]) 64 | sp_indices = tmp 65 | sp_indice_value = sess.run(sp_a.indices)[:,0] 66 | sp_value = np.array(sess.run(sp_a.values)).astype(float) 67 | sp_shape = sess.run(sp_a.dense_shape) 68 | sess.close() 69 | tmp1 = tf.sparse_reorder(tf.SparseTensor(indices=sp_indices, values=sp_indice_value, dense_shape=sp_shape)) 70 | tmp2 = tf.sparse_reorder(tf.SparseTensor(indices=sp_indices, values=sp_value, dense_shape=sp_shape)) 71 | y = tf.transpose(tf.nn.embedding_lookup_sparse(b, tmp1, tmp2, combiner="sum")) 72 | tmp = y.shape 73 | return y 74 | 75 | if __name__ == '__main__': 76 | a = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) 77 | 78 | X = tf.placeholder(tf.float32, shape=[2,3]) 79 | 80 | x = np.array([[1,2,3],[2,4,6]], dtype=np.float32) 81 | 82 | mul = sparse_dense_matmul(a, X, True) 83 | mul2 = sparse_dense_matmul_b(sp_indices=[[0, 0], [1, 2]], sp_value=[1, 2], sp_shape=[3,4], b=X, swap=True) 84 | 85 | sess = tf.Session() 86 | print(sess.run(mul, feed_dict={X:x})) 87 | print(sess.run(mul2, feed_dict={X: x})) -------------------------------------------------------------------------------- /mnist_pruning/test_dense.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | sess = tf.Session() 4 | 5 | def test(predict_logit): 6 | correct_prediction = tf.equal(tf.arg_max(predict_logit,1), tf.arg_max(y_,1)) 7 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 8 | result = 0 9 | for i in range(20): 10 | batch = mnist.test.next_batch(500) 11 | result = result + sess.run(accuracy, feed_dict={x:batch[0], y_:batch[1], keep_prob : 1.0}) 12 | result = result / 20.0 13 | return result 14 | 15 | from tensorflow.examples.tutorials.mnist import input_data 16 | mnist = input_data.read_data_sets('H:/data/', one_hot=True) 17 | 18 | dense_w = { 19 | "w_conv1":tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1), name="w_conv1"), 20 | "b_conv1":tf.Variable(tf.constant(0.1, shape=[32]), name="b_conv1"), 21 | "w_conv2":tf.Variable(tf.truncated_normal([5,5,32,64], stddev=0.1), name="w_conv2"), 22 | "b_conv2":tf.Variable(tf.constant(0.1, shape=[64]), name="b_conv2"), 23 | "w_fc1":tf.Variable(tf.truncated_normal([7*7*64,1024], stddev=0.1), name="w_fc1"), 24 | "b_fc1":tf.Variable(tf.constant(0.1, shape=[1024]), name="b_fc1"), 25 | "w_fc2":tf.Variable(tf.truncated_normal([1024,10], stddev=0.1), name="w_fc2"), 26 | "b_fc2":tf.Variable(tf.constant(0.1, shape=[10]), name="b_fc2") 27 | } 28 | 29 | def dense_cnn_model(image, weights, keep_prob): 30 | def conv2d(x, W): 31 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME") 32 | def max_pool_2x2(x): 33 | return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 34 | x_image = tf.reshape(image, [-1,28,28,1]) 35 | h_conv1 = tf.nn.relu(conv2d(x_image, weights["w_conv1"]) + weights["b_conv1"]) 36 | #[-1,14,14,32] 37 | h_pool1 = max_pool_2x2(h_conv1) 38 | h_conv2 = tf.nn.relu(conv2d(h_pool1, weights["w_conv2"]) + weights["b_conv2"]) 39 | #[-1,7,7,64] 40 | h_pool2 = max_pool_2x2(h_conv2) 41 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 42 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, weights["w_fc1"]) + weights["b_fc1"]) 43 | h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob=keep_prob) 44 | #[-1,10] 45 | logit = tf.matmul(h_fc1_dropout, weights["w_fc2"]) + weights["b_fc2"] 46 | return logit 47 | 48 | def dense_conv_model(image, weights): 49 | def conv2d(x, W): 50 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME") 51 | def max_pool_2x2(x): 52 | return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 53 | x_image = tf.reshape(image, [-1,28,28,1]) 54 | h_conv1 = tf.nn.relu(conv2d(x_image, weights["w_conv1"]) + weights["b_conv1"]) 55 | #[-1,14,14,32] 56 | h_pool1 = max_pool_2x2(h_conv1) 57 | h_conv2 = tf.nn.relu(conv2d(h_pool1, weights["w_conv2"]) + weights["b_conv2"]) 58 | #[-1,7,7,64] 59 | h_pool2 = max_pool_2x2(h_conv2) 60 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 61 | return h_pool2_flat 62 | 63 | saver = tf.train.import_meta_graph("./model_ckpt_dense.meta") 64 | saver.restore(sess, "./model_ckpt_dense") 65 | 66 | x = tf.placeholder(tf.float32, shape=[None, 784]) 67 | y_ = tf.placeholder(tf.float32, shape=[None, 10]) 68 | keep_prob = tf.placeholder(tf.float32) 69 | for var in tf.all_variables(): 70 | if sess.run(tf.is_variable_initialized(var)) == False: 71 | sess.run(tf.initialize_variables([var])) 72 | 73 | logit = dense_cnn_model(x, dense_w, keep_prob) 74 | 75 | test_acc = test(logit) 76 | print(test_acc) 77 | 78 | result = 0 79 | for i in range(20): 80 | batch = mnist.test.next_batch(1) 81 | correct_prediction = tf.equal(tf.arg_max(logit, 1), tf.arg_max(y_, 1)) 82 | accuracy = sess.run(correct_prediction, feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}) 83 | print(1) 84 | 85 | 86 | 87 | # pool2_flat = dense_conv_model(x, dense_w) 88 | 89 | # for i in range(20): 90 | # batch = mnist.test.next_batch(1) 91 | # tmp_pool2_flat = sess.run(pool2_flat, feed_dict={x:batch[0]}) 92 | # tmp_pool2_flat_variable = tf.Variable(tmp_pool2_flat, dtype=tf.float32) 93 | # sess.run(tf.initialize_variables([tmp_pool2_flat_variable])) 94 | # weight_fc1 = dense_w["w_fc1"] 95 | # tmp_fc1 = sess.run(tf.matmul(tmp_pool2_flat_variable, weight_fc1) + dense_w["b_fc1"]) 96 | # tmp_fc1_variable = tf.Variable(tmp_fc1, dtype=tf.float32) 97 | # sess.run(tf.initialize_variables([tmp_fc1_variable])) 98 | # h_fc1 = tf.nn.relu(tmp_fc1_variable) 99 | # tmp_h_fc1 = sess.run(h_fc1) 100 | # tmp_h_fc1_variable = tf.Variable(tmp_h_fc1, dtype=tf.float32) 101 | # sess.run(tf.initialize_variables([tmp_h_fc1_variable])) 102 | # weight_fc2 = dense_w["w_fc2"] 103 | # tmp_fc2 = sess.run(tf.matmul(tmp_h_fc1_variable, weight_fc2) + dense_w["b_fc2"]) 104 | # tmp_fc2_variable = tf.Variable(tmp_fc2, dtype=tf.float32) 105 | # sess.run(tf.initialize_variables([tmp_fc2_variable])) 106 | # 107 | # if_correct = tf.equal(tf.argmax(tmp_fc2_variable, 1), tf.argmax(y_, 1)) 108 | # acc = sess.run(if_correct,feed_dict={y_:batch[1]}) 109 | # print(1) 110 | 111 | 112 | -------------------------------------------------------------------------------- /mnist_pruning/test_sparse.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import argparse 4 | import config 5 | import sys 6 | 7 | from tensorflow.examples.tutorials.mnist import input_data 8 | mnist = input_data.read_data_sets('H:/data/', one_hot=True) 9 | 10 | def test(predict_logit): 11 | correct_prediction = tf.equal(tf.arg_max(predict_logit,1), tf.arg_max(y_,1)) 12 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 13 | result = 0 14 | for i in range(20): 15 | batch = mnist.test.next_batch(500) 16 | result = result + sess.run(accuracy, feed_dict={x:batch[0], y_:batch[1], keep_prob : 1.0}) 17 | result = result / 20.0 18 | return result 19 | 20 | sess = tf.Session() 21 | saver = tf.train.import_meta_graph("./model_ckpt_sparse_retrained.meta") 22 | saver.restore(sess, "./model_ckpt_sparse_retrained") 23 | 24 | x = tf.get_collection("x_placeholder")[0] 25 | y_ = tf.placeholder(tf.float32, [None, 10], name="y_") 26 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 27 | 28 | logit = tf.get_collection("sp_logit")[0] 29 | test_acc_sp = test(logit) 30 | print(test_acc_sp) 31 | -------------------------------------------------------------------------------- /mnist_pruning/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import tensorflow as tf 3 | import numpy as np 4 | import argparse 5 | from sparse_op import sparse_dense_matmul_b 6 | import read_image 7 | 8 | argparser = argparse.ArgumentParser() 9 | argparser.add_argument("-1", "--train", action="store_true", 10 | help="train dense MNIST model with 20000 iterations") 11 | argparser.add_argument("-2", "--prune", action="store_true", 12 | help="prune model and retrain") 13 | argparser.add_argument("-3", "--sparse", action="store_true", 14 | help="transform model to a sparse format and save it") 15 | argparser.add_argument("-m", "--checkpoint", default="./model_ckpt_dense", 16 | help="Target checkpoint model file for 2nd and 3rd round") 17 | args = argparser.parse_args() 18 | 19 | def dense_cnn_model(image, weights, keep_prob): 20 | def conv2d(x, W): 21 | return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding="SAME") 22 | def max_pool_2x2(x): 23 | return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding="SAME") 24 | x_image = tf.reshape(image, [-1,28,28,1]) 25 | h_conv1 = tf.nn.relu(conv2d(x_image, weights["w_conv1"]) + weights["b_conv1"]) 26 | #[-1,14,14,32] 27 | h_pool1 = max_pool_2x2(h_conv1) 28 | h_conv2 = tf.nn.relu(conv2d(h_pool1, weights["w_conv2"]) + weights["b_conv2"]) 29 | #[-1,7,7,64] 30 | h_pool2 = max_pool_2x2(h_conv2) 31 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) 32 | h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, weights["w_fc1"]) + weights["b_fc1"]) 33 | h_fc1_dropout = tf.nn.dropout(h_fc1, keep_prob=keep_prob) 34 | #[-1,10] 35 | logit = tf.matmul(h_fc1_dropout, weights["w_fc2"]) + weights["b_fc2"] 36 | return h_pool2_flat, h_fc1, logit 37 | 38 | def test(predict_logit): 39 | correct_prediction = tf.equal(tf.argmax(predict_logit,1), tf.argmax(y_,1)) 40 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 41 | result = 0 42 | for i in range(20): 43 | batch = mnist.test.next_batch(500) 44 | result = result + sess.run(accuracy, feed_dict={x:batch[0], y_:batch[1], keep_prob : 1.0}) 45 | result = result / 20.0 46 | return result 47 | 48 | def prune(weights, th, name): 49 | shape = weights.shape 50 | weight_arr = sess.run(weights) 51 | under_threshold = abs(weight_arr) < th 52 | weight_arr[under_threshold] = 0 53 | tmp = weight_arr 54 | #set last matrix elemet to a small number, I have to do that since the drawback of tensorflow sparse matrix support 55 | #hope it would have less impact on model 56 | for i in range(len(shape) - 1): 57 | tmp = tmp[-1] 58 | if(tmp[-1] == 0): 59 | tmp[-1] = 0.01 60 | count = np.sum(under_threshold) 61 | print ("None-zero element: %s" % (weight_arr.size - count)) 62 | sparse_weight = tf.Variable(weight_arr, dtype=tf.float32, name=name) 63 | return sparse_weight, ~under_threshold 64 | 65 | 66 | def get_th(weight, percentage=0.8): 67 | flat = tf.reshape(weight, [-1]) 68 | flat_list = sorted(map(abs,sess.run(flat))) 69 | return flat_list[int(len(flat_list) * percentage)] 70 | 71 | #转换全连接层 72 | def transfer_to_sparse(weight): 73 | weight_arr = sess.run(weight) 74 | values = weight_arr[weight_arr != 0] 75 | indices = np.transpose(np.nonzero(weight_arr)) 76 | shape = list(weight_arr.shape) 77 | return [indices, values, shape] 78 | 79 | def delete_none_grads(grads): 80 | count = 0 81 | length = len(grads) 82 | while(count < length): 83 | if(grads[count][0] == None): 84 | del grads[count] 85 | length -= 1 86 | else: 87 | count += 1 88 | 89 | from tensorflow.examples.tutorials.mnist import input_data 90 | if(args.train or args.prune or args.sparse) == False: 91 | argparser.print_help() 92 | sys.exit() 93 | mnist = input_data.read_data_sets('H:/data/', one_hot=True) 94 | 95 | if((args.train or args.prune or args.sparse) == False): 96 | argparser.print_help() 97 | sys.exit(1) 98 | 99 | sess = tf.Session() 100 | 101 | dense_w = { 102 | "w_conv1":tf.Variable(tf.truncated_normal([5,5,1,32], stddev=0.1), name="w_conv1"), 103 | "b_conv1":tf.Variable(tf.constant(0.1, shape=[32]), name="b_conv1"), 104 | "w_conv2":tf.Variable(tf.truncated_normal([5,5,32,64], stddev=0.1), name="w_conv2"), 105 | "b_conv2":tf.Variable(tf.constant(0.1, shape=[64]), name="b_conv2"), 106 | "w_fc1":tf.Variable(tf.truncated_normal([7*7*64,1024], stddev=0.1), name="w_fc1"), 107 | "b_fc1":tf.Variable(tf.constant(0.1, shape=[1024]), name="b_fc1"), 108 | "w_fc2":tf.Variable(tf.truncated_normal([1024,10], stddev=0.1), name="w_fc2"), 109 | "b_fc2":tf.Variable(tf.constant(0.1, shape=[10]), name="b_fc2") 110 | } 111 | 112 | if(args.train == True): 113 | x = tf.placeholder(tf.float32, [None, 784], name="x") 114 | y_ = tf.placeholder(tf.float32, [None, 10], name="y_") 115 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 116 | 117 | useless1, useless2, logit = dense_cnn_model(x, dense_w, keep_prob) 118 | 119 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y_) 120 | train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) 121 | 122 | correct_prediction = tf.equal(tf.arg_max(logit,1), tf.arg_max(y_, 1)) 123 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 124 | 125 | sess.run(tf.global_variables_initializer()) 126 | 127 | for i in range(20000): 128 | batch = mnist.train.next_batch(50) 129 | if i % 100 == 0: 130 | train_acc = sess.run(accuracy, feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 131 | print("step %d, training acc %g" % (i , train_acc)) 132 | sess.run(train_step, feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 133 | 134 | test_acc = test(logit) 135 | print("test acc %g" % test_acc) 136 | saver = tf.train.Saver() 137 | saver.save(sess, "./model_ckpt_dense") 138 | 139 | if(args.prune == True): 140 | saver = tf.train.Saver() 141 | saver.restore(sess, args.checkpoint) 142 | th_fc1 = get_th(dense_w["w_fc1"], percentage=0.9) 143 | th_fc2 = get_th(dense_w["w_fc2"], percentage=0.9) 144 | sp_w_fc1, idx_fc1 = prune(dense_w["w_fc1"], th_fc1, name="sp_w_fc1") 145 | sp_w_fc2, idx_fc2 = prune(dense_w["w_fc2"], th_fc2, name="sp_w_fc2") 146 | dense_w["w_fc1"] = sp_w_fc1 147 | dense_w["w_fc2"] = sp_w_fc2 148 | 149 | x = tf.placeholder(tf.float32, [None, 784], name="x") 150 | y_ = tf.placeholder(tf.float32, [None, 10], name="y_") 151 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 152 | 153 | for var in tf.all_variables(): 154 | if sess.run(tf.is_variable_initialized(var)) == False: 155 | sess.run(var.initializer) 156 | 157 | useless1, useless2, logit = dense_cnn_model(x, dense_w, keep_prob) 158 | test_acc = test(logit) 159 | print("test acc after pruning %g" % test_acc) 160 | saver.save(sess, "./model_ckpt_dense_pruned") 161 | 162 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=y_) 163 | trainer = tf.train.AdamOptimizer(1e-4) 164 | grads = trainer.compute_gradients(cross_entropy) 165 | 166 | delete_none_grads(grads) 167 | 168 | count = 0 169 | for grad, var in grads: 170 | if (var.name == "sp_w_fc1:0"): 171 | idx_in1 = tf.cast(tf.constant(idx_fc1), tf.float32) 172 | grads[count] = (tf.multiply(idx_in1, grad), var) 173 | if (var.name == "sp_w_fc2:0"): 174 | idx_in2 = tf.cast(tf.constant(idx_fc2), tf.float32) 175 | grads[count] = (tf.multiply(idx_in2, grad), var) 176 | count += 1 177 | train_step = trainer.apply_gradients(grads) 178 | 179 | correct_prediction = tf.equal(tf.argmax(logit,1), tf.argmax(y_,1)) 180 | accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 181 | 182 | for var in tf.all_variables(): 183 | if sess.run(tf.is_variable_initialized(var)) == False: 184 | sess.run(tf.initialize_variables([var])) 185 | 186 | for i in range(20000): 187 | batch = mnist.train.next_batch(50) 188 | idx_in1_value = sess.run(idx_in1) 189 | grads_fc1_value = sess.run(grads, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 190 | if i % 100 == 0: 191 | train_acc = sess.run(accuracy, feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5}) 192 | print ("retraining step %d, acc %g" % (i, train_acc)) 193 | sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) 194 | 195 | test_acc = test(logit) 196 | print("test acc after pruning and retraining%g" % test_acc) 197 | 198 | saver = tf.train.Saver(dense_w) 199 | saver.save(sess, "./model_ckpt_dense_retrained") 200 | 201 | if(args.sparse == True): 202 | if args.prune == False: 203 | saver = tf.train.Saver() 204 | saver.restore(sess, "./model_ckpt_dense_retrained") 205 | 206 | sparse_w = { 207 | "w_conv1": tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev=0.1)), 208 | "b_conv1": tf.Variable(tf.constant(0.1, shape=[32])), 209 | "w_conv2": tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev=0.1)), 210 | "b_conv2": tf.Variable(tf.constant(0.1, shape=[64])), 211 | "w_fc1": tf.Variable(tf.truncated_normal([7 * 7 * 64, 1024], stddev=0.1)), 212 | "b_fc1": tf.Variable(tf.constant(0.1, shape=[1024])), 213 | "w_fc2": tf.Variable(tf.truncated_normal([1024, 10], stddev=0.1)), 214 | "b_fc2": tf.Variable(tf.constant(0.1, shape=[10])) 215 | } 216 | 217 | def sparse_cnn_model(image, sparse_weight): 218 | def conv2d(x, W): 219 | return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding="SAME") 220 | 221 | def max_pool_2x2(x): 222 | return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME") 223 | 224 | x_image = tf.reshape(image, [-1, 28, 28, 1]) 225 | h_conv1 = tf.nn.relu(conv2d(x_image, sparse_weight["w_conv1"]) + sparse_weight["b_conv1"]) 226 | # [-1,14,14,32] 227 | h_pool1 = max_pool_2x2(h_conv1) 228 | h_conv2 = tf.nn.relu(conv2d(h_pool1, sparse_weight["w_conv2"]) + sparse_weight["b_conv2"]) 229 | # [-1,7,7,64] 230 | h_pool2 = max_pool_2x2(h_conv2) 231 | h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) 232 | ndarray_w_fc1_idx = sess.run(sparse_weight["w_fc1_idx"]) 233 | ndarray_w_fc1 = sess.run(sparse_weight["w_fc1"]) 234 | ndarray_w_fc1_shape = sess.run(sparse_weight["w_fc1_shape"]) 235 | h_fc1 = tf.nn.relu(sparse_dense_matmul_b(ndarray_w_fc1_idx, ndarray_w_fc1, ndarray_w_fc1_shape, h_pool2_flat, True) + sparse_weight["b_fc1"]) 236 | ndarray_w_fc2_idx = sess.run(sparse_weight["w_fc2_idx"]) 237 | ndarray_w_fc2 = sess.run(sparse_weight["w_fc2"]) 238 | ndarray_w_fc2_shape = sess.run(sparse_weight["w_fc2_shape"]) 239 | logit = sparse_dense_matmul_b(ndarray_w_fc2_idx, ndarray_w_fc2, ndarray_w_fc2_shape, h_fc1, True) + sparse_weight["b_fc2"] 240 | return h_pool2_flat, h_fc1, logit 241 | 242 | copy_ops = [] 243 | for key, value in dense_w.items(): 244 | copy_ops.append(sparse_w[key].assign(value)) 245 | for e in copy_ops: 246 | sess.run(e) 247 | 248 | fc1_sparse_tmp = transfer_to_sparse(dense_w["w_fc1"]) 249 | sparse_w["w_fc1_idx"] = tf.Variable(tf.constant(fc1_sparse_tmp[0], dtype=tf.int64)\ 250 | , name="w_fc1_idx") 251 | sparse_w["w_fc1"] = tf.Variable(tf.constant(fc1_sparse_tmp[1], dtype=tf.float32)\ 252 | , name="w_fc1") 253 | sparse_w["w_fc1_shape"] = tf.Variable(tf.constant(fc1_sparse_tmp[2], dtype=tf.int64)\ 254 | , name="w_fc1_shape") 255 | fc2_sparse_tmp = transfer_to_sparse(dense_w["w_fc2"]) 256 | sparse_w["w_fc2_idx"] = tf.Variable(tf.constant(fc2_sparse_tmp[0], dtype=tf.int64)\ 257 | , name="w_fc2_idx") 258 | sparse_w["w_fc2"] = tf.Variable(tf.constant(fc2_sparse_tmp[1], dtype=tf.float32)\ 259 | , name="w_fc2") 260 | sparse_w["w_fc2_shape"] = tf.Variable(tf.constant(fc2_sparse_tmp[2], dtype=tf.int64)\ 261 | , name="w_fc2_shape") 262 | for var in tf.all_variables(): 263 | if sess.run(tf.is_variable_initialized(var)) == False: 264 | sess.run(tf.initialize_variables([var])) 265 | 266 | for key, value in sparse_w.items(): 267 | tf.add_to_collection("sparse_" + key, value) 268 | print("sparse_" + key) 269 | 270 | x = tf.placeholder(tf.float32, [None, 784], name="x") 271 | y_ = tf.placeholder(tf.float32, [None, 10], name="y_") 272 | keep_prob = tf.placeholder(tf.float32, name="keep_prob") 273 | tf.add_to_collection("x_placeholder", x) 274 | 275 | #dense_prediction 276 | dense_pool2_flat, dense_h_fc1, dense_logit = dense_cnn_model(x, dense_w, keep_prob) 277 | #sparse prediction 278 | sp_pool2_flat, sp_h_fc1, sp_logit = sparse_cnn_model(x, sparse_w) 279 | 280 | img = read_image.read_image("./seven.png") 281 | img = np.reshape(img, (784)) 282 | 283 | rst_dense_pool2_flat = sess.run(dense_pool2_flat, feed_dict={x:[img], keep_prob:1.0}) 284 | rst_dense_h_fc1 = sess.run(dense_h_fc1, feed_dict={x:[img], keep_prob:1.0}) 285 | rst_dense_logit = sess.run(dense_logit, feed_dict={x:[img], keep_prob:1.0}) 286 | 287 | rst_sp_pool2_flat = sess.run(sp_pool2_flat, feed_dict={x:[img]}) 288 | rst_sp_h_fc1 = sess.run(sp_h_fc1, feed_dict={x:[img]}) 289 | rst_sp_logit = sess.run(sp_logit, feed_dict={x:[img]}) 290 | 291 | test_acc_dense = test(dense_logit) 292 | print("dense acc:" + str(test_acc_dense)) 293 | 294 | test_acc_sp = test(sp_logit) 295 | print("sp acc" + str(test_acc_sp)) 296 | 297 | tf.add_to_collection("sp_logit", sp_logit) 298 | 299 | sparse_saver = tf.train.Saver(sparse_w) 300 | sparse_saver.save(sess, "./model_ckpt_sparse_retrained") --------------------------------------------------------------------------------