├── .gitattributes ├── .gitignore ├── README.md ├── other ├── fused.png └── unfused.png ├── post_training_integer_quantization ├── README.md └── post_training_integer_quantization.py ├── post_training_weight_quantization ├── README.md └── post_training_weight_quantization.py └── quantization_aware_training ├── README.md ├── quantization_aware_training.md └── quantization_aware_training.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | 62 | # Flask stuff: 63 | instance/ 64 | .webassets-cache 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # Jupyter Notebook 76 | .ipynb_checkpoints 77 | 78 | # IPython 79 | profile_default/ 80 | ipython_config.py 81 | 82 | # pyenv 83 | .python-version 84 | 85 | # pipenv 86 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 87 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 88 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 89 | # install all needed dependencies. 90 | #Pipfile.lock 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # Environments 99 | .env 100 | .venv 101 | env/ 102 | venv/ 103 | ENV/ 104 | env.bak/ 105 | venv.bak/ 106 | 107 | # Spyder project settings 108 | .spyderproject 109 | .spyproject 110 | 111 | # Rope project settings 112 | .ropeproject 113 | 114 | # mkdocs documentation 115 | /site 116 | 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # This Repository 2 | 3 | When using the Tensorflow Python API methods above to implement model quantization, I encountered several problems with both post-training quantization and quantization-aware training. I also found online that some of the problems are very common. A lot of people were seeing the exact same errors as I was. After finally getting the codes run correctly, I decided to open this repository to post what I have learnt. 4 | 5 | **This is a tutorial of model quantization using TensorFlow, with suggestions based on personal experience.** The Tensorflow version here in examples is **1.14**, and models are built with **`tf.keras`**. Since the Tensorflow team said they are working on a new package [`Tensorflow Model Optimiaztion`](https://www.tensorflow.org/model_optimization), which includes some new implementations of model quantization per their [roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap), I will keep looking for their updates and merge them into this repository if possible. 6 | 7 | The last modification here was on **10/30/2020**, where in the roadmap, the [Post training quantization for dynamic-range kernels](https://blog.tensorflow.org/2018/09/introducing-model-optimization-toolkit.html) (Post-training weight quantization), [Post training quantization for (8b) fixed-point kernels](https://blog.tensorflow.org/2019/06/tensorflow-integer-quantization.html) (Post-training integer quantization), and [Quantization aware training for (8b) fixed-point kernels and experimentation for <8b](https://blog.tensorflow.org/2020/04/quantization-aware-training-with-tensorflow-model-optimization-toolkit.html) are launched. 8 | 9 | **You are welcome to comment any issues, concerns, and suggestions, as well as anything regarding to Tensorflow updates. If you found this repository to be useful, I would like to thank you for your generosity to star** :star2: **this repository.** 10 | 11 | ## Update 10/30/2020 12 | 13 | There is another quantization tool [Qkeras](https://github.com/google/qkeras) launched by a team from Google. It has a Keras like interface, and supports both common CNN and RNN. This is a good tool for advanced quantization research, but not compatible with `tf.lite` for deployment. 14 | 15 | ## Update 6/8/2020 16 | 17 | Post Training Quantization for Hybrid Kernels now has a new official name: Post training quantization for dynamic-range kernels. 18 | 19 | The Tensorflow Model Optimiaztion package now contains a new tool to perform **quantization-aware training**, and here is the [guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_example). By default, this new tool produces a quantization-aware trained model with hybrid kernels, where only weights are stored in fixed-point value. The bias are stroed in float and the model takes float input. Based on my experiences, it only supports a subset of the layers/operations. For example, batch normalization layer(fused & unfused) is not supported. However, I think (did not try) someone can write a customized layer that perform a simplified BN operation and intergrate it with the model based on this [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide), which also includes some experimental features of the new tool. 20 | 21 | To train a fixed-point model, consider the [quantization-aware training method here in this repo](/quantization_aware_training/README.md). It is still working with **Tensorflow 1.14**, but will not work with a new version, due to the removal of `tf.contrib` and some changes on `tf.lite.TFLiteConverter`. 22 | 23 | ## TensorFlow Model Quantization (modified on 10/24/2019) 24 | 25 | The efficiency at inference time is critial when deploying machine learning models to devices with limited resources, such as IoT edge nodes and mobile devices. **Model quantization** is a tool to improve **inference efficiency**, by converting the variable data types inside a model (usually float32) into some data types with fewer numbers of bits (uint8, int8, float16, etc.), to overcome the constraints such as energy consumption, storage capacity, and computation power. 26 | 27 | TensorFlow supports two levels of model quantizations in general (see [this link](https://www.tensorflow.org/lite/performance/model_optimization)): 28 | 29 | - [Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) 30 | - [Post-training weight quantization (hybrid)](https://www.tensorflow.org/lite/performance/post_training_quant) 31 | - [Post-training integer quantization (full)](https://www.tensorflow.org/lite/performance/post_training_integer_quant) 32 | 33 | - [Quantization-aware training](https://github.com/tensorflow/tensorflow/tree/r1.14/tensorflow/contrib/quantize) (training with quantization) 34 | 35 | In general, quantization in Tensforflow uses [`tf.lite.TFLiteConverter`](https://github.com/tensorflow/docs/blob/r1.14/site/en/api_docs/python/tf/lite/TFLiteConverter.md) to convert a float-point model to a [`tf.lite`](https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/lite) model. A `tf.lite` model contains some/only fixed-point values. Some parameters of `tf.lite.TFLiteConverter` can be tuned to indicate the expected quantization method. 36 | 37 | **Post-training quantization** directly converts a trained model into a hybrid or fully-quantized `tf.lite` model using `tf.lite.TFLiteConverter`, with degradation in model accuracy. Post-training weight quantization only quantize model weights (convolution layer kernels, dense layer weights, etc.) to reduce model size and speedup computations by allowing hybrid operations (mix of fixed- and floating-point math). Post-training integer quantization fully quantize the model to support fixed-point-only hardware accelerators. 38 | 39 | **Quantization-aware training** trains a model that can be fully quantized by `tf.lite.TFLiteConverter` with minimal accuracy loss. It uses [`tf.contrib.quantize`](https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/contrib/quantize) to rewrite the training/eval graph to add fake quantization nodes. Fake quantization nodes simulate the errors introduced by quantization during training, so that the errors can be calibrated in the following training process. They also contain min/max values that are required by the `tf.lite.TFLiteConverter`. (Why min/max matter? See [here](/post_training_integer_quantization/README.md).) 40 | 41 | Comparing with quantization-aware training, post-training quantization is simpler to use, and it only requires an already-trained floating-point mode. Based on the [roadmap](https://www.tensorflow.org/model_optimization/guide/roadmap) release above, while quantization-aware training is still expected for some models that accuracy is strict required, the Tensorflow team is expecting it to be rare as they improve post-training quantization tools to a negligible accuracy loss. 42 | 43 | Quoting from Tensorflow: 44 | 45 | > In summary, a user should use “hybrid” post training quantization when targeting simple CPU size and latency improvements. When targeting greater CPU improvements or fixed-point accelerators, they should use this integer post training quantization tool, potentially using quantization-aware training if accuracy of a model suffers. 46 | 47 | **Please go each directory for details about the three model quantization tools.** 48 | -------------------------------------------------------------------------------- /other/fused.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoranREN/TensorFlow_Model_Quantization/543a2c0f57babfb5d37bb490c93b040516702ebe/other/fused.png -------------------------------------------------------------------------------- /other/unfused.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaoranREN/TensorFlow_Model_Quantization/543a2c0f57babfb5d37bb490c93b040516702ebe/other/unfused.png -------------------------------------------------------------------------------- /post_training_integer_quantization/README.md: -------------------------------------------------------------------------------- 1 | # [Post-training Integer Quantization](https://www.tensorflow.org/lite/performance/post_training_integer_quant) 2 | 3 | ###### Information below is version sensitive, time sensitive, and empirical, check the main [README.md](https://github.com/HaoranREN/TensorFlow_Model_Quantization) for details 4 | ###### See this [Google Colab ipynb](https://colab.research.google.com/drive/12tUYhjb8MbczoSgj2kjH5V2UYHrr7780) for sample output 5 | ###### Sample code is available [here](post_training_integer_quantization.py) 6 | 7 | Setting `tf.lite.TFLiteConverter.optimizations = [tf.lite.Optimize.DEFAULT]` indicates the `tf.lite.TFLiteConverter` to perform a post-training integer quantization. By doing so, `tf.lite.TFLiteConverter.representative_dataset` requires a generator function that provides some sample data for calibration. The behavior of the `tf.lite.TFLiteConverter` can be specified with parameter settings, see Inference Specifications below for details. 8 | 9 | To quantize an tensor, the main task is to calculate the two parameters `scalar` and `displacement` for value range mapping, by solving the equation set: 10 | - float_min = (fixed_min - **displacement**) / **scalar** 11 | - float_max = (fixed_max - **displacement**) / **scalar** 12 | 13 | The fixed_min/max are known from the target data type nature, but float_min/max are still needed to solve for `scalar` and `displacement`. 14 | After a regular training, the trained model only contains values in those weight tensors, i.e. only has min/max data for those tensors, so it can only be post-training weight quantized. However, by providing the `tf.lite.TFLiteConverter.representative_dataset`, sample data can flow through the entire model as it does during the train/eval processes, so that each of the other tensors can record a sample value set. With the min/max data provided by these sample sets, the model can be fully quantized. Similarly, the fake quantization nodes track the min/max data during a quantization-aware training. 15 | 16 | ## Inference Specifications 17 | 18 | - **Inputs:** defined by parameter settings, map to the range of the target data type 19 | - **Activations:** same as the data type in each of the operations 20 | - **Outputs:** defined by parameter settings 21 | - **Computation:** defined by parameter settings, check `tf.lite.TFLiteConverter.target_ops` below. If input or output data type does not match the target operation data type, an operation will be added after input or before output to cast data type and map data range. 22 | - **`tf.lite.TFLiteConverter` Parameters:** 23 | - `tf.lite.TFLiteConverter.optimizations = [tf.lite.Optimize.DEFAULT]` 24 | - `tf.lite.TFLiteConverter.representative_dataset` required for calibration 25 | - `tf.lite.TFLiteConverter.inference_input_type` and `tf.lite.TFLiteConverter.inference_output_type` set to target data type, default to `tf.float32`, supports `tf.float32, tf.uint8, tf.int8` 26 | - `tf.lite.TFLiteConverter.target_ops` default to `[tf.lite.OpsSet.TFLITE_BUILTINS]`, supports `SELECT_TF_OPS, TFLITE_BUILTINS, TFLITE_BUILTINS_INT8` 27 | - `[tf.lite.OpsSet.TFLITE_BUILTINS]`: supported operations are quantized, others remain in float-point 28 | - `[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]`: aim for an `int8` fully quantized model, operations cannot be quantized throw errors 29 | - `[tf.lite.OpsSet.SELECT_TF_OPS]`: to avoid the limitation of operations are partially supported by TensorFlow Lite (not recommended) 30 | 31 | ## Keynotes 32 | 33 | To make the `tf.lite.TFLiteConverter.representative_dataset` working, `tf.enable_eager_execution()` must be called immediate after importing Tensorflow. However, I found that there might be some unexpected behaviors during a regular training and evaluation process with eager execution enabled. The `tf.lite.Interpreter` seems working good. Although there should be some workarounds, I would suggest to reset the Python runtime both before the regular train/eval part and `tf.lite.TFLiteConverter` part, and be clear to enable eager execution or not. 34 | -------------------------------------------------------------------------------- /post_training_integer_quantization/post_training_integer_quantization.py: -------------------------------------------------------------------------------- 1 | # imports and load the MNIST data 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | (train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data() 7 | 8 | train_data = (train_data.astype('float32') / 255.0).reshape(-1,28,28,1) 9 | eval_data = (eval_data.astype('float32') / 255.0).reshape(-1,28,28,1) 10 | 11 | # build tf.keras model 12 | 13 | def build_keras_model(): 14 | return tf.keras.models.Sequential([ 15 | 16 | tf.keras.layers.Conv2D(filters = 32, kernel_size=(3,3), activation=tf.nn.relu, padding='same', input_shape=(28,28,1)), 17 | tf.keras.layers.BatchNormalization(), 18 | 19 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 20 | 21 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 22 | tf.keras.layers.BatchNormalization(), 23 | 24 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 25 | 26 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 27 | tf.keras.layers.BatchNormalization(), 28 | 29 | tf.keras.layers.Flatten(), 30 | 31 | tf.keras.layers.Dense(64, activation=tf.nn.relu), 32 | 33 | tf.keras.layers.Dense(10, activation=tf.nn.softmax) 34 | ]) 35 | 36 | # train the model as normal, without tf.enable_eager_execution() 37 | 38 | train_batch_size = 50 39 | train_epoch = 2 40 | 41 | train_model = build_keras_model() 42 | 43 | train_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) 44 | 45 | print('\n------ Train ------\n') 46 | train_model.fit(train_data, train_labels, batch_size = train_batch_size, epochs=train_epoch) 47 | 48 | print('\n------ Test ------\n') 49 | loss, acc = train_model.evaluate(eval_data, eval_labels) 50 | 51 | train_model.save('/content/drive/My Drive/Colab Notebooks/quantization_github/post_training_integer_quantization_model/trained_model.h5') 52 | 53 | ''' 54 | Run the code below after reseting the runtime. 55 | The main goal here is to enable `tf.enable_eager_execution()` immediate after importing Tensorflow, 56 | to generate the representative_dataset 57 | ''' 58 | 59 | # imports and load the MNIST data 60 | 61 | import numpy as np 62 | import tensorflow as tf 63 | 64 | tf.enable_eager_execution() 65 | 66 | (train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data() 67 | 68 | train_data = (train_data.astype('float32') / 255.0).reshape(-1,28,28,1) 69 | eval_data = (eval_data.astype('float32') / 255.0).reshape(-1,28,28,1) 70 | 71 | # convert to quantized tf.lite model 72 | 73 | images = tf.cast(train_data, tf.float32) 74 | mnist_ds = tf.data.Dataset.from_tensor_slices(images).batch(1) 75 | def representative_data_gen(): 76 | for input_value in mnist_ds.take(100): 77 | yield [input_value] 78 | 79 | converter = tf.lite.TFLiteConverter.from_keras_model_file('/content/drive/My Drive/Colab Notebooks/quantization_github/post_training_integer_quantization_model/trained_model.h5') 80 | 81 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 82 | converter.representative_dataset = tf.lite.RepresentativeDataset(representative_data_gen) 83 | 84 | converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] 85 | converter.inference_input_type = tf.int8 86 | converter.inference_output_type = tf.int8 87 | 88 | tflite_model = converter.convert() 89 | open('/content/drive/My Drive/Colab Notebooks/quantization_github/post_training_integer_quantization_model/quantized_model.tflite', 'wb').write(tflite_model) 90 | 91 | # load the quantized tf.lite model and test 92 | 93 | interpreter = tf.lite.Interpreter(model_path='/content/drive/My Drive/Colab Notebooks/quantization_github/post_training_integer_quantization_model/quantized_model.tflite') 94 | interpreter.allocate_tensors() 95 | 96 | input_details = interpreter.get_input_details() 97 | output_details = interpreter.get_output_details() 98 | 99 | acc = 0 100 | 101 | eval_data = np.array(eval_data * 255 - 128, dtype = np.int8) 102 | 103 | for i in range(eval_data.shape[0]): 104 | image = eval_data[i].reshape(1,28,28,1) 105 | 106 | interpreter.set_tensor(input_details[0]['index'], image) 107 | interpreter.invoke() 108 | prediction = interpreter.get_tensor(output_details[0]['index']) 109 | 110 | if (eval_labels[i]) == np.argmax(prediction): 111 | acc += 1 112 | 113 | print('Post-training integer quantization accuracy: ' + str(acc / len(eval_data))) 114 | 115 | ''' 116 | # check the tensor data type 117 | 118 | tensor_details = interpreter.get_tensor_details() 119 | 120 | for i in tensor_details: 121 | print(i['dtype'], i['name'], i['index']) 122 | ''' -------------------------------------------------------------------------------- /post_training_weight_quantization/README.md: -------------------------------------------------------------------------------- 1 | # [Post-training Weight Quantization](https://www.tensorflow.org/lite/performance/post_training_quant) 2 | 3 | ###### Information below is version sensitive, time sensitive, and empirical, check the main [README.md](https://github.com/HaoranREN/TensorFlow_Model_Quantization) for details 4 | ###### See this [Google Colab ipynb](https://colab.research.google.com/drive/119GkmswoaO4GZV5rQ5W9q8W2BlPeedYr) for sample output 5 | ###### Sample code is available [here](post_training_weight_quantization.py) 6 | 7 | Setting `tf.lite.TFLiteConverter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]` indicates the `tf.lite.TFLiteConverter` to perform a post-training weight quantization. In the resulting `tf.lite` model, only weights of supported operations are quantized, such as convolution layer kernels and dense layer weights. However, the first layer and the last layer of the model will not be touched since they are very sensitive for accuracy. 8 | 9 | At inference time, the input should be in float-point type and should be normalized to the same range as the training dataset. The inference computation is in a hybrid manner. For hybrid-supported operations with quantized weights, the input tensor will be quantized and perform fixed-point computation and convert the fixed-point output tensor back to float-point values. For all the other operations, it performs as a normal float-point model. 10 | 11 | ## Inference Specifications 12 | 13 | - **Inputs:** float-point type, in the same range as the training dataset 14 | - **Activations:** float-point type 15 | - **Outputs:** float-point type 16 | - **Computation:** hybrid, fixed-point for hybrid operations, float-point for others 17 | - **`tf.lite.TFLiteConverter` Parameters:** `tf.lite.TFLiteConverter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]` 18 | -------------------------------------------------------------------------------- /post_training_weight_quantization/post_training_weight_quantization.py: -------------------------------------------------------------------------------- 1 | # imports and load the MNIST data 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | (train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data() 7 | 8 | train_data = (train_data.astype('float32') / 255.0).reshape(-1,28,28,1) 9 | eval_data = (eval_data.astype('float32') / 255.0).reshape(-1,28,28,1) 10 | 11 | # build tf.keras model 12 | 13 | def build_keras_model(): 14 | return tf.keras.models.Sequential([ 15 | 16 | tf.keras.layers.Conv2D(filters = 32, kernel_size=(3,3), activation=tf.nn.relu, padding='same', input_shape=(28,28,1)), 17 | tf.keras.layers.BatchNormalization(), 18 | 19 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 20 | 21 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 22 | tf.keras.layers.BatchNormalization(), 23 | 24 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 25 | 26 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 27 | tf.keras.layers.BatchNormalization(), 28 | 29 | tf.keras.layers.Flatten(), 30 | 31 | tf.keras.layers.Dense(64, activation=tf.nn.relu), 32 | 33 | tf.keras.layers.Dense(10, activation=tf.nn.softmax) 34 | ]) 35 | 36 | # train the model as normal 37 | 38 | train_batch_size = 50 39 | train_epoch = 2 40 | 41 | train_model = build_keras_model() 42 | 43 | train_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) 44 | 45 | print('\n------ Train ------\n') 46 | train_model.fit(train_data, train_labels, batch_size = train_batch_size, epochs=train_epoch) 47 | 48 | print('\n------ Test ------\n') 49 | loss, acc = train_model.evaluate(eval_data, eval_labels) 50 | 51 | train_model.save('path_to_trained_model.h5') 52 | 53 | # convert to quantized tf.lite model 54 | 55 | converter = tf.lite.TFLiteConverter.from_keras_model_file('path_to_trained_model.h5') 56 | 57 | converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE] 58 | 59 | tflite_model = converter.convert() 60 | open('path_to_quantized_model.tflite', 'wb').write(tflite_model) 61 | 62 | # load the quantized tf.lite model and test 63 | 64 | interpreter = tf.lite.Interpreter(model_path='path_to_quantized_model.tflite') 65 | interpreter.allocate_tensors() 66 | 67 | input_details = interpreter.get_input_details() 68 | output_details = interpreter.get_output_details() 69 | 70 | acc = 0 71 | 72 | for i in range(eval_data.shape[0]): 73 | image = eval_data[i].reshape(1,28,28,1) 74 | 75 | interpreter.set_tensor(input_details[0]['index'], image) 76 | interpreter.invoke() 77 | prediction = interpreter.get_tensor(output_details[0]['index']) 78 | 79 | if (eval_labels[i]) == np.argmax(prediction): 80 | acc += 1 81 | 82 | print('Post-training weight quantization accuracy: ' + str(acc / len(eval_data))) 83 | 84 | ''' 85 | # check the tensor data type 86 | 87 | tensor_details = interpreter.get_tensor_details() 88 | 89 | for i in tensor_details: 90 | print(i['dtype'], i['name'], i['index']) 91 | ''' 92 | -------------------------------------------------------------------------------- /quantization_aware_training/README.md: -------------------------------------------------------------------------------- 1 | # [Quantization-aware Training](https://github.com/tensorflow/tensorflow/tree/r1.14/tensorflow/contrib/quantize) 2 | 3 | ###### Information below is version sensitive, time sensitive, and empirical, check the main [README.md](https://github.com/HaoranREN/TensorFlow_Model_Quantization) for details 4 | ###### See [quantization_aware_training.md](quantization_aware_training.md) for some code-side comments 5 | ###### See this [Google Colab ipynb](https://colab.research.google.com/drive/1hD_G2qD3ptlH9zrpT4GtDCD0GwXjt7K-) for sample output 6 | ###### Sample code is available [here](quantization_aware_training.py) 7 | 8 | Setting `tf.lite.TFLiteConverter.inference_type` to `tf.uint8` signals conversion to a fully quantized model, **only** from a quantization-aware trained input model. A quantization-aware trained model contains fake quantization nodes added by [`tf.contrib.quantize`](https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/contrib/quantize). Since `tf.contrib.quantize` rewrites the training/eval graph, the `tf.lite.TFLiteConverter` should be constructed by `tf.lite.TFLiteConverter.from_frozen_graph()`. This setup also requires `tf.lite.TFLiteConverter.quantized_input_stats` to be set. This parameter contains a **scalar** value and a **displacement** value of how to map the input data to values in the range of the inference data type (i.e. 0 - 255 for `uint8`), with the equation `real_input_value = (quantized_input_value - **displacement**) / **scalar**`. See the description of [`tf.lite.TFLiteConverter`](https://github.com/tensorflow/docs/blob/r1.14/site/en/api_docs/python/tf/lite/TFLiteConverter.md) for more information. 9 | 10 | This conversion aims for a fully quantized model, but any operations that do not have quantized implementations will throw errors. 11 | 12 | ## Inference Specifications 13 | 14 | - **Inputs:** `unit8` type, map the original data to the range of 0 - 255 15 | - **Activations:** `unit8` type 16 | - **Outputs:** `uint8` type 17 | - **Computation:** all in fixed-point type 18 | - **`tf.lite.TFLiteConverter` Parameters:** 19 | - Construct by `tf.lite.TFLiteConverter.from_frozen_graph()` 20 | - `tf.lite.TFLiteConverter.inference_type = tf.uint8` 21 | - `tf.lite.TFLiteConverter.quantized_input_stats = {'input_layer_name': (displacement, scalar)}` 22 | 23 | ## Keynotes 24 | 25 | Quantization-aware training uses the `tf.contrib` module, so it is somehow an 'experimental' feature of Tensorflow. This Tensorflow team [webpage](https://www.tensorflow.org/lite/performance/model_optimization) says it is only available for a subset of convolutional neural network architectures. An unsupported architecture, is usually a tensor, which `tf.lite.TFLiteConverter` requires range information of it for the conversion, but `tf.contrib.quantize` dose not have the fake quantization implementation for it, so that there is no min/max value associate with that tensor. In this case, a error message of lacking min/max data will be thrown, like: 26 | 27 | ``` 28 | F tensorflow/lite/toco/tooling_util.cc:1728] Array batch_normalization/FusedBatchNormV3_mul_0, which is an input to the Add operator producing the output array batch_normalization/FusedBatchNormV3, is lacking min/max data, which is necessary for quantization. If accuracy matters, either target a non-quantized output format, or run quantized training with your model from a floating point checkpoint to change the input graph to contain min/max information. If you don't care about accuracy, you can pass --default_ranges_min= and --default_ranges_max= for easy experimentation. 29 | Aborted (core dumped) 30 | ``` 31 | 32 | For some of the unsupported architectures, there are some 'tricks' I found based on my experience, to work around some common circumstances (listed below). In general, like prompted in the error message, the min/max lacking problem can be bypassed by setting the default range parameter `tf.lite.TFLiteConverter.default_ranges_stats`, but with an accuracy loss. To achieve better accuracy, my suggestion is, make sure the default range covers all the values inside all the min/max lacking tensors. 33 | 34 | Also, not all Tensorflow operations are supported by `tf.lite`. The compatibility is listed in this [webpage](https://www.tensorflow.org/lite/guide/ops_compatibility). According to this page, operations may be elided or fused, before the supported operations are mapped to their TensorFlow Lite counterparts. This means the operation sequence or layer order matters. I did encounter some problems of supported operations being not supported in some operation combinations. 35 | 36 | The supportability issues are often version sensitive, even some online resources are very helpful, some can also be outdated. Thus, when facing such a supportability issue, my suggestion would be, take some experiments to do whatever can be done to modify the model, even with some minor behavior changes if accuracy is acceptable, such as skip a layer or change layer order. Below are some 'tricks' that I found, which worked well with my experiments. 37 | 38 | Some ideas/code are retrieved from these [discussions](https://github.com/tensorflow/tensorflow/issues/27880). 39 | 40 | ###### Batch Normalization Layers 41 | 42 | For the best compatibility, when using batch normalization layer and convolution layer combinations, the batch normalization layer should come after a convolution layer. If it throws an error message similar to the one above, of `FusedBatchNormV?` is lacking min/max data, try to use an unfused batch normalization layer. For example in `tf.keras`, use 43 | 44 | ```python 45 | tf.keras.layers.BatchNormalization(fused=False) 46 | ``` 47 | 48 | The differences are, a fused batch normalization layer is kind of a wrapper layer of several batch normalization operations, and an unfused batch normalization layer leaves those operations individually. There is no fake quantization implementation for that wrapper layer, but for the individual operations. 49 | 50 | | Fused | Unfused| 51 | | --- | --- | 52 | | ![Fused](/other/fused.png) | ![Unfused](/other/unfused.png) | 53 | 54 | ###### Activation Layers 55 | 56 | The best practice of an activation layer is combining it with the preceding layer, since some stand alone activation layers are not supported. For example, the code below is part of ResNetV1. The activation layer on the last line throws a min/max lacking error. I tried skipping it as well as combining the activations into the two layers in line 6 and line 9. Both gave acceptable accuracies. 57 | 58 | ```python 59 | 1 for res_block in range(num_res_blocks): 60 | 2 strides = 1 61 | 3 if stack > 0 and res_block == 0: # first layer but not first stack 62 | 4 strides = 2 63 | 5 y = resnet_layer(inputs=x, num_filters=num_filters, strides=strides) 64 | 6 y = resnet_layer(inputs=y, num_filters=num_filters, activation=None) 65 | 7 66 | 8 if stack > 0 and res_block == 0: # first layer but not first stack 67 | 9 x = resnet_layer(inputs=x, num_filters=num_filters, kernel_size=1, 68 | 10 strides=strides, activation=None, batch_normalization=False) 69 | 11 x = tf.keras.layers.add([x, y]) 70 | 12 x = tf.keras.layers.Activation('relu')(x) 71 | ``` 72 | 73 | Some activation functions are not supported. For example, I also tried using `softplus` in the two layers, but with no success. 74 | 75 | ###### Convolution Layer Bias 76 | 77 | Mostly, bias tensors are supported. However, in some models, for example an `uint8` target model, if the bias are greater than 255, which cannot be represented by `uint8` data type, bias would be converted to `int32` type. Even it is still good for fixed-point only inference computation, if something unexpected happens with bias, try to set layer parameter `use_bias = False`. 78 | -------------------------------------------------------------------------------- /quantization_aware_training/quantization_aware_training.md: -------------------------------------------------------------------------------- 1 | ## Imports and load the MNIST data 2 | 3 | ```python 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | (train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data() 8 | 9 | train_data = (train_data.astype('float32') / 255.0).reshape(-1,28,28,1) 10 | eval_data = (eval_data.astype('float32') / 255.0).reshape(-1,28,28,1) 11 | ``` 12 | 13 | ## Build tf.keras model 14 | 15 | Concerning the unfused batch normalization layers and details in model supportability, check the [README.md](https://github.com/HaoranREN/TensorFlow_Model_Quantization/tree/master/quantization_aware_training) for details. 16 | 17 | ```python 18 | def build_keras_model(): 19 | return tf.keras.models.Sequential([ 20 | 21 | tf.keras.layers.Conv2D(filters = 32, kernel_size=(3,3), activation=tf.nn.relu, padding='same', input_shape=(28,28,1)), 22 | tf.keras.layers.BatchNormalization(fused=False), 23 | 24 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 25 | 26 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 27 | tf.keras.layers.BatchNormalization(fused=False), 28 | 29 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 30 | 31 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 32 | tf.keras.layers.BatchNormalization(fused=False), 33 | 34 | tf.keras.layers.Flatten(), 35 | 36 | tf.keras.layers.Dense(64, activation=tf.nn.relu), 37 | 38 | tf.keras.layers.Dense(10, activation=tf.nn.softmax) 39 | ]) 40 | ``` 41 | 42 | ## Train the model, quantization-aware training after $[quant_delay] steps 43 | 44 | Since quantization-aware training required fake quantization nodes are added in training/eval graphs, we should gain access to them by defining graphs manully before the training/eval session. Also, we are targeting a quantized model for inference, so the final conversion should be done with an eval graph. Thus, we save the model to checkpoint at the end of the training session, and load it again for saving to a frozen graph in an eval session. 45 | 46 | **The `quant_delay` parameter is the step to start quantization-aware training. In practice, letting the model to be trained with float-point first for a number of steps helps model convergency. The unit is step, so be clear with the math.** 47 | 48 | ```python 49 | train_batch_size = 50 50 | train_batch_number = train_data.shape[0] 51 | quant_delay_epoch = 1 52 | 53 | train_graph = tf.Graph() 54 | train_sess = tf.Session(graph=train_graph) 55 | 56 | tf.keras.backend.set_session(train_sess) 57 | with train_graph.as_default(): 58 | train_model = build_keras_model() 59 | 60 | tf.contrib.quantize.create_training_graph(input_graph=train_graph, quant_delay=int(train_batch_number / train_batch_size * quant_delay_epoch)) 61 | 62 | train_sess.run(tf.global_variables_initializer()) 63 | 64 | train_model.compile( 65 | optimizer='adam', 66 | loss='sparse_categorical_crossentropy', 67 | metrics=['accuracy'] 68 | ) 69 | 70 | print('\n------ Train ------\n') 71 | train_model.fit(train_data, train_labels, batch_size = train_batch_size, epochs=quant_delay_epoch * 2) 72 | 73 | print('\n------ Test ------\n') 74 | loss, acc = train_model.evaluate(eval_data, eval_labels) 75 | 76 | saver = tf.train.Saver() 77 | saver.save(train_sess, 'path_to_checkpoints') 78 | ``` 79 | 80 | ## Save the frozen graph 81 | 82 | Load the model in an eval session. Be sure to use `tf.contrib.quantize.create_eval_graph()` for here. 83 | 84 | **Remember to set `tf.keras.backend.set_learning_phase(0)`, indicates testing time. This is important.** 85 | 86 | ```python 87 | eval_graph = tf.Graph() 88 | eval_sess = tf.Session(graph=eval_graph) 89 | 90 | tf.keras.backend.set_session(eval_sess) 91 | 92 | with eval_graph.as_default(): 93 | tf.keras.backend.set_learning_phase(0) 94 | eval_model = build_keras_model() 95 | tf.contrib.quantize.create_eval_graph(input_graph=eval_graph) 96 | eval_graph_def = eval_graph.as_graph_def() 97 | saver = tf.train.Saver() 98 | saver.restore(eval_sess, 'path_to_checkpoints') 99 | 100 | frozen_graph_def = tf.graph_util.convert_variables_to_constants( 101 | eval_sess, 102 | eval_graph_def, 103 | [eval_model.output.op.name] 104 | ) 105 | 106 | with open('path_to_frozen_graph.pb', 'wb') as f: 107 | f.write(frozen_graph_def.SerializeToString()) 108 | ``` 109 | 110 | ## Convert to quantized tf.lite model 111 | 112 | Fake quantization nodes are added in training/eval graphs, so we should use `tf.lite.TFLiteConverter.from_frozen_graph()` to construct the `tf.lite.TFLiteConverter`. The parameters of the function are path to the frozen graph file, the input tensors names of the frozen graph as a list of strings, and the output tensors names of the frozen graph as a list of strings. 113 | 114 | See the [README.md](https://github.com/HaoranREN/TensorFlow_Model_Quantization/tree/master/quantization_aware_training) and the description of [`tf.lite.TFLiteConverter`](https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/lite/TFLiteConverter) for more information about parameter setup. 115 | 116 | ```python 117 | input_max = np.max(train_data) 118 | input_min = np.min(train_data) 119 | converter_std = 255 / (input_max - input_min) 120 | converter_mean = -(input_min * converter_std) 121 | 122 | converter = tf.lite.TFLiteConverter.from_frozen_graph('path_to_frozen_graph.pb', 123 | ['conv2d_input'], 124 | ['dense_1/Softmax']) 125 | converter.inference_type = tf.uint8 126 | converter.quantized_input_stats = {'conv2d_input':(converter_mean, converter_std)} 127 | #converter.default_ranges_stats = (0,1) 128 | tflite_model = converter.convert() 129 | open('path_to_quantized_model.tflite', 'wb').write(tflite_model) 130 | ``` 131 | 132 | ## Load the quantized tf.lite model and test 133 | 134 | Construct a `tf.lite.Interpreter` to use the quantized model for inference. After setting the input tensor and invoking the `tf.lite.Interpreter`, we can have all the output tensors inside the model inference. 135 | 136 | **For a `tf.lite` model from quantization-aware training, be sure to map the values in your image for inference to the range of `uint8` (i.e. 0 - 255).** 137 | 138 | Check the description of [`tf.lite.Interpreter`](https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/lite/Interpreter) for more details about function behavior. 139 | 140 | ```python 141 | interpreter = tf.lite.Interpreter(model_path='path_to_quantized_model.tflite') 142 | interpreter.allocate_tensors() 143 | 144 | input_details = interpreter.get_input_details() 145 | output_details = interpreter.get_output_details() 146 | 147 | quantize_eval_data = np.array(eval_data * 255, dtype = np.uint8) 148 | acc = 0 149 | 150 | for i in range(quantize_eval_data.shape[0]): 151 | quantize_image = quantize_eval_data[i] 152 | quantize_image = quantize_image.reshape(1,28,28,1) 153 | 154 | interpreter.set_tensor(input_details[0]['index'], quantize_image) 155 | interpreter.invoke() 156 | prediction = interpreter.get_tensor(output_details[0]['index']) 157 | 158 | if (eval_labels[i]) == np.argmax(prediction): 159 | acc += 1 160 | 161 | print('Quantization-aware training accuracy: ' + str(acc / len(eval_data))) 162 | ``` 163 | 164 | ## Check the tensor data type 165 | 166 | We can also check the information of all the tensors inside the `tf.lite` model. 167 | 168 | ```python 169 | tensor_details = interpreter.get_tensor_details() 170 | for i in tensor_details: 171 | print(i['dtype'], i['name'], i['index']) 172 | ``` 173 | -------------------------------------------------------------------------------- /quantization_aware_training/quantization_aware_training.py: -------------------------------------------------------------------------------- 1 | # imports and load the MNIST data 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | (train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data() 7 | 8 | train_data = (train_data.astype('float32') / 255.0).reshape(-1,28,28,1) 9 | eval_data = (eval_data.astype('float32') / 255.0).reshape(-1,28,28,1) 10 | 11 | # build tf.keras model 12 | 13 | def build_keras_model(): 14 | return tf.keras.models.Sequential([ 15 | 16 | tf.keras.layers.Conv2D(filters = 32, kernel_size=(3,3), activation=tf.nn.relu, padding='same', input_shape=(28,28,1)), 17 | tf.keras.layers.BatchNormalization(fused=False), 18 | 19 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 20 | 21 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 22 | tf.keras.layers.BatchNormalization(fused=False), 23 | 24 | tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=2), 25 | 26 | tf.keras.layers.Conv2D(filters = 64, kernel_size=(3,3), activation=tf.nn.relu, padding='same'), 27 | tf.keras.layers.BatchNormalization(fused=False), 28 | 29 | tf.keras.layers.Flatten(), 30 | 31 | tf.keras.layers.Dense(64, activation=tf.nn.relu), 32 | 33 | tf.keras.layers.Dense(10, activation=tf.nn.softmax) 34 | ]) 35 | 36 | # train the model, quantization-aware training after $[quant_delay] steps 37 | 38 | train_batch_size = 50 39 | train_batch_number = train_data.shape[0] 40 | quant_delay_epoch = 1 41 | 42 | train_graph = tf.Graph() 43 | train_sess = tf.Session(graph=train_graph) 44 | 45 | tf.keras.backend.set_session(train_sess) 46 | with train_graph.as_default(): 47 | train_model = build_keras_model() 48 | 49 | tf.contrib.quantize.create_training_graph(input_graph=train_graph, quant_delay=int(train_batch_number / train_batch_size * quant_delay_epoch)) 50 | 51 | train_sess.run(tf.global_variables_initializer()) 52 | 53 | train_model.compile( 54 | optimizer='adam', 55 | loss='sparse_categorical_crossentropy', 56 | metrics=['accuracy'] 57 | ) 58 | 59 | print('\n------ Train ------\n') 60 | train_model.fit(train_data, train_labels, batch_size = train_batch_size, epochs=quant_delay_epoch * 2) 61 | 62 | print('\n------ Test ------\n') 63 | loss, acc = train_model.evaluate(eval_data, eval_labels) 64 | 65 | saver = tf.train.Saver() 66 | saver.save(train_sess, 'path_to_checkpoints') 67 | 68 | # save the frozen graph 69 | 70 | eval_graph = tf.Graph() 71 | eval_sess = tf.Session(graph=eval_graph) 72 | 73 | tf.keras.backend.set_session(eval_sess) 74 | 75 | with eval_graph.as_default(): 76 | tf.keras.backend.set_learning_phase(0) 77 | eval_model = build_keras_model() 78 | tf.contrib.quantize.create_eval_graph(input_graph=eval_graph) 79 | eval_graph_def = eval_graph.as_graph_def() 80 | saver = tf.train.Saver() 81 | saver.restore(eval_sess, 'path_to_checkpoints') 82 | 83 | frozen_graph_def = tf.graph_util.convert_variables_to_constants( 84 | eval_sess, 85 | eval_graph_def, 86 | [eval_model.output.op.name] 87 | ) 88 | 89 | with open('path_to_frozen_graph.pb', 'wb') as f: 90 | f.write(frozen_graph_def.SerializeToString()) 91 | 92 | # convert to quantized tf.lite model 93 | 94 | input_max = np.max(train_data) 95 | input_min = np.min(train_data) 96 | converter_std = 255 / (input_max - input_min) 97 | converter_mean = -(input_min * converter_std) 98 | 99 | converter = tf.lite.TFLiteConverter.from_frozen_graph('path_to_frozen_graph.pb', 100 | ['conv2d_input'], 101 | ['dense_1/Softmax']) 102 | converter.inference_type = tf.uint8 103 | converter.quantized_input_stats = {'conv2d_input':(converter_mean, converter_std)} 104 | #converter.default_ranges_stats = (0,1) 105 | tflite_model = converter.convert() 106 | open('path_to_quantized_model.tflite', 'wb').write(tflite_model) 107 | 108 | # load the quantized tf.lite model and test 109 | 110 | interpreter = tf.lite.Interpreter(model_path='path_to_quantized_model.tflite') 111 | interpreter.allocate_tensors() 112 | 113 | input_details = interpreter.get_input_details() 114 | output_details = interpreter.get_output_details() 115 | 116 | quantize_eval_data = np.array(eval_data * 255, dtype = np.uint8) 117 | acc = 0 118 | 119 | for i in range(quantize_eval_data.shape[0]): 120 | quantize_image = quantize_eval_data[i] 121 | quantize_image = quantize_image.reshape(1,28,28,1) 122 | 123 | interpreter.set_tensor(input_details[0]['index'], quantize_image) 124 | interpreter.invoke() 125 | prediction = interpreter.get_tensor(output_details[0]['index']) 126 | 127 | if (eval_labels[i]) == np.argmax(prediction): 128 | acc += 1 129 | 130 | print('Quantization-aware training accuracy: ' + str(acc / len(eval_data))) 131 | 132 | ''' 133 | # check the tensor data type 134 | 135 | tensor_details = interpreter.get_tensor_details() 136 | 137 | for i in tensor_details: 138 | print(i['dtype'], i['name'], i['index']) 139 | ''' 140 | --------------------------------------------------------------------------------