├── README.md ├── part1-export-onnx.py ├── part2-Symmetric_quantization.py ├── part3-Asymmetric_dequantize.py ├── part4-Kl.py ├── part5-TensorRT-KL.py ├── part6.py ├── ~$量化PPT.pptx └── 量化PPT.pptx /README.md: -------------------------------------------------------------------------------- 1 | 2 | # TensorRT Quantization Tutorial 3 | ***if you feel the tutorial is good! why not give us a star!⭐⭐⭐*** 4 | 5 | This repository relates two main sections: **Fundamentals** and **Practical Application**, aiming to provide a comprehensive guide on model quantization in TensorRT. 6 | 7 | ## Fundamentals 8 | 9 | Both the video and code for this section are completely open-source. 10 | 11 | - **Video Tutorial**: [Bilibili Link](https://www.bilibili.com/video/BV18L41197Uz/?spm_id_from=333.788&vd_source=eefa4b6e337f16d87d87c2c357db8ca7) 12 | - **Code Repository**: [GitHub Link](https://github.com/shouxieai/tensorRT_quantization) 13 | 14 | ### Table of Contents 15 | 16 | 1. **Principles of Model Quantization** 17 | - 1.1 Definition and Significance of Quantization 18 | - 1.1.1 Model Weight Analysis 19 | - 1.1.2 Importance of Quantization 20 | - 1.2 Symmetric vs Asymmetric Quantization 21 | - 1.2.1 Definition of Symmetric Quantization 22 | - 1.2.2 Handwritten Code for Symmetric Quantization 23 | - 1.2.3 Definition of Asymmetric Quantization 24 | - 1.2.4 Handwritten Code for Asymmetric Quantization 25 | - 1.3 Common Methods for Dynamic Range Calculation 26 | - 1.3.1 Max 27 | - 1.3.2 Histogram 28 | - 1.3.3 Entropy 29 | - 1.4 Introduction to PTQ and QAT 30 | - 1.5 Handwriting a Quantized Program with Ops 31 | 2. **TensorRT Quantization Library** 32 | - 2.1 Understanding Quantizer 33 | - 2.2 Understanding InputQuant/MixQuant 34 | - 2.3 Automatic Insertion of QDQ Nodes 35 | - 2.4 Manual Insertion of QDQ Nodes 36 | - 2.5 How to Quantize a Custom Layer 37 | - 2.6 Sensitivity Layer Analysis 38 | - 2.7 Pitfalls and Lessons Learned 39 | 40 | ## Practical Application 41 | 42 | The practical application section is paid content. Please visit the link below to purchase: 43 | 44 | [Buy Now](#) 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | # TensorRT 量化教程 53 | 54 | ***如果你觉得这个教程很赞,欢迎给星星哟!⭐⭐⭐*** 55 | 56 | 本仓库分为涉及两个部分:**基础知识**和**实战应用**,旨在全面讲解 TensorRT 下的模型量化。 57 | 58 | 59 | ## 基础知识 60 | 61 | 该部分的视频和代码完全开源。 62 | 63 | - **视频教程**:[B站链接](https://www.bilibili.com/video/BV18L41197Uz/?spm_id_from=333.788&vd_source=eefa4b6e337f16d87d87c2c357db8ca7) 64 | - **代码仓库**:[GitHub链接](https://github.com/shouxieai/tensorRT_quantization) 65 | 66 | ### 目录 67 | 68 | 1. **模型量化原理** 69 | - 1.1 量化的定义及意义 70 | - 1.1.1 模型权重分析 71 | - 1.1.2 量化的意义 72 | - 1.2 对称量化与非对称量化 73 | - 1.2.1 对称量化的定义 74 | - 1.2.2 对称量化代码手写 75 | - 1.2.3 非对称量化的定义 76 | - 1.2.4 非对称量化代码手写 77 | - 1.3 动态范围的常用计算方法 78 | - 1.3.1 Max 79 | - 1.3.2 Histgram 80 | - 1.3.3 Entropy 81 | - 1.4 PTQ 与 QAT 介绍 82 | - 1.5 手写一个带 op 的量化程序 83 | 2. **TensorRT Quantization Library** 84 | - 2.1 Quantizer 的理解 85 | - 2.2 InputQuant/MixQuant 的理解 86 | - 2.3 自动插入 QDQ 节点 87 | - 2.4 手动插入 QDQ 节点 88 | - 2.5 如何量化一个自定义层 89 | - 2.6 敏感层分析 90 | - 2.7 踩坑实录 91 | 92 | ## 实战应用 93 | 94 | 实战部分内容需要付费购买,请访问以下链接: 95 | 96 | [购买链接](#) 97 | 98 | -------------------------------------------------------------------------------- /part1-export-onnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | 4 | model = models.resnet50(pretrained=True) 5 | 6 | input = torch.randn(1, 3, 224, 224) 7 | torch.onnx.export(model, input, "resnet50-1.onnx") -------------------------------------------------------------------------------- /part2-Symmetric_quantization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def saturate(x, int_max, int_min): 4 | return np.clip(x,int_min, int_max) 5 | 6 | def scale_z_cal(x, int_max, int_min): 7 | scale = (x.max() - x.min())/(int_max - int_min) 8 | z = int_max - np.round((x.max()/scale)) 9 | return scale, z 10 | 11 | def quant_float_data(x, scale, z, int_max, int_min): 12 | xq = saturate( np.round(x/scale + z), int_max, int_min) 13 | return xq 14 | 15 | def dequant_data(xq, scale, z): 16 | x = ((xq - z)*scale).astype('float32') 17 | return x 18 | 19 | 20 | if __name__ == '__main__': 21 | np.random.seed(1) 22 | data_float32 = np.random.randn(3).astype('float32') 23 | data_float32[0] = -0.61 24 | data_float32[1] = -0.52 25 | data_float32[2] = 1.62 26 | print("input",data_float32) 27 | int_max = 255 28 | int_min = 0 29 | 30 | scale, z = scale_z_cal(data_float32, int_max, int_min) 31 | print("scale and z ",scale, z) 32 | data_int8 = quant_float_data(data_float32, scale, z, int_max, int_min) 33 | print("quant result ",data_int8) 34 | data_dequnat_float = dequant_data(data_int8, scale, z) 35 | print("dequant result ",data_dequnat_float) 36 | 37 | print('diff',data_dequnat_float- data_float32 ) -------------------------------------------------------------------------------- /part3-Asymmetric_dequantize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def saturate(x): 4 | return np.clip(x, -127, 127) 5 | 6 | def scale_cal(x): 7 | max_val = np.max(np.abs(x)) 8 | return max_val/127 9 | 10 | def quant_float_data(x, scale): 11 | xq = np.round(x / scale) 12 | return saturate(xq) 13 | 14 | def dequant_data(xq, scale): 15 | x = (xq*scale).astype('float32') 16 | return x 17 | 18 | def histgram_range(x): 19 | hist, range = np.histogram(x, 100) 20 | total = len(x) 21 | left = 0 22 | right = len(hist) -1 23 | limit = 0.99 24 | while True: 25 | cover_percent = hist[left:right].sum()/total 26 | if cover_percent<=limit: 27 | break 28 | 29 | if hist[left] < hist[right]: 30 | left+=1 31 | else: 32 | right -=1 33 | 34 | left_val = range[left] 35 | right_val = range[right] 36 | dynamic_range = max(abs(left_val), abs(right_val)) 37 | return dynamic_range/127. 38 | 39 | 40 | if __name__ == '__main__': 41 | np.random.seed(1) 42 | 43 | data_float32 = np.random.randn(1000).astype('float32') 44 | print('input ',data_float32) 45 | scale = scale_cal(data_float32) 46 | scale2 = histgram_range(data_float32) 47 | print(scale,scale2 ) 48 | exit(1) 49 | xq= quant_float_data(data_float32, scale) 50 | print('quant result ',xq) 51 | xdq=dequant_data(xq, scale) 52 | print('dequant result ',xdq) 53 | print('diff ',xdq-data_float32) -------------------------------------------------------------------------------- /part4-Kl.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | def smooth_data(p, eps = 0.0001): 5 | is_zeros = (p==0).astype(np.float32) 6 | is_nonzeros = (p!=0).astype(np.float32) 7 | n_zeros = is_zeros.sum() 8 | n_nonzeros = p.size - n_zeros 9 | 10 | eps1 = eps*n_zeros/n_nonzeros 11 | hist = p.astype(np.float32) 12 | hist += eps*is_zeros + (-eps1)*is_nonzeros 13 | return hist 14 | 15 | 16 | def cal_kl(p, q): 17 | KL = 0. 18 | for i in range(len(p)): 19 | KL += p[i]* np.log(p[i]/(q[i])) 20 | return KL 21 | 22 | def kl_test(x, kl_threshold = 0.01 ,size =10): 23 | y_out = [] 24 | while True: 25 | y = [ np.random.uniform(1, size+1) for i in range(size)] 26 | y /= np.sum(y) 27 | kl_result = cal_kl(x, y) 28 | if kl_result < kl_threshold: 29 | print(kl_result) 30 | y_out = y 31 | plt.plot(x) 32 | plt.plot(y) 33 | break 34 | return y_out 35 | 36 | def KL_main(): 37 | np.random.seed(1) 38 | size = 10 39 | x = [ np.random.uniform(1, size+1) for i in range(size)] 40 | x = x / np.sum(x) 41 | y_out = kl_test(x,kl_threshold = 0.01) 42 | plt.show() 43 | print(x, y_out) 44 | 45 | if __name__ == '__main__': 46 | p = [1, 0, 2, 3, 5, 3, 1, 7] 47 | bin = 4 48 | split_p = np.array_split(p, bin) 49 | q = [] 50 | for arr in split_p: 51 | avg = np.sum(arr)/ np.count_nonzero(arr) 52 | for item in arr: 53 | if item !=0: 54 | q.append(avg) 55 | continue 56 | q.append(0) 57 | print(q) 58 | p /= np.sum(p) 59 | q /= np.sum(q) 60 | print(p) 61 | print(q) 62 | p = smooth_data(p) 63 | q = smooth_data(q) 64 | print(p) 65 | print(q) 66 | #cal kl 67 | print(cal_kl(p, q)) 68 | 69 | -------------------------------------------------------------------------------- /part5-TensorRT-KL.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | def generator_P(size): 5 | walk = [] 6 | avg = random.uniform(3.000, 600.999) 7 | std = random.uniform(500.000, 1024.959) 8 | for _ in range(size): 9 | walk.append(random.gauss(avg, std)) 10 | return walk 11 | 12 | def smooth_distribution(p, eps=0.0001): 13 | is_zeros = (p == 0).astype(np.float32) 14 | is_nonzeros = (p != 0).astype(np.float32) 15 | n_zeros = is_zeros.sum() 16 | n_nonzeros = p.size - n_zeros 17 | if not n_nonzeros: 18 | raise ValueError('The discrete probability distribution is malformed. All entries are 0.') 19 | eps1 = eps * float(n_zeros) / float(n_nonzeros) 20 | assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) 21 | hist = p.astype(np.float32) 22 | hist += eps * is_zeros + (-eps1) * is_nonzeros 23 | assert (hist <= 0).sum() == 0 24 | return hist 25 | 26 | import copy 27 | import scipy.stats as stats 28 | def threshold_distribution(distribution, target_bin=128): 29 | distribution = distribution[1:] 30 | length = distribution.size 31 | threshold_sum = sum(distribution[target_bin:]) 32 | kl_divergence = np.zeros(length - target_bin) 33 | 34 | for threshold in range(target_bin, length): 35 | sliced_nd_hist = copy.deepcopy(distribution[:threshold]) 36 | 37 | # generate reference distribution p 38 | p = sliced_nd_hist.copy() 39 | p[threshold - 1] += threshold_sum 40 | threshold_sum = threshold_sum - distribution[threshold] 41 | 42 | # is_nonzeros[k] indicates whether hist[k] is nonzero 43 | is_nonzeros = (p != 0).astype(np.int64) 44 | 45 | quantized_bins = np.zeros(target_bin, dtype=np.int64) 46 | # calculate how many bins should be merged to generate 47 | # quantized distribution q 48 | num_merged_bins = sliced_nd_hist.size // target_bin 49 | 50 | # merge hist into num_quantized_bins bins 51 | for j in range(target_bin): 52 | start = j * num_merged_bins 53 | stop = start + num_merged_bins 54 | quantized_bins[j] = sliced_nd_hist[start:stop].sum() 55 | quantized_bins[-1] += sliced_nd_hist[target_bin * num_merged_bins:].sum() 56 | 57 | # expand quantized_bins into p.size bins 58 | q = np.zeros(sliced_nd_hist.size, dtype=np.float64) 59 | for j in range(target_bin): 60 | start = j * num_merged_bins 61 | if j == target_bin - 1: 62 | stop = -1 63 | else: 64 | stop = start + num_merged_bins 65 | norm = is_nonzeros[start:stop].sum() 66 | if norm != 0: 67 | q[start:stop] = float(quantized_bins[j]) / float(norm) 68 | 69 | p = smooth_distribution(p) 70 | q = smooth_distribution(q) 71 | 72 | # calculate kl_divergence between q and p 73 | kl_divergence[threshold - target_bin] = stats.entropy(p, q) 74 | 75 | min_kl_divergence = np.argmin(kl_divergence) 76 | threshold_value = min_kl_divergence + target_bin 77 | 78 | return threshold_value 79 | 80 | if __name__ == '__main__': 81 | 82 | size = 20480 83 | P = generator_P(size) 84 | P = np.array(P) 85 | P = P[P>0] 86 | print("最大的激活值", max(np.absolute(P))) 87 | 88 | hist, bins = np.histogram(P, bins =2048) 89 | threshold = threshold_distribution(hist, target_bin=128) 90 | print("threshold 所在组:", threshold) 91 | print("threshold 所在组的区间范围:", bins[threshold]) 92 | # 分成split_zie组, density表示是否要normed 93 | plt.title("Relu activation value Histogram") 94 | plt.xlabel("Activation values") 95 | plt.ylabel("Normalized number of Counts") 96 | plt.hist(P, bins=2047) 97 | plt.vlines(bins[threshold], 0, 30, colors = "r", linestyles = "dashed") 98 | plt.show() -------------------------------------------------------------------------------- /part6.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from pytorch_quantization import tensor_quant 4 | from pytorch_quantization import quant_modules 5 | from pytorch_quantization import nn as quant_nn 6 | from pytorch_quantization.nn.modules import _utils as quant_nn_utils 7 | from pytorch_quantization import calib 8 | from typing import List, Callable, Union, Dict 9 | 10 | class disable_quantization: 11 | def __init__(self, model): 12 | self.model = model 13 | 14 | def apply(self, disabled=True): 15 | for name, module in self.model.named_modules(): 16 | if isinstance(module, quant_nn.TensorQuantizer): 17 | module._disabled = disabled 18 | 19 | def __enter__(self): 20 | self.apply(True) 21 | 22 | def __exit__(self, *args, **kwargs): 23 | self.apply(False) 24 | 25 | 26 | class enable_quantization: 27 | def __init__(self, model): 28 | self.model = model 29 | 30 | def apply(self, enabled=True): 31 | for name, module in self.model.named_modules(): 32 | if isinstance(module, quant_nn.TensorQuantizer): 33 | module._disabled = not enabled 34 | 35 | def __enter__(self): 36 | self.apply(True) 37 | return self 38 | 39 | def __exit__(self, *args, **kwargs): 40 | self.apply(False) 41 | 42 | 43 | def quantizer_state(module): 44 | for name, module in module.named_modules(): 45 | if isinstance(module, quant_nn.TensorQuantizer): 46 | print(name, module) 47 | 48 | 49 | def transfer_torch_to_quantization(nninstance : torch.nn.Module, quantmodule): 50 | 51 | quant_instance = quantmodule.__new__(quantmodule) 52 | for k, val in vars(nninstance).items(): 53 | setattr(quant_instance, k, val) 54 | 55 | def __init__(self): 56 | 57 | if isinstance(self, quant_nn_utils.QuantInputMixin): 58 | quant_desc_input = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__,input_only=True) 59 | self.init_quantizer(quant_desc_input) 60 | 61 | # Turn on torch_hist to enable higher calibration speeds 62 | if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator): 63 | self._input_quantizer._calibrator._torch_hist = True 64 | else: 65 | quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__) 66 | self.init_quantizer(quant_desc_input, quant_desc_weight) 67 | 68 | # Turn on torch_hist to enable higher calibration speeds 69 | if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator): 70 | self._input_quantizer._calibrator._torch_hist = True 71 | self._weight_quantizer._calibrator._torch_hist = True 72 | 73 | __init__(quant_instance) 74 | return quant_instance 75 | 76 | 77 | def replace_to_quantization_module(model : torch.nn.Module, ignore_policy : Union[str, List[str], Callable] = None): 78 | 79 | module_dict = {} 80 | for entry in quant_modules._DEFAULT_QUANT_MAP: 81 | module = getattr(entry.orig_mod, entry.mod_name) 82 | module_dict[id(module)] = entry.replace_mod 83 | 84 | def recursive_and_replace_module(module, prefix=""): 85 | for name in module._modules: 86 | submodule = module._modules[name] 87 | path = name if prefix == "" else prefix + "." + name 88 | recursive_and_replace_module(submodule, path) 89 | 90 | submodule_id = id(type(submodule)) 91 | if submodule_id in module_dict: 92 | module._modules[name] = transfer_torch_to_quantization(submodule, module_dict[submodule_id]) 93 | 94 | recursive_and_replace_module(model) 95 | 96 | 97 | #quant_modules.initialize() 98 | model = torchvision.models.resnet50() 99 | model.cuda() 100 | # disable_quantization(model.conv1).apply() 101 | # quantizer_state(model) 102 | replace_to_quantization_module(model) 103 | inputs = torch.randn(1, 3, 224, 224, device='cuda') 104 | quant_nn.TensorQuantizer.use_fb_fake_quant =True 105 | torch.onnx.export(model, inputs, 'quant_resnet50_replace_to_quantization.onnx',opset_version=13) 106 | -------------------------------------------------------------------------------- /~$量化PPT.pptx: -------------------------------------------------------------------------------- 1 | Peter Huang Peter Huang -------------------------------------------------------------------------------- /量化PPT.pptx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shouxieai/tensorRT_quantization/22dce4a860241b8e1e1d4d96d56bdd1f7f9a9235/量化PPT.pptx --------------------------------------------------------------------------------