├── .gitattributes ├── .gitignore ├── CustomStraightThroughEstimator.p ├── IdentityTrainingLayer.m ├── LICENSE.md ├── QuantizationAwareTrainingWithMobilenetv2.mlx ├── QuantizedConvolutionBatchNormTrainingLayer.m ├── QuantizedConvolutionTrainingLayer.m ├── README.md ├── SECURITY.md ├── bypassdlgradients.m ├── bypassdlgradients.p ├── foldBatchNormalizationParameters.m ├── foldBatchNormalizationParameters.p ├── images ├── original_inference.png ├── qat_workflow.png ├── quantized_inference.png ├── quantized_training.png └── ste.png ├── index.html └── quantizeToFloat.m /.gitattributes: -------------------------------------------------------------------------------- 1 | # gitattributes for a MATLAB repo. 2 | 3 | # Source files 4 | # ============ 5 | *.m text diff=matlab 6 | 7 | # Binary files 8 | # ============ 9 | *.mlx binary 10 | *.p binary 11 | *.mex* binary 12 | *.fig binary 13 | *.mat binary 14 | *.mdl binary 15 | *.slx binary 16 | *.mdlp binary 17 | *.slxp binary 18 | *.sldd binary 19 | *.mltbx binary 20 | *.mlappinstall binary 21 | *.mlpkginstall binary 22 | *.mn binary -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Flower Data 2 | flower_dataset.tgz 3 | flower_photos/ 4 | 5 | # Flower Data 6 | internal/ -------------------------------------------------------------------------------- /CustomStraightThroughEstimator.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/CustomStraightThroughEstimator.p -------------------------------------------------------------------------------- /IdentityTrainingLayer.m: -------------------------------------------------------------------------------- 1 | classdef IdentityTrainingLayer < nnet.layer.Layer 2 | %% IdentityTrainingLayer that returns the input as output 3 | 4 | % Copyright 2023 The Mathworks, Inc. 5 | 6 | methods 7 | 8 | function obj = IdentityTrainingLayer(originalLayer) 9 | obj.Name = originalLayer.Name; 10 | obj.Type = "Identity Training Layer"; 11 | obj.Description = "No operation to forward behavior"; 12 | end 13 | 14 | 15 | function X = predict(layer, X) 16 | % No op - return the input as output 17 | end 18 | 19 | end 20 | 21 | end 22 | 23 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022, The MathWorks, Inc. 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /QuantizationAwareTrainingWithMobilenetv2.mlx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/QuantizationAwareTrainingWithMobilenetv2.mlx -------------------------------------------------------------------------------- /QuantizedConvolutionBatchNormTrainingLayer.m: -------------------------------------------------------------------------------- 1 | classdef QuantizedConvolutionBatchNormTrainingLayer < nnet.layer.Layer & nnet.layer.Formattable 2 | %% QuantizedConvolutionBatchNormTrainingLayer Layer for Quantization Aware Training 3 | % This custom layer introduces quantization error to a 4 | % fused convolution layer and a batch normalization layer during 5 | % training. 6 | 7 | % Copyright 2023 The Mathworks, Inc. 8 | 9 | 10 | properties (Learnable) 11 | Network 12 | end 13 | 14 | methods 15 | function obj = QuantizedConvolutionBatchNormTrainingLayer(cLayer, bLayer) 16 | % Freeze the Scale and Offset Learn Factor of the 17 | % BatchNormalizaiton Layer so to use the statistics collected 18 | % at training of the original network 19 | bLayer.ScaleLearnRateFactor = 0; 20 | bLayer.OffsetLearnRateFactor = 0; 21 | 22 | % Construct a dlnetwork as the Learnable of this custom layer 23 | obj.Network = dlnetwork([cLayer bLayer], 'Initialize', false); 24 | 25 | obj.Name = cLayer.Name; 26 | obj.Description = "Quantization Aware Conv-BN Layer Group for Training"; 27 | obj.Type = "Quantized Fused Convolution Layer"; 28 | end 29 | 30 | function Z = predict(layer, X) 31 | % Call predict on the underlying network if the network is not 32 | % yet initialized to avoid errors in inspecting the LayerGraph 33 | % before training. 34 | if ~layer.Network.Initialized 35 | Z = predict(layer.Network, X); 36 | return; 37 | end 38 | 39 | % Calculate the adjusted Weights and Bias of the convolution 40 | % layer in the underlying network during fusion. 41 | [adjustedWeights, adjustedBias] = foldBatchNormalizationParameters(layer.Network); 42 | 43 | % Quantize adjusted Weights to float. 44 | adjustedWeights = quantizeToFloat(adjustedWeights); 45 | 46 | % Recreate the learnables table using the adjusted Weights and 47 | % Bias. 48 | newLearnables = layer.Network.Learnables; 49 | newLearnables.Value{1} = adjustedWeights; 50 | newLearnables.Value{2} = adjustedBias; 51 | 52 | % Set learnables back on the Network. 53 | layer.Network.Learnables = newLearnables; 54 | 55 | % Call predict on the underlying Network tapping the 56 | % activations of the convolution layer only since the 57 | % batchNormalization has already been applied during the fusion 58 | % of foldBatchNormalizationParameters. 59 | Z = predict(layer.Network, X, 'Outputs', layer.Name); 60 | 61 | % Quantize the activation to flaot. 62 | Z = quantizeToFloat(Z); 63 | end 64 | 65 | end 66 | 67 | end 68 | -------------------------------------------------------------------------------- /QuantizedConvolutionTrainingLayer.m: -------------------------------------------------------------------------------- 1 | classdef QuantizedConvolutionTrainingLayer < nnet.layer.Layer & nnet.layer.Formattable 2 | %% QuantizedConvolutionTrainingLayer Layer for Quantization Aware Training 3 | % This custom layer introduces quantization error to a 4 | % convolution layer during training. 5 | 6 | % Copyright 2023 The Mathworks, Inc. 7 | 8 | 9 | properties (Learnable) 10 | Network 11 | end 12 | 13 | methods 14 | function obj = QuantizedConvolutionTrainingLayer(cLayer) 15 | % Construct a dlnetwork as the Learnable of this custom layer 16 | obj.Network = dlnetwork(cLayer, 'Initialize', false); 17 | 18 | obj.Name = cLayer.Name; 19 | obj.Description = "Quantization Aware Conv Layer for Training"; 20 | obj.Type = "Quantized Convolution Layer"; 21 | end 22 | 23 | function Z = predict(layer, X) 24 | % Call predict on the underlying network if the network is not 25 | % yet initialized to avoid errors in inspecting the LayerGraph 26 | % before training. 27 | if ~layer.Network.Initialized 28 | Z = predict(layer.Network, X); 29 | return; 30 | end 31 | 32 | % Capture the Weights of the convolution 33 | % layer in the underlying network 34 | weights = layer.Network.Learnables.Value{1}; 35 | 36 | % Quantize the Weights to float. 37 | weights = quantizeToFloat(weights); 38 | 39 | % Set learnables back on the Network. 40 | layer.Network.Learnables.Value{1} = weights; 41 | 42 | % Call predict on the underlying Network 43 | Z = predict(layer.Network, X); 44 | 45 | % Quantize the activation to float. 46 | Z = quantizeToFloat(Z); 47 | end 48 | 49 | end 50 | 51 | end -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quantization Aware Training with MobileNet-v2 2 | 3 | [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/quantization-aware-training&file=QuantizationAwareTrainingWithMobilenetv2.mlx) 4 | [![View Quantization Aware Training with MobileNet-v2 on File Exchange](https://www.mathworks.com/matlabcentral/images/matlab-file-exchange.svg)](https://www.mathworks.com/matlabcentral/fileexchange/125420-quantization-aware-training-with-mobilenet-v2) 5 | 6 | This example shows how to perform quantization aware training as a way to prepare a network for quantization. Quantization aware training is a method that can help recover accuracy lost due to quantizing a network to use 8-bit scaled integer weights and biases. Networks like MobileNet-v2 are especially sensitive to quantization due to the significant variation in range of values of the weight tensor of the convolution and grouped convolution layers. 7 | 8 | This example shows how pre-processing a network with quantization aware training can produce a quantized network with accuracy on par with the original unquantized network. Note that the values in this table may differ slightly. 9 | 10 | | Network | Accuracy | 11 | | ----------- | ----------- | 12 | | Original network | **0.9101** | 13 | | `int8` network via post-training quantization | 0.2452 | 14 | | `int8` network via quantization aware training | **0.8937** | 15 | 16 | ## **Running the Example** 17 | 18 | Open and run the live script `QuantizationAwareTrainingWithMobilenetv2.mlx`. 19 | 20 | Additional files: 21 | 22 | - `QuantizedConvolutionBatchNormTrainingLayer`: Custom layer that implements quantization aware fused convolution-batch normalization layer. 23 | - `QuantizedConvolutionTrainingLayer`: Custom layer unused in this example but can be applied to networks with convolution layers without batch normalization. 24 | - `IdentityTrainingLayer`: No-op layer that acts as a placeholder for batch normalization layers. 25 | - `quantizeToFloat`: Function to quantize the values to a floating point representation. 26 | - `bypassdlgradients`: Function to perform straight through estimation for a given operation. The source of this function is obfuscated because of the use of internal packages. 27 | - `foldBatchNormalizationParameters`: function to calculate the adjusted weights and bias for dlnetwork that contains a convolution layer followed by a batch normalization layer. The source of this function is obfuscated because of the use of internal packages. 28 | - `CustomStraightThroughEstimator`: helper class used by `bypassdlgradients` and should not be used directly. 29 | 30 | ### Requirements 31 | 32 | - [MATLAB ®](https://www.mathworks.com/products/matlab.html) version R2022b or later 33 | - [Deep Learning Toolbox ™](https://www.mathworks.com/products/deep-learning.html) 34 | - [Deep Learning Toolbox Model Quantization Library](https://www.mathworks.com/matlabcentral/fileexchange/74614-deep-learning-toolbox-model-quantization-library) 35 | 36 | ## About Quantization Aware Training 37 | 38 | This example focuses on the steps of a quantization workflow: 39 | 40 | - Replace quantizable layers in a floating-point network with quantization aware training layers. 41 | - Train with the quantizable training layers until reaching convergence. 42 | - Replace the quantizable training layers back with the original layers with updated learnables more robust to quantization. 43 | - Perform post-training quantization on this network to produce a quantized int8 network. 44 | 45 | ![Quantization Aware Workflow Steps](./images/qat_workflow.png) 46 | 47 | During training, the quantization aware convolution layers quantized the weights and activations of the layer at each forward pass. The function, `quantizeToFloat` is used to quantize the values to a floating point representation using `single` type. This operation is akin to quantizing a value to integer and then immediately rescaling the value back to the real-world representation. 48 | 49 | As an example, `quantizeToFloat` would take an input value `365.247` and calculates a scaling factor that is used to scale the value to an integer representation of `91`. The integer value of `91` is then rescaled back to `364` introducing a absolute error of `-1.247`. 50 | 51 | $$ 52 | \begin{align} 53 | \hat x &= quantizeToFloat\left(\mathrm{𝑥}\right) \\ 54 | \ &= \mathrm{unquantize}\left(\mathrm{quantize}\left(\mathrm{𝑥}\right)\right) \\ 55 | \ &= \mathrm{rescale}\cdot \mathrm{saturate}\left(\mathrm{round}\left(\frac{\mathrm{𝑥}}{\mathrm{scale}}\right)\right) 56 | \end{align} 57 | $$ 58 | 59 | The quantization step uses a non-differentiable operation `round` that would normally break the training workflow by zeroing out the gradients. During quantization aware training, bypass the gradient calculations for non-differentiable operations using an identity function. The diagram below \[2\] shows how the custom layer calculates the gradients for non-differentiable operations with the identity function via straight-through estimation. 60 | 61 | ![Straight Through Estimation](./images/ste.png) 62 | 63 | After training, the network returned from the `trainNetwork` function still has the quantization aware training layers. Replace the quantization aware training operators with operators that are specific to inference. Whereas the training graph operates on pseudo-quantized 32-bit floating-point values, in the inference graph, the network applies the convolution using `int8` inputs and weights. 64 | 65 | | Conovolution Operation Graph at Training | Convolution Operation Graph at Inference | 66 | | ----------- | ----------- | 67 | | ![Quantized operators during training](./images/quantized_training.png) | ![Quantized operators during inference](./images/quantized_inference.png)| 68 | 69 | ## **References** 70 | 71 | 1. The TensorFlow Team. Flowers [http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz) 72 | 2. Gholami, A., Kim, S., Dong, Z., Mahoney, M., & Keutzer, K. (2021). A Survey of Quantization Methods for Efficient Neural Network Inference. Retrieved from [https://arxiv.org/abs/2103.13630](https://arxiv.org/abs/2103.13630) 73 | 3. Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., Adam, H., & Kalenichenko, D. (2017). Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. Retrieved from [https://arxiv.org/abs/1712.05877](https://arxiv.org/abs/1712.05877) 74 | 75 | Copyright 2023 The MathWorks, Inc. 76 | 77 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Vulnerabilities 2 | 3 | If you believe you have discovered a security vulnerability, please report it to 4 | [security@mathworks.com](mailto:security@mathworks.com). Please see 5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html) 6 | for additional information. 7 | -------------------------------------------------------------------------------- /bypassdlgradients.m: -------------------------------------------------------------------------------- 1 | % BYPASSDLGRADIENTS Bypass gradient for non-differentiable operations 2 | % Y = bypassdlgradients(FUN, X) evaluates FUN(X) 3 | % while overriding the derivitive calculation used during backward 4 | % propogation to an identity function instead. 5 | % 6 | % Examples: 7 | % a = dlarray([1.0 2.5]); % point at which to evaluate gradient 8 | % 9 | % % non-differentiable gradient calculation 10 | % function [y,grad] = objectiveAndGradient(x) 11 | % y = round(x(1) + x(2)); 12 | % grad = dlgradient(y,x); 13 | % end 14 | % [val,grad] = dlfeval(@objectiveAndGradient,a); 15 | % % val is dlarray(4) 16 | % % grad is dlarray([0 0]) 17 | % 18 | % % non-differentiable gradient calculation with a straight-through 19 | % % estimator for the 'round' function 20 | % function [y,grad] = steObjectiveAndGradient(x) 21 | % y = BYPASSDLGRADIENTS(@round, x(1) + x(2) ); 22 | % grad = dlgradient(y,x); 23 | % end 24 | % [val,grad] = dlfeval(@steObjectiveAndGradient,a); 25 | % % val is dlarray(4) 26 | % % grad is dlarray([1 1]) 27 | % 28 | % See also: DLARRAY, DLACCELERATE, EXTRACTDATA 29 | % 30 | 31 | % Copyright 2023 The Mathworks, Inc. 32 | -------------------------------------------------------------------------------- /bypassdlgradients.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/bypassdlgradients.p -------------------------------------------------------------------------------- /foldBatchNormalizationParameters.m: -------------------------------------------------------------------------------- 1 | % foldBatchNormalizationParameters Adjusts Convolution Learnables for Fusion 2 | % Calculates the adjusted learnables of a convolution layer from a 3 | % would-be fusion with a batch normalization layer. 4 | % 5 | % [ADJUSTEDWEIGHTS, ADJUSTEDBIAS] = foldBatchNormalizationParameters(NET) a dlnetwork with a Convolution2D or GroupedConvolution2D layer 6 | % as the first layer and BatchNormalization layer as the second layer, 7 | % return the adjusted weights and bias of the convolution layer 8 | % 9 | 10 | % Copyright 2023 The Mathworks, Inc. -------------------------------------------------------------------------------- /foldBatchNormalizationParameters.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/foldBatchNormalizationParameters.p -------------------------------------------------------------------------------- /images/original_inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/images/original_inference.png -------------------------------------------------------------------------------- /images/qat_workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/images/qat_workflow.png -------------------------------------------------------------------------------- /images/quantized_inference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/images/quantized_inference.png -------------------------------------------------------------------------------- /images/quantized_training.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/images/quantized_training.png -------------------------------------------------------------------------------- /images/ste.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/quantization-aware-training/d0d0038bfeb9d89fc1a12d35423ed47c523d12a1/images/ste.png -------------------------------------------------------------------------------- /quantizeToFloat.m: -------------------------------------------------------------------------------- 1 | function value = quantizeToFloat(value) 2 | %% quantizeToFloat Quantizes a value and rescales back to floating point 3 | % 4 | % quantizedValue = quantizeToFloat(value) returns a floating point 5 | % value that has been quantized using best precision scaling. 6 | % 7 | % Example: 8 | % quantizedValue = quantizeToFloat(dlarray(single(365.247))) 9 | 10 | % Copyright 2023 The Mathworks, Inc. 11 | 12 | % Calculate the ideal scaling factor using the input range. 13 | m = extractdata(gather(max(abs(value(:))))); 14 | scalingFactor = double(floor(log2(m))); 15 | % Adjust the scaling factor by 6. 8 bit wordlength - 1 16 | % sign - 1 floor 17 | scalingFactor = scalingFactor - 6; 18 | 19 | % Scale the value using the calculated scaling factor. 20 | value = scaleValue(value, scalingFactor); 21 | 22 | % Saturate to int8 range. 23 | value = saturateValue(value); 24 | 25 | % Round values while bypassing the dlgradient calculation. 26 | value = bypassdlgradients(@round, value); 27 | 28 | % Rescale values to single range. 29 | value = rescaleValue(value, scalingFactor); 30 | end 31 | 32 | function value = scaleValue(value, scalingFactor) 33 | % Scale the value using the calculated scaling factor. 34 | value = value*single((2^( -1*scalingFactor ))); 35 | end 36 | 37 | function value = saturateValue(value) 38 | % Saturate to int8 range 39 | value = max(value,-128); % intmin('int8') 40 | value = min(value, 127); % intmax('int8') 41 | end 42 | 43 | function value = rescaleValue(value, scalingFactor) 44 | % Rescale values to single range. 45 | value = value*single(2^scalingFactor); 46 | end --------------------------------------------------------------------------------