├── README.md ├── post_quantization.py └── quantization_aware_training.py /README.md: -------------------------------------------------------------------------------- 1 | # model_quantization 2 | 3 | - post_quantization.py 训练后量化,原理见[深度学习算法优化系列五 | 使用TensorFlow-Lite对LeNet进行训练后量化](https://mp.weixin.qq.com/s/MSnkltSHGsBF9ddwN4wZKQ) 4 | - quantization_aware_training.py 训练时量化,原理见[深度学习算法优化系列六 | 使用TensorFlow-Lite对LeNet进行训练时量化](https://mp.weixin.qq.com/s/vbFegPQg5omMlwNn6tSmhQ) 5 | 6 | -------------------------------------------------------------------------------- /post_quantization.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import re 3 | import time 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | from tensorflow.contrib.slim import get_variables_to_restore 8 | import tensorflow.examples.tutorials.mnist.input_data as input_data 9 | 10 | # 参数设置 11 | KEEP_PROB = 0.5 12 | LEARNING_RATE = 1e-5 13 | BATCH_SIZE = 30 14 | PARAMETER_FILE = "./checkpoint/variable.ckpt-100000" 15 | MAX_ITER = 100000 16 | 17 | # Build LeNet 18 | class Lenet: 19 | def __init__(self, is_train=True): 20 | self.raw_input_image = tf.placeholder(tf.float32, [None, 784], "inputs") 21 | self.input_images = tf.reshape(self.raw_input_image, [-1, 28, 28, 1]) 22 | self.raw_input_label = tf.placeholder("float", [None, 10], "labels") 23 | self.input_labels = tf.cast(self.raw_input_label, tf.int32) 24 | self.dropout = KEEP_PROB 25 | self.is_train = is_train 26 | 27 | with tf.variable_scope("Lenet") as scope: 28 | self.train_digits = self.build(True) 29 | scope.reuse_variables() 30 | self.pred_digits = self.build(False) 31 | 32 | self.loss = slim.losses.softmax_cross_entropy(self.train_digits, self.input_labels) 33 | self.lr = LEARNING_RATE 34 | self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss) 35 | 36 | self.predictions = tf.arg_max(self.pred_digits, 1, name="predictions") 37 | self.correct_prediction = tf.equal(tf.argmax(self.pred_digits, 1), tf.argmax(self.input_labels, 1)) 38 | self.train_accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, "float")) 39 | 40 | def build(self, is_trained=True): 41 | with slim.arg_scope([slim.conv2d], padding='VALID', 42 | weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 43 | weights_regularizer=slim.l2_regularizer(0.0005)): 44 | net = slim.conv2d(self.input_images, 6, [5, 5], 1, padding='SAME', scope='conv1') 45 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 46 | net = slim.conv2d(net, 16, [5, 5], 1, scope='conv3') 47 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 48 | net = slim.conv2d(net, 120, [5, 5], 1, scope='conv5') 49 | net = slim.flatten(net, scope='flat6') 50 | net = slim.fully_connected(net, 84, scope='fc7') 51 | net = slim.dropout(net, self.dropout, is_training=is_trained, scope='dropout8') 52 | digits = slim.fully_connected(net, 10, scope='fc9') 53 | return digits 54 | 55 | # 将Saved_Model转为tflite,调用的tf.lite.TFLiteConverter 56 | def convert_to_tflite(): 57 | saved_model_dir = "./pb_model" 58 | converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir, 59 | input_arrays=["inputs"], 60 | input_shapes={"inputs": [1, 784]}, 61 | output_arrays=["predictions"]) 62 | converter.post_training_quantize = True 63 | tflite_model = converter.convert() 64 | open("tflite_model/eval_graph.tflite", "wb").write(tflite_model) 65 | 66 | # 使用原始的checkpoint进行预测 67 | def origin_predict(): 68 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 69 | sess = tf.Session() 70 | saver = tf.train.import_meta_graph("./checkpoint/variable.ckpt-100000.meta") 71 | saver.restore(sess, "./checkpoint/variable.ckpt-100000") 72 | 73 | input_node = sess.graph.get_tensor_by_name('inputs:0') 74 | pred = sess.graph.get_tensor_by_name('predictions:0') 75 | labels = [label.index(1) for label in mnist.test.labels.tolist()] 76 | predictions = [] 77 | start_time = time.time() 78 | for i in range(10): 79 | for image in mnist.test.images: 80 | prediction = sess.run(pred, feed_dict={input_node: [image]}).tolist()[0] 81 | predictions.append(prediction) 82 | end_time = time.time() 83 | correct = 0 84 | for prediction, label in zip(predictions, labels): 85 | if prediction == label: 86 | correct += 1 87 | print(correct / len(labels)) 88 | print((end_time - start_time)) 89 | 90 | # 使用tflite进行预测 91 | def tflite_predict(): 92 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 93 | labels = [label.index(1) for label in mnist.test.labels.tolist()] 94 | images = mnist.test.images 95 | #images = np.array(images, dtype="uint8") 96 | # 根据tflite文件生成解析器 97 | interpreter = tf.contrib.lite.Interpreter(model_path="tflite_model/eval_graph.tflite") 98 | # 用allocate_tensors()分配内存 99 | interpreter.allocate_tensors() 100 | # 获取输入输出tensor 101 | input_details = interpreter.get_input_details() 102 | output_details = interpreter.get_output_details() 103 | 104 | predictions = [] 105 | start_time = time.time() 106 | for i in range(10): 107 | for image in images: 108 | # 填充输入tensor 109 | interpreter.set_tensor(input_details[0]['index'], [image]) 110 | # 前向推理 111 | interpreter.invoke() 112 | # 获取输出tensor 113 | score = interpreter.get_tensor(output_details[0]['index'])[0][0] 114 | # # 结果去掉无用的维度 115 | # result = np.squeeze(score) 116 | # #print('result:{}'.format(result)) 117 | # # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字 118 | predictions.append(score) 119 | end_time = time.time() 120 | correct = 0 121 | for prediction, label in zip(predictions, labels): 122 | if prediction == label: 123 | correct += 1 124 | print((end_time - start_time)) 125 | print(correct / len(labels)) 126 | 127 | 128 | def train(): 129 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 130 | test_images = mnist.test.images 131 | test_labels = mnist.test.labels 132 | sess = tf.Session() 133 | batch_size = BATCH_SIZE 134 | paramter_path = PARAMETER_FILE 135 | max_iter = MAX_ITER 136 | 137 | lenet = Lenet() 138 | variables = get_variables_to_restore() 139 | save_vars = [variable for variable in variables if not re.search("Adam", variable.name)] 140 | 141 | saver = tf.train.Saver(save_vars) 142 | sess.run(tf.initialize_all_variables()) 143 | # 用来显示标量信息 144 | tf.summary.scalar("loss", lenet.loss) 145 | # merge_all 可以将所有summary全部保存到磁盘,以便tensorboard显示。如果没有特殊要求, 146 | # 一般用这一句就可一显示训练时的各种信息了。 147 | summary_op = tf.summary.merge_all() 148 | # 指定一个文件用来保存图 149 | train_summary_writer = tf.summary.FileWriter("logs", sess.graph) 150 | 151 | for i in range(max_iter): 152 | batch = mnist.train.next_batch(batch_size) 153 | if i % 100 == 0: 154 | train_accuracy, summary = sess.run([lenet.train_accuracy, summary_op], feed_dict={ 155 | lenet.raw_input_image: batch[0], 156 | lenet.raw_input_label: batch[1] 157 | }) 158 | train_summary_writer.add_summary(summary) 159 | print("step %d, training accuracy %g" % (i, train_accuracy)) 160 | 161 | if i % 500 == 0: 162 | test_accuracy = sess.run(lenet.train_accuracy, feed_dict={lenet.raw_input_image: test_images, 163 | lenet.raw_input_label: test_labels}) 164 | print("\n") 165 | print("step %d, test accuracy %g" % (i, test_accuracy)) 166 | print("\n") 167 | sess.run(lenet.train_op, feed_dict={lenet.raw_input_image: batch[0], 168 | lenet.raw_input_label: batch[1]}) 169 | saver.save(sess, paramter_path) 170 | print("saved model") 171 | 172 | # 保存为saved_model 173 | builder = tf.saved_model.builder.SavedModelBuilder("pb_model") 174 | inputs = {"inputs": tf.saved_model.utils.build_tensor_info(lenet.raw_input_image)} 175 | outputs = {"predictions": tf.saved_model.utils.build_tensor_info(lenet.predictions)} 176 | prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(inputs=inputs, outputs=outputs, 177 | method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) 178 | 179 | legacy_init_op = tf.group(tf.tables_initializer(), name="legacy_init_op") 180 | builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING], 181 | signature_def_map={"serving_default": prediction_signature}, 182 | legacy_init_op=legacy_init_op, saver=saver) 183 | builder.save() 184 | 185 | 186 | if __name__ == '__main__': 187 | #train() 188 | convert_to_tflite() 189 | origin_predict() 190 | tflite_predict() 191 | 192 | -------------------------------------------------------------------------------- /quantization_aware_training.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import re 3 | import time 4 | import numpy as np 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | from tensorflow.contrib.slim import get_variables_to_restore 8 | from tensorflow.python.framework import graph_util 9 | import tensorflow.examples.tutorials.mnist.input_data as input_data 10 | 11 | # 参数设置 12 | KEEP_PROB = 0.5 13 | LEARNING_RATE = 1e-5 14 | BATCH_SIZE = 30 15 | PARAMETER_FILE = "./checkpoint/variable.ckpt-100000" 16 | MAX_ITER = 100000 17 | 18 | # Build LeNet 19 | class Lenet: 20 | def __init__(self, is_train=True): 21 | self.raw_input_image = tf.placeholder(tf.float32, [None, 784], "inputs") 22 | self.input_images = tf.reshape(self.raw_input_image, [-1, 28, 28, 1]) 23 | self.raw_input_label = tf.placeholder("float", [None, 10], "labels") 24 | self.input_labels = tf.cast(self.raw_input_label, tf.int32) 25 | self.dropout = KEEP_PROB 26 | self.is_train = is_train 27 | 28 | with tf.variable_scope("Lenet") as scope: 29 | self.train_digits = self.build(True) 30 | scope.reuse_variables() 31 | self.pred_digits = self.build(False) 32 | 33 | self.loss = slim.losses.softmax_cross_entropy(self.train_digits, self.input_labels) 34 | # 获取当前的计算图,用于后续的量化 35 | self.g = tf.get_default_graph() 36 | 37 | if self.is_train: 38 | # 在损失函数之后,优化器定义之前,在这里会自动选择计算图中的一些operation和activation做伪量化 39 | tf.contrib.quantize.create_training_graph(self.g, 80000) 40 | self.lr = LEARNING_RATE 41 | self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss) 42 | else: 43 | # 用于预测时,将之前训练时构造的伪量化的operation和activation实际量化,用于后续的推断 44 | tf.contrib.quantize.create_eval_graph(self.g) 45 | self.predictions = tf.arg_max(self.pred_digits, 1, name="predictions") 46 | self.correct_prediction = tf.equal(tf.argmax(self.pred_digits, 1), tf.argmax(self.input_labels, 1)) 47 | self.train_accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, "float")) 48 | 49 | def build(self, is_trained=True): 50 | with slim.arg_scope([slim.conv2d], padding='VALID', 51 | weights_initializer=tf.truncated_normal_initializer(stddev=0.01), 52 | weights_regularizer=slim.l2_regularizer(0.0005)): 53 | net = slim.conv2d(self.input_images, 6, [5, 5], 1, padding='SAME', scope='conv1') 54 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 55 | net = slim.conv2d(net, 16, [5, 5], 1, scope='conv3') 56 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 57 | net = slim.conv2d(net, 120, [5, 5], 1, scope='conv5') 58 | net = slim.flatten(net, scope='flat6') 59 | net = slim.fully_connected(net, 84, scope='fc7') 60 | net = slim.dropout(net, self.dropout, is_training=is_trained, scope='dropout8') 61 | digits = slim.fully_connected(net, 10, scope='fc9') 62 | return digits 63 | 64 | # 将保存了训练时伪量化信息的checkpoint文件转换成freeze pb文件 65 | def freeze(): 66 | with tf.Session() as sess: 67 | sess.run(tf.global_variables_initializer()) 68 | le_net = Lenet(False) 69 | saver = tf.train.Saver() 70 | # 导出当前计算图的GraphDef部分 71 | saver.restore(sess, "checkpoint/variable.ckpt-100000") 72 | # 保存指定的节点,并将节点值保存为常数 73 | frozen_graph_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['predictions']) 74 | # 将计算图写入到模型文件中 75 | tf.io.write_graph(frozen_graph_def, "pb_model", "freeze_eval_graph.pb", as_text=False) 76 | 77 | # 将存储了伪量化信息的freeze pb文件转换成完全量化的tflite文件,可以看见量化完之后文件内存基本减小到1/4 78 | def convert_to_tflite(): 79 | converter = tf.lite.TFLiteConverter.from_frozen_graph("pb_model/freeze_eval_graph.pb", ["inputs"], ["predictions"]) 80 | converter.inference_type = tf.lite.constants.QUANTIZED_UINT8 81 | converter.quantized_input_stats = {"inputs":(0., 1.)} # mean, std_dev,需要自己从训练集(增强后,输入网络之前的)统计出来 82 | converter.allow_custom_ops = True 83 | converter.default_ranges_stats = (0, 255) 84 | converter.post_training_quantize = True 85 | tflite_model = converter.convert() 86 | open("tflite_model/eval_graph.tflite", "wb").write(tflite_model) 87 | 88 | # 使用原始的checkpoint进行预测 89 | def origin_predict(): 90 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 91 | sess = tf.Session() 92 | saver = tf.train.import_meta_graph("./checkpoint/variable.ckpt-100000.meta") 93 | saver.restore(sess, "./checkpoint/variable.ckpt-100000") 94 | 95 | input_node = sess.graph.get_tensor_by_name('inputs:0') 96 | pred = sess.graph.get_tensor_by_name('predictions:0') 97 | labels = [label.index(1) for label in mnist.test.labels.tolist()] 98 | predictions = [] 99 | start_time = time.time() 100 | for i in range(10): 101 | for image in mnist.test.images: 102 | prediction = sess.run(pred, feed_dict={input_node: [image]}).tolist()[0] 103 | predictions.append(prediction) 104 | end_time = time.time() 105 | correct = 0 106 | for prediction, label in zip(predictions, labels): 107 | if prediction == label: 108 | correct += 1 109 | print(correct / len(labels)) 110 | print((end_time - start_time)) 111 | sess.close() 112 | 113 | # 使用freeze pb文件进行预测 114 | def freeze_pb_predict(): 115 | mnist = input_data.read_data_sets('MNIST_data/', one_hot=True) 116 | with tf.Session() as sess: 117 | with tf.gfile.FastGFile("pb_model/freeze_eval_graph.pb", 'rb') as f: 118 | # 使用tf.GraphDef()定义一个空Graph 119 | graph_def = tf.GraphDef() 120 | graph_def.ParseFromString(f.read()) 121 | # 返回一个上下文管理器,使得这个Graph对象成为当前默认的graph.当你想在一个进程里面创建 122 | # 多个图的时候,就应该使用这个函数.为了方便起见,一个全局的图对象被默认提供,要是你没有 123 | # 显式创建一个新的图的话,所有的操作(ops)都会被添加到这个默认的图里面来. 124 | #sess.graph.as_default() 125 | # 导入Graph 126 | tf.import_graph_def(graph_def, name='') 127 | # tf.global_variables_initializer()添加节点用于初始化所有的变量(GraphKeys.VARIABLES)。 128 | # 返回一个初始化所有全局变量的操作(Op)。在你构建完整个模型并在会话中加载模型后,运行这个节点。 129 | sess.run(tf.global_variables_initializer()) 130 | # 获取输出Tensor和输出Tensor 131 | input_node = sess.graph.get_tensor_by_name('inputs:0') 132 | pred = sess.graph.get_tensor_by_name('predictions:0') 133 | 134 | labels = [label.index(1) for label in mnist.test.labels.tolist()] 135 | predictions = [] 136 | start_time = time.time() 137 | for i in range(10): 138 | for image in mnist.test.images: 139 | prediction = sess.run(pred, feed_dict={input_node: [image]}).tolist()[0] 140 | predictions.append(prediction) 141 | end_time = time.time() 142 | correct = 0 143 | for prediction, label in zip(predictions, labels): 144 | if prediction == label: 145 | correct += 1 146 | print(correct / len(labels)) 147 | print((end_time - start_time)) 148 | 149 | # 使用tflite进行预测 150 | def tflite_predict(): 151 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 152 | labels = [label.index(1) for label in mnist.test.labels.tolist()] 153 | images = mnist.test.images 154 | 155 | means = np.mean(images, axis=1).reshape([10000, 1]) 156 | std = np.std(images, axis=1, ddof=1).reshape([10000, 1]) 157 | images = (images - means) / std 158 | 159 | images = np.array(images, dtype="uint8") 160 | # 根据tflite文件生成解析器 161 | interpreter = tf.contrib.lite.Interpreter(model_path="./tflite_model/eval_graph.tflite") 162 | # 用allocate_tensors()分配内存 163 | interpreter.allocate_tensors() 164 | # 获取输入输出tensor 165 | input_details = interpreter.get_input_details() 166 | output_details = interpreter.get_output_details() 167 | 168 | predictions = [] 169 | start_time = time.time() 170 | for i in range(10): 171 | for image in images: 172 | # 填充输入tensor 173 | interpreter.set_tensor(input_details[0]['index'], [image]) 174 | # 前向推理 175 | interpreter.invoke() 176 | # 获取输出tensor 177 | score = interpreter.get_tensor(output_details[0]['index'])[0][0] 178 | # # 结果去掉无用的维度 179 | # result = np.squeeze(score) 180 | # #print('result:{}'.format(result)) 181 | # # 输出结果是长度为10(对应0-9)的一维数据,最大值的下标就是预测的数字 182 | predictions.append(score) 183 | end_time = time.time() 184 | correct = 0 185 | for prediction, label in zip(predictions, labels): 186 | if prediction == label: 187 | correct += 1 188 | print((end_time - start_time)) 189 | print(correct / len(labels)) 190 | 191 | 192 | def train(): 193 | mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) 194 | test_images = mnist.test.images 195 | test_labels = mnist.test.labels 196 | sess = tf.Session() 197 | batch_size = BATCH_SIZE 198 | paramter_path = PARAMETER_FILE 199 | max_iter = MAX_ITER 200 | 201 | lenet = Lenet() 202 | variables = get_variables_to_restore() 203 | save_vars = [variable for variable in variables if not re.search("Adam", variable.name)] 204 | 205 | saver = tf.train.Saver(save_vars) 206 | sess.run(tf.initialize_all_variables()) 207 | # 用来显示标量信息 208 | tf.summary.scalar("loss", lenet.loss) 209 | # merge_all 可以将所有summary全部保存到磁盘,以便tensorboard显示。如果没有特殊要求, 210 | # 一般用这一句就可一显示训练时的各种信息了。 211 | summary_op = tf.summary.merge_all() 212 | # 指定一个文件用来保存图 213 | train_summary_writer = tf.summary.FileWriter("logs", sess.graph) 214 | 215 | for i in range(max_iter): 216 | batch = mnist.train.next_batch(batch_size) 217 | if i % 100 == 0: 218 | train_accuracy, summary = sess.run([lenet.train_accuracy, summary_op], feed_dict={ 219 | lenet.raw_input_image: batch[0], 220 | lenet.raw_input_label: batch[1] 221 | }) 222 | train_summary_writer.add_summary(summary) 223 | print("step %d, training accuracy %g" % (i, train_accuracy)) 224 | 225 | if i % 500 == 0: 226 | test_accuracy = sess.run(lenet.train_accuracy, feed_dict={lenet.raw_input_image: test_images, 227 | lenet.raw_input_label: test_labels}) 228 | print("\n") 229 | print("step %d, test accuracy %g" % (i, test_accuracy)) 230 | print("\n") 231 | sess.run(lenet.train_op, feed_dict={lenet.raw_input_image: batch[0], 232 | lenet.raw_input_label: batch[1]}) 233 | saver.save(sess, paramter_path) 234 | print("saved model") 235 | 236 | if __name__ == '__main__': 237 | #train() 238 | #freeze() 239 | convert_to_tflite() 240 | #origin_predict() 241 | #freeze_pb_predict() 242 | tflite_predict() 243 | 244 | --------------------------------------------------------------------------------