├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── freeze.sh ├── load_tflite.py ├── model.py ├── quantization.sh ├── requirements.txt ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | MNIST-data 3 | *.pb 4 | *.ckpt.* 5 | *.tflite 6 | checkpoint 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # celery beat schedule file 86 | celerybeat-schedule 87 | 88 | # SageMath parsed files 89 | *.sage.py 90 | 91 | # Environments 92 | .env 93 | .venv 94 | env/ 95 | venv/ 96 | ENV/ 97 | env.bak/ 98 | venv.bak/ 99 | 100 | # Spyder project settings 101 | .spyderproject 102 | .spyproject 103 | 104 | # Rope project settings 105 | .ropeproject 106 | 107 | # mkdocs documentation 108 | /site 109 | 110 | # mypy 111 | .mypy_cache/ 112 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | python: 4 | - '3.6' 5 | 6 | install: 7 | - pip install -r requirements.txt 8 | 9 | script: 10 | - python train.py 11 | - python test.py 12 | - sh ./freeze.sh 13 | - sh ./quantization.sh 14 | - python ./load_tflite.py 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow Quantization Example 2 | 3 | [![Build Status](https://travis-ci.com/tutorials-with-ci/tensorflow-quantization-example.svg?branch=master)](https://travis-ci.com/tutorials-with-ci/tensorflow-quantization-example) 4 | 5 | TensorFlow Quantization Example, for TensorFlow Lite 6 | 7 | (There are still problems and we are looking for a solution.) 8 | 9 | ## Steps 10 | 11 | Same as the steps in the configuration file: 12 | 13 | ```yml 14 | language: python 15 | 16 | python: 17 | - '3.6' 18 | 19 | install: 20 | - pip install -r requirements.txt 21 | 22 | script: 23 | - python train.py 24 | - python test.py 25 | - sh ./freeze.sh 26 | - sh ./quantization.sh 27 | - python ./load_tflite.py 28 | ``` 29 | -------------------------------------------------------------------------------- /freeze.sh: -------------------------------------------------------------------------------- 1 | freeze_graph \ 2 | --input_graph=./eval.pb \ 3 | --input_checkpoint=./local.ckpt \ 4 | --output_graph=./frozen_eval.pb \ 5 | --output_node_names=prob 6 | -------------------------------------------------------------------------------- /load_tflite.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tensorflow.lite.python.interpreter import Interpreter 3 | from tensorflow.examples.tutorials.mnist import input_data 4 | 5 | 6 | class TfLiteModel: 7 | def __init__(self, model_content): 8 | self.model_content = bytes(model_content) 9 | self.interpreter = Interpreter(model_content=self.model_content) 10 | input_details = self.interpreter.get_input_details() 11 | output_details = self.interpreter.get_output_details() 12 | print(input_details) 13 | self.input_index = input_details[0]['index'] 14 | self.output_index = output_details[0]['index'] 15 | 16 | self.input_scale, self.input_zero_point = input_details[0]['quantization'] 17 | self.output_scale, self.output_zero_point = output_details[0]['quantization'] 18 | 19 | self.interpreter.allocate_tensors() 20 | 21 | def forward(self, data_in): 22 | test_input = np.array(data_in / self.input_scale + self.input_zero_point, dtype=np.uint8).reshape(1, -1) 23 | self.interpreter.set_tensor(self.input_index, test_input) 24 | self.interpreter.invoke() 25 | 26 | output_data = self.interpreter.get_tensor(self.output_index)[0] 27 | return (np.array(output_data, dtype=np.float32) - self.output_zero_point) * self.output_scale 28 | 29 | 30 | mnist = input_data.read_data_sets('MNIST-data', one_hot=True) 31 | batch = mnist.train.next_batch(1) 32 | image, label = batch[0], batch[1] 33 | 34 | model_path = './final.tflite' 35 | with open(model_path, 'rb') as f: 36 | model_content = f.read() 37 | 38 | model = TfLiteModel(model_content) 39 | predict = model.forward(image) 40 | 41 | print("TF-Lite Output: {}".format(predict)) 42 | print("Ground Truth: {}".format(label)) 43 | print("Right? {}".format(np.argmax(predict) == np.argmax(label))) 44 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def simple_conv_net(x, is_training: bool = False): 5 | x = tf.reshape(x, shape=[-1, 28, 28, 1]) 6 | 7 | x = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu) 8 | x = tf.layers.max_pooling2d(x, 2, 2) 9 | 10 | x = tf.layers.conv2d(x, 64, 3, activation=tf.nn.relu) 11 | x = tf.layers.max_pooling2d(x, 2, 2) 12 | 13 | x = tf.contrib.layers.flatten(x) 14 | x = tf.layers.dense(x, 1024) 15 | x = tf.layers.dropout(x, rate=0.5, training=is_training) 16 | x = tf.layers.dense(x, 10) 17 | 18 | return x 19 | -------------------------------------------------------------------------------- /quantization.sh: -------------------------------------------------------------------------------- 1 | tflite_convert \ 2 | --output_file=./final.tflite \ 3 | --graph_def_file=./frozen_eval.pb \ 4 | --inference_type=QUANTIZED_UINT8 \ 5 | --input_type=QUANTIZED_UINT8 \ 6 | --input_arrays=input \ 7 | --output_arrays=prob \ 8 | --mean_values=128 \ 9 | --std_dev_values=127 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tensorflow>=1.12.1 3 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from model import simple_conv_net 4 | from tensorflow.examples.tutorials.mnist import input_data 5 | 6 | 7 | mnist = input_data.read_data_sets('MNIST-data', one_hot=True) 8 | 9 | sess = tf.Session() 10 | 11 | x = tf.placeholder(tf.float32, shape=[None, 784], name='input') 12 | logits = simple_conv_net(x, is_training=False) 13 | y = tf.nn.softmax(logits, name='prob') 14 | 15 | tf.contrib.quantize.create_eval_graph() 16 | 17 | saver = tf.train.Saver() 18 | saver.restore(sess, './local.ckpt') 19 | 20 | with open('eval.pb', 'w') as f: 21 | g = tf.get_default_graph() 22 | f.write(str(g.as_graph_def())) 23 | 24 | batch = mnist.train.next_batch(100) 25 | results = sess.run(y, feed_dict={x: batch[0]}) 26 | 27 | truth = np.argmax(batch[1], -1) 28 | predict = np.argmax(results, -1) 29 | print(np.mean(truth == predict)) 30 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from model import simple_conv_net 3 | from tensorflow.examples.tutorials.mnist import input_data 4 | 5 | 6 | mnist = input_data.read_data_sets('MNIST-data', one_hot=True) 7 | sess = tf.Session() 8 | 9 | x = tf.placeholder(tf.float32, shape=[None, 784], name='input') 10 | y_ = tf.placeholder(tf.float32, shape=[None, 10], name='label') 11 | 12 | logits = simple_conv_net(x, is_training=True) 13 | y = tf.nn.softmax(logits, name='prob') 14 | 15 | cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y)) 16 | 17 | tf.contrib.quantize.create_training_graph() 18 | 19 | sess.run(tf.global_variables_initializer()) 20 | train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) 21 | 22 | for i in range(3000): 23 | batch = mnist.train.next_batch(128) 24 | sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]}) 25 | 26 | if (i + 1) % 100 == 0: 27 | print('Iteration: {: 4d}'.format(i + 1)) 28 | 29 | saver = tf.train.Saver() 30 | saver.save(sess, './local.ckpt') 31 | --------------------------------------------------------------------------------