├── .github └── workflows │ ├── lint.yml │ └── ppq_simulator.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── MANIFEST.in ├── ProgramEntrance_1.py ├── ProgramEntrance_2.py ├── README.md ├── assets ├── OpenPPL.jpg ├── QQGroup.jpg └── logo.png ├── md_doc ├── Passes │ ├── BiasCorrectionPass.md │ ├── IsotoneCalibrationPass.md │ ├── LayerSpilit.md │ ├── LayerwiseEqualization.md │ ├── LearnedStepSizePass.md │ ├── QuantAlignment.md │ ├── QuantFusion.md │ ├── QuantSimplify.md │ └── RuntimeCalibrationPass.md ├── deploy_for_mnn.md ├── deploy_trt_by_OnnxParser.md ├── deploy_trt_by_api.md ├── how_to_use.md ├── inference_with_ncnn.md ├── inference_with_ppl_cuda.md └── inference_with_snpe_dsp.md ├── ppq ├── IR │ ├── README.md │ ├── __init__.py │ ├── base │ │ ├── __init__.py │ │ ├── command.py │ │ ├── graph.py │ │ └── opdef.py │ ├── deploy.py │ ├── morph.py │ ├── processer.py │ ├── quantize.py │ ├── search.py │ └── training.py ├── README.md ├── __init__.py ├── api │ ├── __init__.py │ ├── fsys.py │ ├── interface.py │ └── setting.py ├── core │ ├── README.md │ ├── __init__.py │ ├── common.py │ ├── config.py │ ├── data.py │ ├── defs.py │ ├── ffi.py │ ├── quant.py │ └── storage.py ├── csrc │ ├── build │ │ ├── .gitignore │ │ └── readme.md │ ├── cpu │ │ ├── common.h │ │ ├── hist_mse.cc │ │ └── hist_mse.h │ ├── cuda │ │ ├── PPQ.h │ │ ├── common.cuh │ │ ├── common.h │ │ ├── floating.cu │ │ ├── floating.h │ │ ├── isotone.cc │ │ ├── isotone.h │ │ ├── linear.cu │ │ ├── linear.h │ │ ├── sort.cu │ │ ├── sort.h │ │ ├── test.cu │ │ ├── test.h │ │ ├── train.cu │ │ └── train.h │ └── export.cc ├── executor │ ├── README.md │ ├── __init__.py │ ├── base.py │ ├── op │ │ ├── __init__.py │ │ ├── fp32 │ │ │ └── fp32_backend.py │ │ ├── ppl_dsp │ │ │ └── ppl_dsp_backend.py │ │ ├── ppl_trt │ │ │ └── ppl_gpu_backend.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── academic.py │ │ │ ├── base.py │ │ │ ├── cuda.py │ │ │ ├── default.py │ │ │ ├── dsp.py │ │ │ ├── extension.py │ │ │ ├── nxp.py │ │ │ └── onnx.py │ └── torch.py ├── lib │ ├── README.md │ ├── __init__.py │ ├── common.py │ ├── extension.py │ └── quant.py ├── log │ ├── __init__.py │ └── logger.py ├── parser │ ├── __init__.py │ ├── ascend_export.py │ ├── caffe │ │ ├── __init__.py │ │ ├── caffe.proto │ │ ├── caffe_export_utils.py │ │ ├── caffe_graph_optim.py │ │ ├── caffe_import_utils.py │ │ └── ppl_caffe_pb2.py │ ├── caffe_exporter.py │ ├── caffe_parser.py │ ├── extension.py │ ├── mnn_exporter.py │ ├── native.py │ ├── ncnn_exporter.py │ ├── nxp_exporter.py │ ├── onnx_exporter.py │ ├── onnx_parser.py │ ├── onnxruntime_exporter.py │ ├── openvino_exporter.py │ ├── ppl.py │ ├── qnn_exporter.py │ ├── tengine_exporter.py │ ├── tensorRT.py │ └── util.py ├── qat │ └── core.py ├── quantization │ ├── __init__.py │ ├── algorithm │ │ ├── __init__.py │ │ ├── equalization.py │ │ ├── exprimental.py │ │ └── training.py │ ├── analyse │ │ ├── __init__.py │ │ ├── graphwise.py │ │ ├── layerwise.py │ │ └── util │ │ │ └── __init__.py │ ├── measure │ │ ├── __init__.py │ │ ├── cosine.py │ │ ├── norm.py │ │ └── statistic.py │ ├── observer │ │ ├── __init__.py │ │ ├── base.py │ │ ├── floating.py │ │ ├── order.py │ │ └── range.py │ ├── optim │ │ ├── README.md │ │ ├── __init__.py │ │ ├── baking.py │ │ ├── base.py │ │ ├── calibration.py │ │ ├── equalization.py │ │ ├── exprimental.py │ │ ├── extension.py │ │ ├── legacy.py │ │ ├── morph.py │ │ ├── parameters.py │ │ ├── refine.py │ │ ├── ssd.py │ │ └── training.py │ ├── qfunction │ │ ├── README.md │ │ ├── __init__.py │ │ ├── base.py │ │ ├── floating.py │ │ └── linear.py │ └── quantizer │ │ ├── AscendQuantizer.py │ │ ├── DSPQuantizer.py │ │ ├── FP8Quantizer.py │ │ ├── FPGAQuantizer.py │ │ ├── MNNQuantizer.py │ │ ├── MetaxQuantizer.py │ │ ├── MyQuantizer.py │ │ ├── NCNNQuantizer.py │ │ ├── NXPQuantizer.py │ │ ├── ORTQuantizer.py │ │ ├── OpenvinoQuantizer.py │ │ ├── PPLQuantizer.py │ │ ├── README.md │ │ ├── RKNNQuantizer.py │ │ ├── TengineQuantizer.py │ │ ├── TensorRTQuantizer.py │ │ ├── __init__.py │ │ └── base.py ├── samples │ ├── FP8 │ │ └── fp8_sample.py │ ├── Imagenet │ │ ├── Utilities │ │ │ ├── Imagenet │ │ │ │ ├── __init__.py │ │ │ │ └── imagenet_util.py │ │ │ └── __init__.py │ │ └── evaluation_with_imagenet.py │ ├── Onnxruntime │ │ ├── Example_Benchmark.py │ │ ├── Example_Fp32.py │ │ └── Example_PTQ.py │ ├── Openvino │ │ ├── Example_Benchmark.py │ │ ├── Example_Fp32.py │ │ ├── Example_PTQ.py │ │ └── Example_QAT.py │ ├── QAT │ │ ├── imagenet.py │ │ ├── myquantizer.py │ │ └── trainer.py │ ├── QuantZoo │ │ ├── QuantZoo_Imagenet.py │ │ ├── QuantZoo_OCR.py │ │ ├── QuantZoo_Pose.py │ │ ├── QuantZoo_Segmentation.py │ │ ├── QuantZoo_SuperRes.py │ │ ├── QuantZoo_Yolo.py │ │ └── Readme.md │ ├── RKNN │ │ └── Example_PTQ.py │ ├── TensorRT │ │ ├── Benchmark_with_onnx.py │ │ ├── Example_Benchmark.py │ │ ├── Example_Fp32.py │ │ ├── Example_PTQ.py │ │ ├── Example_Profiling.py │ │ ├── Example_QAT.py │ │ ├── Example_Torch2trt.py │ │ ├── create_engine.py │ │ ├── lenet_demo │ │ │ ├── CMakeLists.txt │ │ │ ├── common │ │ │ │ ├── ErrorRecorder.h │ │ │ │ ├── common.h │ │ │ │ ├── logger.cpp │ │ │ │ ├── logger.h │ │ │ │ ├── logging.h │ │ │ │ └── macros.h │ │ │ ├── generate_onnx.py │ │ │ ├── lenet_int8.cpp │ │ │ └── lenet_int8.py │ │ └── trt_infer.py │ ├── Tutorial │ │ ├── analyse.py │ │ ├── bestPractice.py │ │ ├── calibration.py │ │ ├── dequantize.py │ │ ├── dispatch.py │ │ ├── execute.py │ │ ├── finetune.py │ │ ├── fusion.py │ │ ├── optimization.py │ │ ├── quantize.py │ │ └── targetPlatform.py │ ├── Yolo │ │ ├── 00_FloatModel.py │ │ ├── 01_Quantization.py │ │ ├── 02_Quantization.py │ │ ├── yolo_5.py │ │ └── yolo_x.py │ ├── bert_sample.py │ ├── bypass_nms.py │ ├── custimize_quant_func.py │ ├── custimized_quant.py │ ├── dynamic_shape.py │ ├── enable_cuda_kernel.py │ ├── fp8_sample.py │ ├── onnx_converter.py │ ├── quantize_caffe_model.py │ ├── quantize_dsp.py │ ├── quantize_onnx_model.py │ ├── quantize_torch_model.py │ └── yolo6_sample.py ├── scheduler │ ├── __init__.py │ ├── allin.py │ ├── base.py │ ├── dispatchers.py │ └── perseus.py └── utils │ ├── OnnxruntimeUtil.py │ ├── OpenvinoUtil.py │ ├── TensorRTUtil.py │ ├── __init__.py │ ├── attribute.py │ ├── ema.py │ ├── fetch.py │ ├── graph_editor.py │ ├── round.py │ ├── write_qparams_caffe2trt.py │ ├── write_qparams_onnx2trt.py │ └── write_qparams_to_snpe_dlc.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── testBnToConv.py ├── testFuseBias.py ├── test_activation_fusion.py ├── test_block.py ├── test_block_split.py ├── test_cuda_kernel.py ├── test_gemm_fusion.py ├── test_gemm_split.py ├── test_graph_api.py ├── test_isotone.py ├── test_layerwise_equalization.py ├── test_onnxruntime.py ├── test_persus.py ├── test_rounding.py ├── test_system.py ├── tmodel ├── __init__.py ├── base.py ├── testblocks │ ├── __init__.py │ └── blocks.py └── torchmodels │ └── __init__.py ├── tscheme ├── __init__.py └── base.py └── tworkingspace └── placeholder.py /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Set up Python 3.7 11 | uses: actions/setup-python@v2 12 | with: 13 | python-version: 3.7 14 | - name: Install pre-commit hook 15 | run: | 16 | pwd&&ls 17 | pip install pre-commit 18 | pre-commit --version 19 | - name: pre-commit checking 20 | run: | 21 | git fetch 22 | git branch -a 23 | git diff --name-status remotes/origin/master HEAD 24 | updated_files=`git diff --name-status remotes/origin/master HEAD|awk '{print$ 2}'` 25 | echo $updated_files 26 | for file in $updated_files 27 | do 28 | if [ "${file##*.}" = "py" ]; then 29 | echo $file 30 | pre-commit run --files $file --show-diff-on-failure 31 | fi 32 | done 33 | 34 | -------------------------------------------------------------------------------- /.github/workflows/ppq_simulator.yml: -------------------------------------------------------------------------------- 1 | name: linux-x86-64 2 | 3 | on: 4 | #push: 5 | # branches: [ master ] 6 | # paths-ignore: ['.**', 'docker/**', 'docs/**', 'samples/**', README.md] 7 | pull_request: 8 | branches: [ master ] 9 | paths-ignore: ['.**', 'docker/**', 'docs/**', 'samples/**', README.md ] 10 | workflow_dispatch: 11 | 12 | concurrency: 13 | group: ${{ github.workflow }}--${{ github.head_ref || github.run_id }}--${{ github.ref }}--${{ github.event_name }} 14 | cancel-in-progress: true 15 | 16 | jobs: 17 | build_and_test: 18 | runs-on: [self-hosted, linux, X64] 19 | 20 | steps: 21 | - name: Create Checkout Directory 22 | run: | 23 | echo "Create Checkout Directory: ${{ github.run_id }}." 24 | [ -z "${{ github.run_id }}" ] || rm -rf ${{ github.run_id }} 25 | mkdir ${{ github.run_id }} 26 | - name: Checkout 27 | uses: actions/checkout@v2 28 | with: 29 | path: ${{ github.run_id }} 30 | 31 | - name: Test 32 | run: | 33 | cd ../../ && ./test_ppq.sh ${{ github.run_id }} 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.pyc 2 | **/*.pth 3 | **/*.onnx 4 | **/*.out 5 | **/*.csv 6 | **/*.model 7 | __pycache__/ 8 | /build 9 | /dist 10 | /ppq.egg-info 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.1.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: check-yaml 7 | - id: end-of-file-fixer 8 | - id: requirements-txt-fixer 9 | - id: double-quote-string-fixer 10 | - id: check-merge-conflict 11 | - id: fix-encoding-pragma 12 | args: ["--remove"] 13 | - id: mixed-line-ending 14 | args: ["--fix=lf"] 15 | - repo: https://github.com/markdownlint/markdownlint 16 | rev: v0.11.0 17 | hooks: 18 | - id: markdownlint 19 | args: 20 | [ 21 | "-r", 22 | "~MD002,~MD013,~MD029,~MD033,~MD034", 23 | "-t", 24 | "allow_different_nesting", 25 | ] 26 | - repo: https://github.com/myint/docformatter 27 | rev: v1.4 28 | hooks: 29 | - id: docformatter 30 | args: ["--in-place", "--wrap-descriptions", "79"] 31 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ppq/parser/caffe/* 2 | include ppq/csrc/cuda/* 3 | include ppq/csrc/build/readme.md 4 | include ppq/csrc/cpu/* 5 | include ppq/csrc/export.cc 6 | include requirements.txt -------------------------------------------------------------------------------- /assets/OpenPPL.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/assets/OpenPPL.jpg -------------------------------------------------------------------------------- /assets/QQGroup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/assets/QQGroup.jpg -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/assets/logo.png -------------------------------------------------------------------------------- /md_doc/Passes/BiasCorrectionPass.md: -------------------------------------------------------------------------------- 1 | ## Bias Correction Optimization Pass(Bias 校准过程) 2 | 3 | Bias correction is the process of shifting quantized model outputs to account for their statistical errors. 4 | 5 | Network quantization will bring some error(noise) to the result. To improve the accuracy of a quantized model, we can correct the network by adding an extra term on bias in order to make the output has zero expectation. 6 | 7 | Bias correction is used to eliminate bias error, generally it will take a few mintues to correct all bias terms. 8 | 9 | For those layers have no bias, Bias Correction Optimization will skip them directly. 10 | 11 | let: Y = WX + b 12 | 13 | Quant(Y) = Qunat(W) Quant(X) + b 14 | 15 | bias_error = reduce_mean(Y - Quant(Y)) 16 | 17 | This pass will correct bias with: b = b + bias_error 18 | 19 | ### Parameters: 20 | 21 | * interested_layers(List[str]): 22 | 23 | A list of operation names, only the layers listed in this parameter will be processed. 24 | 25 | If interested_layers is None, all layers will be processed. 26 | 27 | * steps(int) 28 | 29 | Forward steps for collecting bias error, a large value of this parameter means more data will be collected so the bias error will be estimated better, while it takes more time. 30 | 31 | Usually 8 ~ 32 step is enough in most cases. 32 | 33 | * block_size(int) 34 | 35 | Bias Correction Optimization will split your graph into blocks, bias error will be collected and corrected block by block. 36 | 37 | A large block size will greatly reduce running time of this optimization, while it might give an unstable result when blocksize is too large. 38 | 39 | By default this value is set to 4, to have the best result of optimization, you are recommended to set blocksize = 1. 40 | 41 | * loss_fn(Callable) 42 | 43 | A function that used to measure the loss after optimization. 44 | 45 | Bias Correction Optimization is a training-based pass, we will check the loss at the end of block optimization. 46 | 47 | If the optimization created worsen result, the optimization result will be drop. 48 | 49 | ### Usage: 50 | 51 | Bias Correction Optimization Pass should be invoked after Runtime Calibration Pass. 52 | 53 | This pass is inclueded in PPQ Quantization Setting, you can calling this optimization by: 54 | 55 | setting = QuantizationSettingFactory.default_setting() 56 | 57 | setting.bias_correct = True 58 | 59 | # calling ppq.api.quantize_onnx_model function with this setting. 60 | ir = quantize_torch_model( 61 | model=model, calib_dataloader=load_calibration_dataset(), setting=setting, 62 | platform=TargetPlatform.PPL_CUDA_INT8, calib_steps=8, input_shape=INPUT_SHAPE, 63 | collate_fn=collate_fn) 64 | 65 | You can manully create this optimization by: 66 | 67 | from ppq import BiasCorrectionPass 68 | 69 | optim = BiasCorrectionPass() 70 | 71 | ### Version: 72 | 73 | Require PPQ 0.5.2 + 74 | 75 | Interface changed since PPQ 0.6.5 76 | -------------------------------------------------------------------------------- /md_doc/Passes/IsotoneCalibrationPass.md: -------------------------------------------------------------------------------- 1 | ## Isotone Calibration Pass(保序量化校准过程) 2 | 3 | 在神经网络中,一些算子的输出并不需要保证总体的精确性,而只关注于最大最小值所在的位置, 4 | 例如图像分类网络中,网络的输出通常是一个1000维的向量,用于表达图像属于特定类别的概率。 5 | 为了保证分类的正确性,我们并不需要这个1000维的向量在量化后是整体准确的,只需要其中的最大值出现在正确的位置上。 6 | 因此我们希望最大值与次大值之间相差至少半个 scale,并且次大值能够不被截断。 7 | 8 | 因此传统的 min-max, percentile, kl 方法在这一情景中并不能得到最高的分类精度, 9 | 保序量化是为了解决这一问题而设计的,在这一校准过程中,程序将网络输出变量的校准方式改写为 Isotone(保序校准)。 10 | 默认设置下,该过程只对 softmax 算子的输出进行保序校准。对于其他情况,用户需要手动指定需要进行保序校准的变量名。 11 | 12 | 保序量化需要设定一个分类轴,同样地以分类网络为例,其输出形为 [Batch, 1000]。 13 | 分类操作将在数据的最后一维展开,因此需要设置保序轴为 -1。 14 | 15 | Algorithm: 16 | 17 | For softmax or sigmoid activations, usually we just need 18 | argmax(softmax(x)) == argmax(softmax(quant(x))) 19 | 20 | Inspired by this Property, Isotone Observer is designed to provide an order-preserving calibration method, 21 | which cares only about argmax(x) [or argmin(x)] 22 | 23 | To keep argmax(x) == argmax(quant(x)), we only need to 24 | distinguish the largest element and the second largert element with quantization 25 | 26 | let L1 represents the largest element of x, 27 | while L2 represents the second largest. 28 | 29 | For Symmetrical Quantization, We want: 30 | 31 | 1. round(L1 / scale) - round(L2 / scale) > 0 32 | 33 | 2. round(L2 / scale) < quant_max 34 | 35 | Hence, we will have: 36 | 37 | 1. scale < 2 * (L1 - L2) 38 | 39 | 2. scale > L2 / (self._quant_cfg.quant_max - .5) 40 | 41 | For Asymmetircal Quantization, We want: 42 | 43 | 1. round(L1 / scale) + offset - round(L2 / scale) - offset > 0 44 | 45 | 2. round(L2 / scale) + offset < quant_max 46 | 47 | Hence, we will have: 48 | 49 | 1. scale < 2 * (L1 - L2) 50 | 51 | 2. scale > L2 / (self._quant_cfg.quant_max - offset - .5) 52 | 53 | The best setting of scale, offset can be solved by PPQ Isotone observer. 54 | 55 | Time Complexity: O(nlogn) -------------------------------------------------------------------------------- /md_doc/Passes/LayerSpilit.md: -------------------------------------------------------------------------------- 1 | ## Horizontal Layer Split Pass(算子分裂过程) 2 | 3 | Split convolution layers or GEMM layers for better performance. 4 | 5 | Formula: 6 | 7 | Y = W * X + b 8 | 9 | where W can be divided into W_1 + W_2 10 | 11 | Y = (W_1 * X + b) + (W_2 * X) 12 | 13 | By splitting W like this, we are able to represent W more accurately. 14 | In the case where one channel has weights in the range [-32, 32] and another channel has weights in the range [-0.5, 0.5]. 15 | the large channel will be divided so the range will come to [-16, 16], which leads us to use scale = 0.125 for representing 16 | the weight tensor rather than 0.25. 17 | 18 | The Estimation of Quantization Error is shown as a quadratic function of scale: 19 | 20 | E(Quantization Error) = scale ^ 2 / 12 21 | 22 | This Formula is proved by Bernard Widrow, according to the formula, a scale = 0.125 will decrease the quantization error by 75%. 23 | 24 | All the value larger than value_threshold will be divided into 2 part via this function, thus the layer itself will be 25 | splitted, an new Add operation are going to be created. 26 | 27 | ### Parameters: 28 | self.interested_layers = interested_layers 29 | self.value_threshold = value_threshold 30 | self.method = str(method).lower() 31 | self.verbose = verbose 32 | 33 | * interested_layers(List[str]) 34 | 35 | Only layer that listed in interested_layers will be processed by this pass. 36 | 37 | If interested_layers is None or empty list, NO layer will be processed. 38 | 39 | * value_threshold(float) 40 | 41 | This pass split value only when value is larger than value_threshold 42 | 43 | If there is no value large enough to be processed, corresponding layer will be skipped. 44 | 45 | * method(str) 46 | 47 | Splitting method, 'balance' or 'random' 48 | 49 | With balance method, W_1 and W_2 will be evenly divided. 50 | 51 | With random method, W_1 and W_2 will be randomly divided. 52 | 53 | ### Warning: 54 | 55 | Creating new operation in your network probably slows down the execution. 56 | 57 | Thus horizontal splitting is somehow a trade-off between speed and accuracy. 58 | 59 | ### Usage 60 | 61 | You can create this optimization manually: 62 | 63 | from ppq import HorizontalLayerSplitPass 64 | 65 | optim = HorizontalLayerSplitPass() 66 | -------------------------------------------------------------------------------- /md_doc/Passes/QuantAlignment.md: -------------------------------------------------------------------------------- 1 | ## PPQ Quant Alignment Pass(通用量化对齐过程) 2 | 3 | When deploy on real hardware and inference framework, 4 | we will find that there are various restrictions or rules that we have to follow. 5 | 6 | * AVERAGE_POOL_2D: Input and outputs must all have same scale/zero_point 7 | 8 | * CONCATENATION: Input and outputs must all have same scale/zero_point 9 | 10 | * SLICE: Input and outputs must all have same scale/zero_point 11 | 12 | More detailed restrictions please refer to: https://www.tensorflow.org/lite/performance/quantization_spec 13 | 14 | Those restrictions, can be concluded as some quantization should share 15 | the same quantization parameter with others. PPQ Quant Alignment Pass is designed 16 | for dealing with problems like this. 17 | 18 | PPQ uses Tensor Quantization Config (A data structure defined in ppq.core) to control the 19 | quantization logic, so to say if we want to align quantization parameters, we align 20 | their TQC in fact. 21 | 22 | The way to align TQC is simple, code like: 23 | tqc1.set_master(master=tqc2) 24 | Will make tqc1 and tqc2 share the same quantization parameters as tqc1 has, and change the 25 | state of tqc2 to be QuantizationState.SLAVE 26 | 27 | If we access the scale of tqc2, PPQ will return its master TQC's scale instead, so does offset. 28 | 29 | That is tqc1 and tqc2 are bonuded with statement "tqc1.set_master(master=tqc2)". 30 | 31 | ### Parameters: 32 | 33 | * elementwise_merge_method(Set[str]): 34 | 35 | Alignment method for elementwise ops. 36 | 37 | All elementwise ops are listed in ppq.core.common.py 38 | 39 | * concat_merge_method(bool) 40 | 41 | Alignment method for concat-like ops. 42 | 43 | All concat-like ops are listed in ppq.core.common.py 44 | 45 | * averagepool_method(bool) 46 | 47 | Alignment method for pooling-like ops. 48 | 49 | All pooling-like ops are listed in ppq.core.common.py 50 | 51 | * force_overlap(bool) 52 | 53 | TQC alignment might cause serious cascade effect. 54 | 55 | For subgraph like this: 56 | 57 | Conv1 --- 58 | + --- Add1 59 | Conv2 --- 60 | + --- Conv3 61 | 62 | If we demand Add1 to have same input scale, this alignment will affect Conv3 also, 63 | cause Conv2's output is feed to both Add1 and Conv3. 64 | 65 | If force_overlap = False, PPQ alignment procedure will remain the output scale of 66 | Conv2 as unchanged, while only align the input scale of Add1. 67 | 68 | If force_overlap = True, the input of Add1, Conv3 and the output of Conv2 will all 69 | be aligned to a same scale. 70 | 71 | ### Usage 72 | This pass is included in PPQ Quantization Setting, you can calling this optimization by: 73 | 74 | setting = QuantizationSettingFactory.default_setting() 75 | 76 | setting.fusion = True 77 | setting.fusion_setting.force_alignment_overlap = True 78 | 79 | # calling ppq.api.quantize_onnx_model function with this setting. 80 | ir = quantize_torch_model( 81 | model=model, calib_dataloader=load_calibration_dataset(), setting=setting, 82 | platform=TargetPlatform.PPL_CUDA_INT8, calib_steps=8, input_shape=INPUT_SHAPE, 83 | collate_fn=collate_fn) 84 | -------------------------------------------------------------------------------- /md_doc/Passes/QuantFusion.md: -------------------------------------------------------------------------------- 1 | ## PPQ Quantize Fusion Pass(通用量化图融合过程) 2 | 3 | Operation fusion (or kernel/layer fusion) is key optimization in many state-of-the-art execution frameworks. 4 | 5 | Graph fusion can combine operations into a single op to obtain higher accuracy and performance, 6 | Pattern like: Conv + Relu can be reduced to ConvRelu. This fusion will reduce memory accesses, 7 | and the quantization point after conv can also be removed. 8 | 9 | Technically we can fuse those layers before quantization, while fused layers are not supported by onnx standard. 10 | So to say ConvRelu is not a valid onnx operation, no execution framework can parse it. 11 | 12 | Therefore, PPQ will simulate the graph fusion by adjusting quantization config: if PPQ finds their is a 13 | pattern like Conv + Relu, the output quantization of Conv will be forbidden, pretending that the Conv + Relu 14 | fusion has happened. 15 | 16 | This Pass is designed for 2 types fusion: 17 | 18 | * activation fusion 19 | 20 | For activation fusion, PPQ will identify the pattern: Computing op + Activation Op from your network. The output 21 | quantization of computing op will be disabled with their state being set to QuantizationState.OVERLAPPED. 22 | 23 | Activation fusion here supports only simple activation patterns, 24 | for complex activation functions like mish, swish, 25 | will be represented as mish = tanh + mul + softplus, swish = sigmoid + mul in onnx, 26 | cause onnx does not have a op defination for them. 27 | Identifying those complex patterns requires pattern matching, which is implemented in ppq.IR.search.py 28 | 29 | Complex quantization fusions must be invoked manually, PPQ implemented softplus & swish fusion functions in 30 | ppq.quantization.optim.refine.MishFusionPass 31 | ppq.quantization.optim.refine.SwishFusionPass 32 | 33 | * passive operation fusion 34 | 35 | For passive operation fusion, PPQ will keep the input and the output variable share a same scale for passive operations. 36 | An operation is identified as passive op only if its attribute "is_active_quant_op" = False, this 37 | attribute is initilized by quantizer. 38 | 39 | If there is a passive operation having multiple input and output, the fusion procedure will make its 40 | FIRST input variable and ALL output variables share the same scale(the same scale as its first input). 41 | The quantization states of all output variables will be set to QuantizationState.OVERLAPPED. 42 | 43 | ### Parameters: 44 | 45 | * activation_type(Set[str]): 46 | 47 | A collection contains all activation types. 48 | 49 | The pattern will be recognized as [Computing Op -> Activation Op], 50 | 51 | By graph fusion, the output quantization of the Computing Op and 52 | the input quantization of the activation op will be disabled. 53 | 54 | * fuse_activation(bool) 55 | 56 | Whether to fuse activation op with computing op. 57 | 58 | * fuse_passive_op(bool) 59 | 60 | Whether to fuse passive op so that the input variable and output variable will share a same scale. 61 | 62 | ### Usage 63 | This pass is included in PPQ Quantization Setting, you can calling this optimization by: 64 | 65 | setting = QuantizationSettingFactory.default_setting() 66 | 67 | setting.fusion = True 68 | 69 | # calling ppq.api.quantize_onnx_model function with this setting. 70 | ir = quantize_torch_model( 71 | model=model, calib_dataloader=load_calibration_dataset(), setting=setting, 72 | platform=TargetPlatform.PPL_CUDA_INT8, calib_steps=8, input_shape=INPUT_SHAPE, 73 | collate_fn=collate_fn) 74 | -------------------------------------------------------------------------------- /md_doc/Passes/QuantSimplify.md: -------------------------------------------------------------------------------- 1 | ## PPQ Quantize Simplify Pass(通用量化精简过程) 2 | 3 | PPQ use Tensor Quantization Configuration(A data structure defined in ppq.core) to 4 | control quantization. Each quantable op will have a list of TQC as its quantization config, 5 | which contains necessary quantization parameter(scale, offset), in order to quantize its input(s) and output(s). 6 | 7 | While TQC is a powerful tool for describing quantization, it introduces some undiserible features: 8 | 9 | For a subgraph like: 10 | 11 | Relu1 - Relu2 12 | 13 | PPQ will create at least 4 TQC here, namely the input TQC of Relu1 and Relu2, and the output TQC of Relu1 and Relu2. 14 | Problem here is the output TQC of Relu1 and the input TQC of Relu2 is actually duplicated, the output variable 15 | should not be quantized twice. 16 | 17 | This Simplify Pass will detect all the duplicated TQCs in your network, disable them and create a link with their 18 | dominating TQCs. Disabled TQC will have and inactive state(QuantizationState.OVERRLAPED), so PPQ executor will 19 | simply ignore them when executing. 20 | 21 | A duplicated TQC is an input TQC(A) whose binding variable has been quantized by another output TQC(B), 22 | and the input TQC(A) should have the same bit-width as the output TQC(B) 23 | 24 | ### Parameters: 25 | 26 | * No Parameter 27 | 28 | ### Usage 29 | This pass is included in PPQ Quantization Setting, you can calling this optimization by: 30 | 31 | setting = QuantizationSettingFactory.default_setting() 32 | 33 | setting.fusion = True 34 | setting.fusion_setting.remove_useless_quantization = True 35 | 36 | # calling ppq.api.quantize_onnx_model function with this setting. 37 | ir = quantize_torch_model( 38 | model=model, calib_dataloader=load_calibration_dataset(), setting=setting, 39 | platform=TargetPlatform.PPL_CUDA_INT8, calib_steps=8, input_shape=INPUT_SHAPE, 40 | collate_fn=collate_fn) 41 | -------------------------------------------------------------------------------- /md_doc/Passes/RuntimeCalibrationPass.md: -------------------------------------------------------------------------------- 1 | ## Runtime Calibration Pass(量化参数校准过程) 2 | 3 | For integer quantization, you need to calibrate or estimate the scale of all floating-point tensors in the model. 4 | 5 | Formula: 6 | 7 | Quant(Y, scale_Y) = Clip(Round(Y / scale_Y)) 8 | 9 | Dequant(Y, scale_Y) = Y * scale_Y 10 | 11 | Only activations that have quantization state = INITIAL are going to be calibrated via this optimization pass. 12 | While if the parameter "override" is set to True, activations with quantization state = ACTIVATED will also be re-calibrated. 13 | 14 | Runtime Calibration Pass will write estimated scales and offsets to tensor quantization configs, and set their state to ACTIVATED. 15 | 16 | Unlike constant tensors such as weights and biases, variable tensors such as model input, activations (outputs of intermediate layers) and model output cannot be calibrated unless we run a few inference cycles. 17 | 18 | As a result, PPQ Runtime Calibration Pass requires a representative dataset to calibrate them. 19 | 20 | This dataset is supposed to be a small subset (around ~100-500 samples) of the training or validation data. 21 | 22 | ### Parameters: 23 | 24 | * method(str): 25 | 26 | String that representing the algorithm used to estimate scales and offsets for activations. 27 | 28 | Can be mse, kl, percentile, minmax, this parameter is case insensitive. 29 | 30 | You can register your own calibration method through functions in ppq.api 31 | 32 | * override(bool) 33 | 34 | if this parameter is set to True, activations with quantization state = ACTIVATED will also be re-calibrated, 35 | runtime calibration pass will overwrite their scales and offsets. 36 | 37 | This parameter is introduced since ppq 0.6.4 38 | 39 | ### Observer Support Matrix: 40 | | observer | Symmetrical | Asymmetrical | Per-channel | Per-tensor | Cuda Acceleration | 41 | | --- | --- | --- | --- | --- | --- | 42 | | minmax | ✔ | ✔ | ✔ | ✔ | | 43 | | mse | ✔ | ✔ | | ✔ | ✔ | 44 | | precentile | ✔ | ✔ | ✔ | ✔ | ✔ | 45 | | kl | ✔ | | | ✔ | ✔ | 46 | | isotone | ✔ | ✔ | | ✔ | | 47 | 48 | If possible, using Cuda kernel can speed up observer by 10~100x. 49 | 50 | ### Usage: 51 | 52 | Runtime Calibration Pass should be invoked before Passive Parameter Quantize Pass 53 | 54 | This pass is included in PPQ Quantization Setting, you can calling this optimization by: 55 | 56 | setting = QuantizationSettingFactory.default_setting() 57 | 58 | setting.quantize_activation = True 59 | 60 | # calling ppq.api.quantize_onnx_model function with this setting. 61 | ir = quantize_torch_model( 62 | model=model, calib_dataloader=load_calibration_dataset(), setting=setting, 63 | platform=TargetPlatform.PPL_CUDA_INT8, calib_steps=8, input_shape=INPUT_SHAPE, 64 | collate_fn=collate_fn) 65 | 66 | You can manually create this optimization by: 67 | 68 | from ppq import RuntimeCalibrationPass 69 | 70 | optim = RuntimeCalibrationPass() 71 | 72 | ### Register Calibration Method: 73 | 74 | Using api function register_calibration_observer to resister new observer algorithm to PPQ system. 75 | Once Algorithm is registered, Runtime Calibration Pass will automatically calling them by name. 76 | 77 | This feature requires PPQ > 0.6.5 78 | -------------------------------------------------------------------------------- /ppq/IR/__init__.py: -------------------------------------------------------------------------------- 1 | from .base.command import GraphCommand, GraphCommandType 2 | from .base.graph import (BaseGraph, GraphBuilder, GraphExporter, Operation, 3 | OperationExporter, Opset, Variable) 4 | from .base.opdef import OperationBase, OpSocket, VLink 5 | from .deploy import RunnableGraph 6 | from .morph import GraphFormatter, GraphMerger, GraphReplacer 7 | from .processer import DefaultGraphProcessor, GraphCommandProcessor 8 | from .quantize import (DeviceSwitchOP, QuantableGraph, QuantableOperation, 9 | QuantableVariable) 10 | from .search import (Path, GraphPattern, SearchableGraph, TraversalCommand, 11 | PatternMatchHelper) 12 | from .training import TrainableGraph -------------------------------------------------------------------------------- /ppq/IR/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/IR/base/__init__.py -------------------------------------------------------------------------------- /ppq/IR/training.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Union 2 | 3 | import torch 4 | 5 | from ppq.core import DataType 6 | 7 | from .base.graph import BaseGraph 8 | from .processer import GraphCommandProcessor 9 | 10 | 11 | class TrainableGraph(GraphCommandProcessor): 12 | """ Trainable Graph offers a bunch of functions that provide training interfaces. """ 13 | 14 | def __init__(self, graph_or_processor: Union[BaseGraph, Callable]) -> None: 15 | super().__init__(graph_or_processor) 16 | 17 | def parameters(self) -> List[torch.Tensor]: 18 | parameters = [] 19 | for var in self.graph.variables.values(): 20 | if var.is_parameter and DataType.to_torch(var.dtype) == torch.float: 21 | parameters.append(var.value) 22 | return parameters 23 | 24 | def zero_grad(self): 25 | for var in self.graph.variables.values(): 26 | if var.is_parameter and DataType.to_torch(var.dtype) == torch.float: 27 | if var.value._grad is not None: 28 | var.value._grad.zero_() 29 | 30 | def state_dict(self) -> dict: 31 | parameters = {} 32 | for var in self.graph.variables.values(): 33 | if var.is_parameter and DataType.to_torch(var.dtype) == torch.float: 34 | parameters[var.name] = var.value 35 | return parameters 36 | 37 | def _acceptable_command_types(self): return None 38 | def process(self): return None 39 | -------------------------------------------------------------------------------- /ppq/README.md: -------------------------------------------------------------------------------- 1 | ## Project hierarchy 代码结构 2 | 3 | * IR - PPQ 量化计算图定义,以及图上相关操作(算子变换, 算子融合等),量化计算图是基于 onnx 标准的 4 | * api - 用户接口,包含基本 api 函数 5 | * core - 核心数据结构定义、全局常量定义、编程语言接口等 6 | * executor - PPQ 训练与推理引擎,用于执行 PPQ IR 7 | * parser - 网络读取与导出模块 8 | * quantization - 量化逻辑 9 | * algorithm - 算法相关逻辑 10 | * analyse - 量化误差分析工具 11 | * measure - 损失函数集合 12 | * observer - 量化校准算法集合 13 | * optim - 量化优化过程集合 14 | * qfunction - PPQ 核心量化函数 15 | * quantizer - 量化器集合 16 | * samples - 示例文件 17 | * scheduler - 调度器 18 | * utils - 工具函数 19 | * csrc - C++ / Cuda 高性能算子库 20 | 21 | ## Reading Recommendations 推荐阅读 22 | * core.quant - 核心量化结构抽象 23 | * core.common - 全局常量定义 24 | * IR.search - 图模式匹配库 25 | * IR.quantize - 量化图定义 26 | * executor.torch - 量化推理引擎 27 | * quantization.optim - 量化优化过程 28 | * quantization.analyse - 量化误差分析 29 | * quantization.quantizer - 量化器 30 | * scheduler.perseus - 调度器 31 | * utils.round - 量化取整策略 32 | * csrc - 高性能算子库 33 | -------------------------------------------------------------------------------- /ppq/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | if os.path.dirname(os.path.realpath(__file__)) == os.path.join( 5 | os.path.realpath(os.getcwd()), "ppq" 6 | ): 7 | message = ( 8 | "You are importing ppq within its own root folder ({}). " 9 | ) 10 | warnings.warn(message.format(os.getcwd())) 11 | 12 | # This file defines export functions & class of PPQ. 13 | from ppq.api.setting import (ActivationQuantizationSetting, DispatchingTable, 14 | EqualizationSetting, GraphFormatSetting, 15 | LSQSetting, ParameterQuantizationSetting, 16 | QuantizationFusionSetting, QuantizationSetting, 17 | QuantizationSettingFactory, TemplateSetting) 18 | from ppq.core import * 19 | from ppq.executor import (BaseGraphExecutor, TorchExecutor, 20 | TorchQuantizeDelegator) 21 | from ppq.IR import (BaseGraph, GraphBuilder, GraphCommand, GraphExporter, 22 | GraphFormatter, Operation, QuantableGraph, SearchableGraph, 23 | Variable, TrainableGraph) 24 | from ppq.IR.deploy import RunnableGraph 25 | from ppq.IR.quantize import QuantableOperation, QuantableVariable 26 | from ppq.IR.search import SearchableGraph 27 | from ppq.log import NaiveLogger 28 | from ppq.quantization.analyse import (graphwise_error_analyse, 29 | layerwise_error_analyse, 30 | parameter_analyse, statistical_analyse, 31 | variable_analyse) 32 | from ppq.quantization.measure import (torch_cosine_similarity, 33 | torch_cosine_similarity_as_loss, 34 | torch_KL_divergence, 35 | torch_mean_square_error, torch_snr_error) 36 | from ppq.quantization.optim import (BiasCorrectionPass, GRUSplitPass, 37 | HorizontalLayerSplitPass, 38 | LayerwiseEqualizationPass, 39 | MetaxGemmSplitPass, MishFusionPass, 40 | NxpInputRoundingRefinePass, 41 | NxpQuantizeFusionPass, 42 | NXPResizeModeChangePass, 43 | ParameterBakingPass, ParameterQuantizePass, 44 | PassiveParameterQuantizePass, 45 | QuantizationOptimizationPass, 46 | QuantizationOptimizationPipeline, 47 | QuantizeFusionPass, QuantizeSimplifyPass, 48 | RuntimeCalibrationPass, SwishFusionPass) 49 | from ppq.quantization.qfunction import (BaseQuantFunction, 50 | PPQDyamicLinearQuantFunction, 51 | PPQFloatingQuantFunction, 52 | PPQLinearQuant_toInt, 53 | PPQLinearQuantFunction, 54 | PPQuantFunction, PPQuantFunction_toInt) 55 | from ppq.quantization.quantizer import (BaseQuantizer, NXP_Quantizer, 56 | PPL_DSP_Quantizer, PPLCUDAQuantizer, 57 | TensorRTQuantizer) 58 | from ppq.scheduler import (AggresiveDispatcher, ConservativeDispatcher, 59 | GraphDispatcher, PPLNNDispatcher) 60 | from ppq.scheduler.perseus import Perseus 61 | from ppq.utils.round import (ppq_numerical_round, ppq_round_to_power_of_2, 62 | ppq_tensor_round) 63 | -------------------------------------------------------------------------------- /ppq/api/__init__.py: -------------------------------------------------------------------------------- 1 | from ppq.lib import (register_calibration_observer, register_network_exporter, 2 | register_network_parser, register_network_quantizer, 3 | register_operation_handler) 4 | 5 | from .fsys import (compare_cosine_similarity_between_results, create_dir, 6 | dump_internal_results, dump_to_file, 7 | load_calibration_dataset, load_from_file, 8 | split_result_to_directory) 9 | from .interface import (DISABLE_CUDA_KERNEL, ENABLE_CUDA_KERNEL, 10 | UnbelievableUserFriendlyQuantizationSetting, 11 | dispatch_graph, dump_torch_to_onnx, empty_ppq_cache, 12 | export, export_ppq_graph, format_graph, 13 | load_caffe_graph, load_graph, load_native_graph, 14 | load_onnx_graph, manop, quantize, quantize_caffe_model, 15 | quantize_native_model, quantize_onnx_model, 16 | quantize_torch_model, load_torch_model) 17 | from .setting import (ActivationQuantizationSetting, BiasCorrectionSetting, 18 | BlockwiseReconstructionSetting, ChannelSplitSetting, 19 | DispatchingTable, EqualizationSetting, 20 | GraphFormatSetting, LSQSetting, 21 | ParameterQuantizationSetting, QuantizationFusionSetting, 22 | QuantizationSetting, QuantizationSettingFactory, 23 | SSDEqualizationSetting, TemplateSetting, 24 | WeightSplitSetting) 25 | -------------------------------------------------------------------------------- /ppq/core/config.py: -------------------------------------------------------------------------------- 1 | class PPQ_GLOBAL_CONFIGURATION: 2 | def __init__(self) -> None: 3 | # 是否启动 cuda kernel 加速计算 4 | self.USING_CUDA_KERNEL = False 5 | 6 | # PPQ 的名字 7 | self.NAME = 'PPL Quantization Tool' 8 | 9 | # PPQ 的版本号 10 | self.VERSION = '0.6.6' 11 | 12 | # 导出图时是否导出权重(仅影响 Native 格式导出) 13 | self.DUMP_VALUE_WHEN_EXPORT = True 14 | 15 | # 导出图时,是否导出调度信息 16 | self.EXPORT_PPQ_INTERNAL_INFO = False 17 | 18 | # 开启 PPQ 调试模式,将打印所有量化点插入信息 19 | self.PPQ_DEBUG = False 20 | 21 | PPQ_CONFIG = PPQ_GLOBAL_CONFIGURATION() 22 | -------------------------------------------------------------------------------- /ppq/core/defs.py: -------------------------------------------------------------------------------- 1 | """PPQ Core Decorator & MetaClass definitions PPQ 核心装饰器、元类型定义. 2 | 3 | You are not allowed to modify this 请勿修改此文件 4 | """ 5 | 6 | import gc 7 | from typing import Callable 8 | from torch.cuda import empty_cache 9 | 10 | from .config import PPQ_CONFIG 11 | 12 | 13 | class SingletonMeta(type): 14 | """The Singleton class can be implemented in different ways in Python. Some 15 | possible methods include: base class, decorator, metaclass. We will use the 16 | metaclass because it is best suited for this purpose. 17 | 18 | see also: https://refactoring.guru/design-patterns/singleton/python/example 19 | """ 20 | 21 | _instances = {} 22 | 23 | def __call__(cls, *args, **kwargs): 24 | """Possible changes to the value of the `__init__` argument do not 25 | affect the returned instance.""" 26 | if cls not in cls._instances: 27 | instance = super().__call__(*args, **kwargs) 28 | cls._instances[cls] = instance 29 | return cls._instances[cls] 30 | 31 | 32 | def ppq_legacy(func: str, version: str, adapt_to: str = None): 33 | """Mark an function as legacy function. 34 | 35 | Args: 36 | func (str): _description_ 37 | version (str): _description_ 38 | adapt_to (str, optional): _description_. Defaults to None. 39 | """ 40 | print(f'{func} has been obsoleted since PPQ {version}, use {adapt_to} instead.') 41 | 42 | 43 | def empty_ppq_cache(func: Callable): 44 | """Using empty_ppq_cache decorator to clear ppq memory cache, both gpu 45 | memory and cpu memory will be clear via this function. 46 | 47 | Function which get decorated by this will clear all ppq system cache BEFORE its running. 48 | Args: 49 | func (Callable): decorated function 50 | """ 51 | def _wrapper(*args, **kwargs): 52 | empty_cache() # torch.cuda.empty_cache might requires a sync of all cuda device. 53 | gc.collect() # empty memory. 54 | return func(*args, **kwargs) 55 | return _wrapper 56 | 57 | 58 | def ppq_quant_param_computing_function(func: Callable): 59 | """mark a function to be a scale-computing function. 60 | 61 | Args: 62 | func (Callable): decorated function 63 | """ 64 | def _wrapper(*args, **kwargs): 65 | return func(*args, **kwargs) 66 | return _wrapper 67 | 68 | 69 | def ppq_debug_function(func: Callable): 70 | """mark a function to be a debug function. 71 | 72 | Args: 73 | func (Callable): decorated function 74 | """ 75 | def _wrapper(*args, **kwargs): 76 | if PPQ_CONFIG.PPQ_DEBUG: 77 | debug_str = func(*args, **kwargs) 78 | if debug_str is None: return None 79 | assert isinstance(debug_str, str), ( 80 | 'ppq_debug_function should only return string instance, ' 81 | f'while {str(func)} returns {type(debug_str)}') 82 | print(debug_str, end='') 83 | else: return None 84 | 85 | return _wrapper 86 | 87 | 88 | def ppq_file_io(func: Callable): 89 | """mark a function to be a ppq file io function. 90 | 91 | function must have return a file handle. 92 | Args: 93 | func (Callable): decorated function 94 | """ 95 | def _wrapper(*args, **kwargs): 96 | return func(*args, **kwargs) 97 | return _wrapper 98 | 99 | 100 | def ppq_warning(info: str): 101 | print(f'\033[31m[Warning] {info}\033[0m') 102 | 103 | 104 | def ppq_info(info: str): 105 | print(f'\033[33m[Info] {info}\033[0m') 106 | -------------------------------------------------------------------------------- /ppq/csrc/build/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !readme.md 4 | -------------------------------------------------------------------------------- /ppq/csrc/build/readme.md: -------------------------------------------------------------------------------- 1 | # this dir is for jit compilation, please do not delete it 2 | -------------------------------------------------------------------------------- /ppq/csrc/cpu/common.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/csrc/cpu/common.h -------------------------------------------------------------------------------- /ppq/csrc/cpu/hist_mse.cc: -------------------------------------------------------------------------------- 1 | # include "hist_mse.h" 2 | 3 | float compute_mse_loss( 4 | const vector &hist, 5 | const int start, 6 | const int step, 7 | const int end){ 8 | int64_t num_of_elements = 0; float loss = 0.0; 9 | for(auto v: hist) num_of_elements += v; 10 | for(int idx = 0; idx < hist.size(); idx++){ 11 | float error = 0.0f; 12 | int64_t bin = hist[idx]; 13 | if(idx < start) error = start - idx - 1 + 0.5; 14 | else if(idx > end) error =idx - end + 0.5; 15 | else{ 16 | int64_t l_idx = (idx - start) % step; 17 | int64_t r_idx = step - l_idx - 1; 18 | if(l_idx == r_idx) error = l_idx + 0.25; 19 | else{ 20 | float l_err = (l_idx + 0.5); 21 | float r_err = (r_idx + 0.5); 22 | error = l_err < r_err ? l_err: r_err; 23 | } 24 | } 25 | loss += (bin * error * error) / num_of_elements; 26 | } 27 | return loss; 28 | } -------------------------------------------------------------------------------- /ppq/csrc/cpu/hist_mse.h: -------------------------------------------------------------------------------- 1 | # include 2 | # include 3 | 4 | using std::vector; 5 | float compute_mse_loss( 6 | const vector &hist, 7 | const int start, 8 | const int step, 9 | const int end); -------------------------------------------------------------------------------- /ppq/csrc/cuda/common.h: -------------------------------------------------------------------------------- 1 | # include 2 | # include 3 | # include 4 | 5 | using std::vector; 6 | using int64 = long long; 7 | -------------------------------------------------------------------------------- /ppq/csrc/cuda/floating.h: -------------------------------------------------------------------------------- 1 | # include "common.cuh" 2 | 3 | Tensor QuantizeTensor_FT( 4 | const Tensor &value, const Tensor &scale, const Tensor &offset, 5 | const int exponent, const int mantissa, 6 | const float clip_min, const float clip_max, const Rounding rounding); 7 | 8 | Tensor QuantizeTensor_FC( 9 | const Tensor &value, const Tensor &scale, const Tensor &offset, 10 | const int exponent, const int mantissa, 11 | const float clip_min, const float clip_max, const int channel_axis, 12 | const Rounding rounding); 13 | 14 | std::vector QuantizeTensor_FT_B( 15 | const Tensor &value, const Tensor &scales, 16 | const Tensor &offsets, const Tensor &grad_y, 17 | const int exponent, const int mantissa, 18 | const float clip_min, const float clip_max, 19 | const Rounding rounding); 20 | 21 | std::vector QuantizeTensor_FC_B( 22 | const Tensor &value, const Tensor &scales, 23 | const Tensor &offsets, const Tensor &grad_y, 24 | const int exponent, const int mantissa, 25 | const float clip_min, const float clip_max, 26 | const Rounding rounding, const int channel_axis); 27 | -------------------------------------------------------------------------------- /ppq/csrc/cuda/isotone.cc: -------------------------------------------------------------------------------- 1 | # include "isotone.h" 2 | 3 | namespace PPQ_Crt{ 4 | 5 | float SolveIsotoneScale( 6 | const float *first_largest_arr, 7 | const float *second_largest_arr, 8 | const int64 length, const int quant_max){ 9 | /** 10 | * Solving isotonic quantization scale 11 | * Algorithm Complexity is 0(n^2), where n denotes the num of batches 12 | * This can be optimized further to reach O(n log n) complexity. 13 | */ 14 | float *candidates = new float[length * 2]; 15 | float *s_maxs = new float[length]; 16 | float *s_mins = new float[length]; 17 | 18 | for(int64 i = 0; i < length; i++){ 19 | auto f = first_largest_arr[i]; 20 | auto s = second_largest_arr[i]; 21 | float s_min = s / (quant_max - 1); 22 | float s_max = 2 * (f - s); 23 | 24 | s_maxs[i] = s_max; 25 | s_mins[i] = s_min; 26 | 27 | candidates[2 * i] = s_min; 28 | candidates[2 * i + 1] = s_max; 29 | } 30 | 31 | int64 best_satisified = 0; 32 | float best_scale = 1.0f; 33 | 34 | for(int64 i = 0; i < length * 2; i++){ 35 | auto c = candidates[i]; int64 satisified = 0; 36 | for(int64 j = 0; j < length; j++){ 37 | float s_max = s_maxs[j]; 38 | float s_min = s_mins[j]; 39 | satisified += (c <= s_max) + (c >= s_min); 40 | } 41 | if (satisified > best_satisified){ 42 | best_satisified = satisified; 43 | best_scale = c; 44 | } 45 | } 46 | 47 | delete [] candidates; 48 | delete [] s_maxs; 49 | delete [] s_mins; 50 | return best_scale; 51 | } 52 | 53 | } // end of namespace -------------------------------------------------------------------------------- /ppq/csrc/cuda/isotone.h: -------------------------------------------------------------------------------- 1 | # include "common.h" 2 | 3 | float SolveIsotoneScale( 4 | const float *first_largest_arr, 5 | const float *second_largest_arr, 6 | const int64 length, const int quant_max); -------------------------------------------------------------------------------- /ppq/csrc/cuda/linear.h: -------------------------------------------------------------------------------- 1 | # include "common.cuh" 2 | 3 | Tensor QuantizeTensor_LT( 4 | const Tensor &value, const Tensor &scale, const Tensor &offset, 5 | const int clip_min, const int clip_max, const Rounding rounding); 6 | 7 | Tensor QuantizeTensor_LC( 8 | const Tensor &value, const Tensor &scale, const Tensor &offset, 9 | const int clip_min, const int clip_max, const int channel_axis, 10 | const Rounding rounding); 11 | 12 | std::vector QuantizeTensor_LT_B( 13 | const Tensor &value, const Tensor &scale, 14 | const Tensor &offset, const Tensor &grad_y, 15 | const int clip_min, const int clip_max, 16 | const Rounding rounding); 17 | 18 | std::vector QuantizeTensor_LC_B( 19 | const Tensor &value, const Tensor &scale, 20 | const Tensor &offset, const Tensor &grad_y, 21 | const int clip_min, const int clip_max, 22 | const Rounding rounding, const int channel_axis); 23 | -------------------------------------------------------------------------------- /ppq/csrc/cuda/sort.h: -------------------------------------------------------------------------------- 1 | # include "common.cuh" 2 | 3 | Tensor Quantile_T(const Tensor &source, const float q); 4 | 5 | void Histogram_T( 6 | const Tensor &value, 7 | const float hist_scale, 8 | const bool clip_outliers, 9 | Tensor &hist); 10 | 11 | void Histogram_C( 12 | const Tensor &value, 13 | const int channel_axis, 14 | const float hist_scale, 15 | const bool clip_outliers, 16 | Tensor &hist); 17 | 18 | void Histogram_Asymmetric_T( 19 | const float min, 20 | const float max, 21 | const Tensor &value, 22 | const bool clip_outliers, 23 | Tensor &hist); 24 | 25 | Tensor Isotone_T(const Tensor &source); -------------------------------------------------------------------------------- /ppq/csrc/cuda/test.h: -------------------------------------------------------------------------------- 1 | # include "common.cuh" 2 | 3 | void dummy_pooling_v1( 4 | const Tensor source, 5 | Tensor dest, 6 | const float in_scale, 7 | const float out_scale 8 | ); 9 | 10 | void dummy_pooling_v2( 11 | const Tensor source, 12 | Tensor dest, 13 | const float in_scale, 14 | const float out_scale 15 | ); 16 | 17 | void dummy_pooling_v3( 18 | const Tensor source, 19 | Tensor dest, 20 | const float in_scale, 21 | const float out_scale 22 | ); 23 | -------------------------------------------------------------------------------- /ppq/csrc/cuda/train.h: -------------------------------------------------------------------------------- 1 | # include "common.cuh" 2 | 3 | Tensor TensorClip_T( 4 | const Tensor &value, const Tensor &reference, const Tensor &limit); 5 | 6 | Tensor TensorClip_C( 7 | const Tensor &value, const Tensor &reference, const Tensor &limit, 8 | const int channel_axis); 9 | 10 | Tensor RoundingLoss_LT( 11 | const Tensor &value, const Tensor &scale, const Tensor &offset, 12 | const int clip_min, const int clip_max, 13 | Rounding rounding); 14 | 15 | Tensor RoundingLoss_LT_B( 16 | const Tensor &value, const Tensor &dy, 17 | const Tensor &scale, const Tensor &offset, 18 | const int clip_min, const int clip_max, 19 | Rounding rounding); 20 | 21 | Tensor RoundingLoss_LC( 22 | const Tensor &value, const Tensor &scale, const Tensor &offset, 23 | const int clip_min, const int clip_max, 24 | const int channel_axis, Rounding rounding); 25 | 26 | Tensor RoundingLoss_LC_B( 27 | const Tensor &value, const Tensor &dy, 28 | const Tensor &scale, const Tensor &offset, 29 | const int clip_min, const int clip_max, 30 | const int channel_axis, Rounding rounding); 31 | -------------------------------------------------------------------------------- /ppq/csrc/export.cc: -------------------------------------------------------------------------------- 1 | # include "cuda/linear.h" 2 | # include "cuda/sort.h" 3 | # include "cuda/train.h" 4 | # include "cuda/train.h" 5 | # include "cuda/floating.h" 6 | # include "cpu/hist_mse.h" 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | m.def("Quantile_T", Quantile_T, "Quantile_T"); 10 | m.def("Histogram_T", Histogram_T, "Histogram_T"); 11 | m.def("Histogram_Asymmetric_T", Histogram_Asymmetric_T, "Histogram_Asymmetric_T"); 12 | m.def("Histogram_C", Histogram_C, "Histogram_C"); 13 | 14 | m.def("QuantizeTensor_LT", QuantizeTensor_LT, "QuantizeTensor_LT"); 15 | m.def("QuantizeTensor_LC", QuantizeTensor_LC, "QuantizeTensor_LC"); 16 | m.def("QuantizeTensor_LT_B", QuantizeTensor_LT_B, "QuantizeTensor_LT_B"); 17 | m.def("QuantizeTensor_LC_B", QuantizeTensor_LC_B, "QuantizeTensor_LC_B"); 18 | 19 | m.def("QuantizeTensor_FT", QuantizeTensor_FT, "QuantizeTensor_FT"); 20 | m.def("QuantizeTensor_FC", QuantizeTensor_FC, "QuantizeTensor_FC"); 21 | m.def("QuantizeTensor_FT_B", QuantizeTensor_FT_B, "QuantizeTensor_FT_B"); 22 | m.def("QuantizeTensor_FC_B", QuantizeTensor_FC_B, "QuantizeTensor_FC_B"); 23 | 24 | m.def("TensorClip_T", TensorClip_T, "TensorClip_T"); 25 | m.def("TensorClip_C", TensorClip_C, "TensorClip_C"); 26 | 27 | m.def("RoundingLoss_LT", RoundingLoss_LT, "RoundingLoss_LT"); 28 | m.def("RoundingLoss_LC", RoundingLoss_LC, "RoundingLoss_LC"); 29 | m.def("RoundingLoss_LT_B", RoundingLoss_LT_B, "RoundingLoss_LT_B"); 30 | m.def("RoundingLoss_LC_B", RoundingLoss_LC_B, "RoundingLoss_LC_B"); 31 | 32 | m.def("Isotone_T", Isotone_T, "Isotone_T"); 33 | m.def("compute_mse_loss", compute_mse_loss, "compute_mse_loss"); 34 | } 35 | -------------------------------------------------------------------------------- /ppq/executor/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import (BaseGraphExecutor, QuantOPRuntimeHook, RuntimeHook, 2 | register_operation_handler) 3 | from .torch import TorchExecutor, TorchQuantizeDelegator 4 | -------------------------------------------------------------------------------- /ppq/executor/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .torch import (DEFAULT_BACKEND_TABLE, NXP_BACKEND_TABLE, EXTENSION_BACKEND_TABLE, 2 | PPL_DSP_BACKEND_TABLE, PPL_GPU_BACKEND_TABLE, ONNX_BACKEND_TABLE, 3 | ACADEMIC_BACKEND_TABLE, TorchBackendContext) 4 | -------------------------------------------------------------------------------- /ppq/executor/op/fp32/fp32_backend.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/executor/op/fp32/fp32_backend.py -------------------------------------------------------------------------------- /ppq/executor/op/ppl_dsp/ppl_dsp_backend.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/executor/op/ppl_dsp/ppl_dsp_backend.py -------------------------------------------------------------------------------- /ppq/executor/op/ppl_trt/ppl_gpu_backend.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/executor/op/ppl_trt/ppl_gpu_backend.py -------------------------------------------------------------------------------- /ppq/executor/op/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import (DEFAULT_BACKEND_TABLE, AveragePool_forward, AdaptiveAvgPool2d_forward, 2 | BatchNormalization_forward, Cast_forward, Clip_forward, 3 | Concat_forward, Constant_forward, 4 | ConstantOfShape_forward, Conv_forward, Eltwise_forward, 5 | Equal_forward, Expand_forward, Flatten_forward, 6 | Gather_forward, GatherND_forward, Gemm_forward, 7 | Greater_forward, MaxPool2d_forward, _NMS_forward, 8 | NonZero_forward, Range_forward, ReduceL2_forward, 9 | ReduceMax_forward, Reshape_forward, Resize_forward, 10 | ScatterElements_forward, ScatterND_forward, 11 | Shape_forward, Slice_forward, Softmax_forward, 12 | Squeeze_forward, Tile_forward, TopK_forward, 13 | Transpose_forward, UnaryEltwise_forward, 14 | Unsqueeze_forward, Where_forward) 15 | from .dsp import PPL_DSP_BACKEND_TABLE 16 | from .cuda import PPL_GPU_BACKEND_TABLE 17 | from .nxp import NXP_BACKEND_TABLE 18 | from .extension import EXTENSION_BACKEND_TABLE 19 | from .base import TorchBackendContext 20 | from .onnx import ONNX_BACKEND_TABLE 21 | from .academic import ACADEMIC_BACKEND_TABLE 22 | -------------------------------------------------------------------------------- /ppq/executor/op/torch/academic.py: -------------------------------------------------------------------------------- 1 | from .default import DEFAULT_BACKEND_TABLE 2 | 3 | ACADEMIC_BACKEND_TABLE = DEFAULT_BACKEND_TABLE.copy() 4 | 5 | # When you trying to implement a custimized function for ppl_dsp platform 6 | # Be aware that you can just overwrite part of DEFAULT_DISPATCHING_TABLE 7 | # rather than rewrite all dispatching table. 8 | # here an example was given: Sample_Forward 9 | def Sample_Forward(): 10 | return None 11 | 12 | ACADEMIC_BACKEND_TABLE['Sample_Function'] = Sample_Forward 13 | -------------------------------------------------------------------------------- /ppq/executor/op/torch/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | from ppq.core import TargetPlatform 4 | from ppq.IR import Operation 5 | from ppq.IR.quantize import QuantableOperation 6 | 7 | import torch 8 | 9 | 10 | class TorchBackendContext: 11 | def __init__(self, executing_device: str) -> None: 12 | self.executing_device = executing_device 13 | 14 | def ASSERT_NUM_OF_INPUT(op: Operation, values: List[torch.Tensor], 15 | min_num_of_input: int = -1, max_num_of_input: int = 99): 16 | if min_num_of_input == max_num_of_input: 17 | if len(values) != min_num_of_input: 18 | raise ValueError(f'Can not feed value to operation {op.name}, ' 19 | f'expects exact {min_num_of_input} inputs, however {len(values)} was given') 20 | elif len(values) > max_num_of_input: 21 | raise ValueError(f'Too many input value for {op.name}, ' 22 | f'expects {max_num_of_input} inputs at most, however {len(values)} was given') 23 | elif len(values) < min_num_of_input: 24 | raise ValueError(f'Too few input value for {op.name}, ' 25 | f'expects {min_num_of_input} inputs at least, however {len(values)} was given') 26 | 27 | def GET_ATTRIBUTE_FROM_OPERATION(op: Operation, attribute: str, compulsive: bool = False, default: Any = None): 28 | """Try to get an attribute from operation. If an attribute is compulsive, 29 | then operation must give a value of it, otherwise an error will be thrown. 30 | If an attribute is not compulsive, a default value will be given if 31 | operation.attributes do not holds a value of requesting attribute. 32 | 33 | Args: 34 | op (Operation): Operation instance. 35 | attribute (str): Attribute name. 36 | compulsive (bool): Whether is a compulsive attribute. 37 | default (Any, optional): [description]. default value of attribute. 38 | """ 39 | if attribute in op.attributes: 40 | return op.attributes[attribute] 41 | else: 42 | if compulsive: 43 | raise KeyError( 44 | f'Operation {op.name} is supposed to have a value of attribute {attribute}. ', 45 | 'However this value is missing from currecnt operation.') 46 | else: 47 | return default 48 | 49 | def GET_VALUE_FROM_INPUTS(values: list, idx: int) -> torch.Tensor: 50 | assert isinstance(idx, int) 51 | assert idx > 0 52 | if len(values) > idx: return values[idx] 53 | else: return None 54 | 55 | def ASSERT_IS_QUANT_OP(op: QuantableOperation): 56 | if not isinstance(op, QuantableOperation): 57 | raise TypeError(f'Given Operation is expected as a QuantableOperation, however {type(op)} was given.') 58 | 59 | def FORCE_CONVERT_DEVICE(value: torch.Tensor, device: str) -> torch.Tensor: 60 | # SET LOG HERE FOR DEBUG. 61 | # value = value.clone() 62 | return value.to(device=device, copy=True) 63 | 64 | def VALUE_TO_EXECUTING_DEVICE(op: Operation, ctx: TorchBackendContext, values: List[torch.Tensor]) -> List[torch.Tensor]: 65 | if ctx is None: device = values[0].device 66 | else: device = ctx.executing_device 67 | for idx, (plat, value) in enumerate(zip(op.socket.in_plat, values)): 68 | if value is None: continue 69 | if plat == TargetPlatform.SOI or op.platform == TargetPlatform.SOI: 70 | values[idx] = value.cpu() 71 | else: values[idx] = value.to(device) 72 | return values -------------------------------------------------------------------------------- /ppq/executor/op/torch/dsp.py: -------------------------------------------------------------------------------- 1 | from .default import DEFAULT_BACKEND_TABLE 2 | 3 | PPL_DSP_BACKEND_TABLE = DEFAULT_BACKEND_TABLE.copy() 4 | 5 | # When you trying to implement a custimized function for ppl_dsp platform 6 | # Be aware that you can just overwrite part of DEFAULT_DISPATCHING_TABLE 7 | # rather than rewrite all dispatching table. 8 | # here an example was given: Sample_Forward 9 | def Sample_Forward(): 10 | return None 11 | 12 | PPL_DSP_BACKEND_TABLE['Sample_Function'] = Sample_Forward 13 | -------------------------------------------------------------------------------- /ppq/executor/op/torch/extension.py: -------------------------------------------------------------------------------- 1 | from .default import DEFAULT_BACKEND_TABLE 2 | 3 | EXTENSION_BACKEND_TABLE = DEFAULT_BACKEND_TABLE.copy() 4 | 5 | # When you trying to implement a custimized function for ppl_gpu platform 6 | # Be aware that you can just overwrite part of DEFAULT_DISPATCHING_TABLE 7 | # rather than rewrite all dispatching table. 8 | # here an example was given: Sample_Forward 9 | def Sample_Forward(): 10 | return None 11 | 12 | EXTENSION_BACKEND_TABLE['Sample_Function'] = Sample_Forward 13 | -------------------------------------------------------------------------------- /ppq/executor/op/torch/nxp.py: -------------------------------------------------------------------------------- 1 | from .default import DEFAULT_BACKEND_TABLE 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | NXP_BACKEND_TABLE = DEFAULT_BACKEND_TABLE.copy() 7 | 8 | 9 | def Resize_forward(op, input_value, device=None): 10 | """NXP Platform has a custimized resize operation implementation, which 11 | gives different result from torch.nn.Resize. To correctly simulate hardware 12 | beviour and have the same result with NXP, it is necessary to force resize 13 | to run with nearest mode. Any other mode of resize will be ignored by this 14 | function. 15 | 16 | Args: 17 | op ([type]): [description] 18 | input_value ([type]): [description] 19 | device ([type], optional): [description]. Defaults to None. 20 | 21 | Returns: 22 | [type]: [description] 23 | """ 24 | input_data = input_value[0] 25 | # Not used roi 26 | # roi = input_value[1] if len(input_value) > 1 else None 27 | scales = input_value[2] if len(input_value) > 2 else None 28 | sizes = input_value[-1].tolist() if len(input_value) == 4 else None 29 | mode = 'nearest' 30 | 31 | # If 'size' is specified, then set scales to empty data (zero shape) in this operator's input list. 32 | if sizes is None or len(sizes) == 0: 33 | sizes = None 34 | if scales.numel() == 1: 35 | scales = scales.item() 36 | else: 37 | assert scales.numel() % 2 == 0 38 | scales = scales[-2].cpu().numpy().tolist() 39 | else: 40 | # the sizes in onnx is 4-D while in pytorch is 2-D 41 | # check the dim.0 & dim.1 is equal, then remain dim.2 and dim.3 42 | scales = None 43 | assert (sizes[:2] == list(input_data.shape[:2])) 44 | sizes = sizes[2:] 45 | 46 | trans_mode = op.attributes.get('coordinate_transformation_mode', 'half_pixel') 47 | if trans_mode == 'align_corners': 48 | output = F.interpolate(input_data, sizes, scales, mode, align_corners=True) 49 | else: 50 | output = F.interpolate(input_data, sizes, scales, mode) 51 | return output 52 | 53 | 54 | # When you trying to implement a custimized function for ppl_dsp platform 55 | # Be aware that you can just overwrite part of DEFAULT_DISPATCHING_TABLE 56 | # rather than rewrite all dispatching table. 57 | # here an example was given: Sample_Forward 58 | def Sample_Forward(): 59 | return None 60 | 61 | 62 | NXP_BACKEND_TABLE['Sample_Function'] = Sample_Forward 63 | # NXP_DISPATCHING_TABLE['Resize'] = Resize_forward 64 | -------------------------------------------------------------------------------- /ppq/executor/op/torch/onnx.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from ppq.IR import Operation 4 | 5 | from .default import DEFAULT_BACKEND_TABLE 6 | from .base import * 7 | 8 | import torch 9 | 10 | ONNX_BACKEND_TABLE = DEFAULT_BACKEND_TABLE.copy() 11 | 12 | # When you trying to implement a custimized function for ppl_gpu platform 13 | # Be aware that you can just overwrite part of DEFAULT_DISPATCHING_TABLE 14 | # rather than rewrite all dispatching table. 15 | # here an example was given: Sample_Forward 16 | def Sample_Forward(): 17 | return None 18 | 19 | ONNX_BACKEND_TABLE['Sample_Forward'] = Sample_Forward 20 | -------------------------------------------------------------------------------- /ppq/lib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## PPQ Foundation Library - PFL 3 | 4 | PPQ 基础类库 5 | 6 | PFL is a collection of basic classes and functions that provides fundamental functionalities. 7 | 8 | - Parser: Get a network parser. 9 | 10 | - Exporter: According to given platform, get a network exporter. 11 | 12 | - OperationForwardFunction: According to given platform and optype, get a forward function. 13 | 14 | - Dispatcher: Get a network dispatcher. 15 | 16 | - FloatingQuantizationConfig: Get a TensorQuantizationConfig for FP8 Quantization. 17 | 18 | - LinearQuantizationConfig: Get a TensorQuantizationConfig for INT8 Quantization. 19 | 20 | - QuantStub: Get a QuantStub class instance. 21 | 22 | - Quantizer: Get a Quantizer corresponding to given platform. 23 | 24 | - Observer: Get a Tensor Observer, which is bound to given TensorQuantizationConfig. 25 | 26 | - Pipeline: Build Optimization Pipeline. 27 | 28 | - QuantFunction: Get PPQ Default Quantize Function. 29 | 30 | PFL also provides a set of functions to register Quantizer, Parser, Exporter to PPQ. 31 | 32 | - register_network_quantizer 33 | 34 | - register_network_parser 35 | 36 | - register_network_exporter 37 | 38 | - register_operation_handler 39 | 40 | - register_calibration_observer 41 | 42 | """ 43 | 44 | from .extension import (register_calibration_observer, 45 | register_network_exporter, register_network_parser, 46 | register_network_quantizer, register_operation_handler) 47 | from .quant import (Dispatcher, Exporter, FloatingQuantizationConfig, 48 | LinearQuantizationConfig, Observer, 49 | OperationForwardFunction, Parser, Pipeline, QuantFunction, 50 | Quantizer, TensorQuant, ParameterQuant) 51 | -------------------------------------------------------------------------------- /ppq/log/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import NaiveLogger 2 | -------------------------------------------------------------------------------- /ppq/parser/__init__.py: -------------------------------------------------------------------------------- 1 | from ppq.core import NetworkFramework, TargetPlatform 2 | from ppq.IR import BaseGraph, GraphBuilder, GraphExporter 3 | 4 | from .caffe_exporter import (CaffeExporter, PPLDSPCaffeExporter, 5 | PPLDSPTICaffeExporter, SNPECaffeExporter) 6 | from .caffe_parser import CaffeParser 7 | from .extension import ExtensionExporter 8 | from .native import NativeExporter, NativeImporter 9 | from .nxp_exporter import NxpExporter 10 | from .onnx_exporter import OnnxExporter 11 | from .onnx_parser import OnnxParser 12 | from .onnxruntime_exporter import ONNXRUNTIMExporter 13 | from .ppl import PPLBackendExporter 14 | from .tensorRT import TensorRTExporter_QDQ, TensorRTExporter_JSON 15 | from .qnn_exporter import QNNDSPExporter 16 | from .ncnn_exporter import NCNNExporter 17 | from .tengine_exporter import TengineExporter 18 | from .ascend_export import AscendExporter 19 | from .mnn_exporter import MNNExporter 20 | from .openvino_exporter import OpenvinoExporter -------------------------------------------------------------------------------- /ppq/parser/caffe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/parser/caffe/__init__.py -------------------------------------------------------------------------------- /ppq/parser/caffe_parser.py: -------------------------------------------------------------------------------- 1 | from google.protobuf import text_format 2 | from ppq.core import NetworkFramework, is_file_exist 3 | from ppq.IR import BaseGraph, GraphBuilder 4 | from ppq.log import NaiveLogger 5 | 6 | from .caffe import ppl_caffe_pb2 7 | from .caffe.caffe_graph_optim import de_inplace, merge_batchnorm_scale 8 | from .caffe.caffe_import_utils import caffe_import_map, get_input_shape 9 | 10 | logger = NaiveLogger.get_logger('PPQ') 11 | 12 | class CaffeParser(GraphBuilder): 13 | def load_graph_and_format(self, prototxt_path: str, caffemodel_path: str) -> ppl_caffe_pb2.NetParameter: 14 | if not is_file_exist(prototxt_path): 15 | raise FileNotFoundError(f'file {prototxt_path} not exist, please check your file path') 16 | elif not is_file_exist(caffemodel_path): 17 | raise FileNotFoundError(f'file {caffemodel_path} not existm please check your file path') 18 | network = ppl_caffe_pb2.NetParameter() 19 | with open(prototxt_path) as f: 20 | text_format.Merge(f.read(), network) 21 | weight = ppl_caffe_pb2.NetParameter() 22 | with open(caffemodel_path, 'rb') as f: 23 | weight.ParseFromString(f.read()) 24 | 25 | network = de_inplace(network) 26 | 27 | for i in network.layer: 28 | for j in weight.layer: 29 | if i.name == j.name: 30 | i.ClearField('blobs') 31 | i.blobs.MergeFrom(j.blobs) 32 | break 33 | 34 | network = merge_batchnorm_scale(network) 35 | return network 36 | 37 | def build(self, prototxt_path: str, caffemodel_path: str) -> BaseGraph: 38 | network = self.load_graph_and_format(prototxt_path, caffemodel_path) 39 | graph = BaseGraph(name=network.name, built_from=NetworkFramework.CAFFE) 40 | input_shape = get_input_shape(network) 41 | input_names = list(input_shape.keys()) 42 | 43 | activation_shape = input_shape 44 | top_name_set = set() 45 | for layer in network.layer: 46 | if layer.type not in caffe_import_map: 47 | logger.error(f'{layer.type} Caffe OP is not supported in PPQ import parser yet') 48 | raise NotImplementedError(f'{layer.type} Caffe OP is not supported in PPQ import parser yet') 49 | input_shape = [activation_shape[k] for k in layer.bottom] 50 | caffe_layer = caffe_import_map[layer.type](graph, layer, input_shape) 51 | graph = caffe_layer.trans() 52 | activation_shape.update([(k, v) for k, v in zip(layer.top, caffe_layer.out_shape)]) 53 | 54 | # statistic top_name and get final out var name 55 | for name in layer.bottom: 56 | if name in top_name_set: 57 | top_name_set.remove(name) 58 | for name in layer.top: 59 | top_name_set.add(name) 60 | 61 | # add input and output for graph 62 | try: 63 | for var_name in input_names: 64 | if var_name not in graph.variables: continue 65 | graph.inputs[var_name] = graph.variables[var_name] 66 | for var_name in top_name_set: 67 | graph.outputs[var_name] = graph.variables[var_name] 68 | except KeyError as e: 69 | raise KeyError( 70 | 'seems you got an input/output variable that is not linked to any operation.') 71 | return graph 72 | -------------------------------------------------------------------------------- /ppq/parser/extension.py: -------------------------------------------------------------------------------- 1 | from ppq.core import QuantizationStates 2 | from ppq.IR import BaseGraph, GraphExporter 3 | from ppq.IR.quantize import QuantableOperation 4 | 5 | 6 | class ExtensionExporter(GraphExporter): 7 | """ExtensionExporter is an empty exporter for you to implement customized 8 | logic. rewrite function export in order to dump ppq graph to disk. 9 | 10 | use export_ppq_graph(..., platform=TargetPlatform.EXTENSION) to invoke this class. 11 | 12 | Args: 13 | GraphExporter ([type]): [description] 14 | """ 15 | 16 | def __init__(self) -> None: 17 | super().__init__() 18 | 19 | def export(self, file_path: str, graph: BaseGraph, config_path: str = None): 20 | """Sample Export Function -- export all quantization params into txt""" 21 | 22 | if config_path is None: 23 | raise ValueError('Can not export ppq quantization params, cause configuration path is empty.') 24 | with open(file=config_path, mode='w') as file: 25 | 26 | for op in graph.operations.values(): 27 | if not isinstance(op, QuantableOperation): continue 28 | 29 | for cfg, var in op.config_with_variable: 30 | if QuantizationStates.can_export(cfg.state): 31 | 32 | pass 33 | -------------------------------------------------------------------------------- /ppq/parser/native.py: -------------------------------------------------------------------------------- 1 | from pickle import dump, load 2 | 3 | from ppq.core import PPQ_CONFIG 4 | from ppq.IR import BaseGraph, GraphExporter 5 | from ppq.IR.base.graph import GraphBuilder 6 | 7 | 8 | class NativeExporter(GraphExporter): 9 | def __init__(self) -> None: 10 | super().__init__() 11 | def export(self, file_path: str, graph: BaseGraph, 12 | config_path: str = None, dump_value: bool = True): 13 | def dump_elements_to_file(file, elements: list): 14 | for element in elements: dump(element, file) 15 | 16 | with open(file_path, 'wb') as file: 17 | dump_elements_to_file(file, elements=[ 18 | 'PPQ GRAPH DEFINITION', # PPQ Signature. 19 | PPQ_CONFIG.VERSION, # PPQ Signature. 20 | graph, 21 | ]) 22 | 23 | class NativeImporter(GraphBuilder): 24 | def __init__(self) -> None: 25 | super().__init__() 26 | 27 | def build(self, file_path: str, **kwargs) -> BaseGraph: 28 | def load_elements_from_file(file, num_of_elements: int) -> list: 29 | try: return [load(file) for _ in range(num_of_elements)] 30 | except EOFError as e: 31 | raise Exception('File format parsing error. Unexpected EOF found.') 32 | 33 | with open(file_path, 'rb') as file: 34 | signature, version, graph = load_elements_from_file(file, 3) 35 | if signature != 'PPQ GRAPH DEFINITION': 36 | raise Exception('File format parsing error. Graph Signature has been damaged.') 37 | if str(version) > PPQ_CONFIG.VERSION: 38 | print(f'\033[31mWarning: Dump file is created by PPQ({str(version)}), ' 39 | f'however you are using PPQ({PPQ_CONFIG.VERSION}).\033[0m') 40 | 41 | assert isinstance(graph, BaseGraph), ( 42 | 'File format parsing error. Graph Definition has been damaged.') 43 | try: 44 | for op in graph.operations.values(): 45 | input_copy, _ = op.inputs.copy(), op.inputs.clear() 46 | for name in input_copy: op.inputs.append(graph.variables[name]) 47 | output_copy, _ = op.outputs.copy(), op.outputs.clear() 48 | for name in output_copy: op.outputs.append(graph.variables[name]) 49 | 50 | for var in graph.variables.values(): 51 | dest_copy, _ = var.dest_ops.copy(), var.dest_ops.clear() 52 | for name in dest_copy: var.dest_ops.append(graph.operations[name]) 53 | if var.source_op is not None: 54 | var.source_op = graph.operations[var.source_op] 55 | 56 | graph._graph_inputs = {name: graph.variables[name] for name in graph._graph_inputs} 57 | graph._graph_outputs = {name: graph.variables[name] for name in graph._graph_outputs} 58 | except Exception as e: 59 | raise Exception('File format parsing error. Graph Definition has been damaged.') 60 | return graph 61 | -------------------------------------------------------------------------------- /ppq/parser/ncnn_exporter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from ppq.core import (DataType, NetworkFramework, QuantizationProperty, 5 | QuantizationStates) 6 | from ppq.IR import BaseGraph, GraphExporter, QuantableOperation 7 | 8 | from .caffe_exporter import CaffeExporter 9 | from .onnx_exporter import OnnxExporter 10 | from .util import convert_value 11 | 12 | 13 | class NCNNExporter(GraphExporter): 14 | def export_quantization_config(self, config_path: str, graph: BaseGraph): 15 | fd = open(config_path, 'w+') 16 | topo_order = graph.topological_sort() 17 | for op in topo_order: 18 | if op.is_computing_op and isinstance(op, QuantableOperation): 19 | fd.write(f'{op.name}_param_0 ') 20 | param_cfg = op.config.input_quantization_config[1] 21 | if not param_cfg.can_export(): continue 22 | 23 | assert param_cfg.state in {QuantizationStates.BAKED, QuantizationStates.ACTIVATED}\ 24 | and param_cfg.observer_algorithm in {'minmax', 'Minmax'} and \ 25 | param_cfg.policy.has_property(QuantizationProperty.PER_CHANNEL) 26 | # a workaround for depthwise conv in ncnn 27 | # will cause mis-alignment between ppq and ncnn 28 | if op.type == 'Conv' and op.attributes.get('group', 1) > 1: 29 | group = op.attributes.get('group', 1) 30 | scale = param_cfg.scale.reshape(group, -1).max(dim=1)[0] 31 | else: 32 | scale = param_cfg.scale 33 | scale = convert_value(1 / scale, False, DataType.FP32) 34 | for s in scale: 35 | fd.write('%f '% s) 36 | fd.write('\n') 37 | 38 | for op in topo_order: 39 | if op.is_computing_op and isinstance(op, QuantableOperation): 40 | fd.write(f'{op.name} ') 41 | input_cfg = op.config.input_quantization_config[0] 42 | assert input_cfg.state == QuantizationStates.ACTIVATED and\ 43 | input_cfg.policy.has_property(QuantizationProperty.PER_TENSOR) 44 | scale = convert_value(1 / input_cfg.scale, True, DataType.FP32) 45 | fd.write('%f '% scale) 46 | fd.write('\n') 47 | fd.close() 48 | 49 | def export(self, file_path: str, graph: BaseGraph, config_path: str = None, input_shapes: List[List[int]] = [[1, 3, 224, 224]]): 50 | if config_path is not None: 51 | self.export_quantization_config(config_path, graph) 52 | 53 | _, ext = os.path.splitext(file_path) 54 | if ext == '.onnx': 55 | exporter = OnnxExporter() 56 | exporter.export(file_path=file_path, graph=graph, config_path=None) 57 | elif ext in {'.prototxt', '.caffemodel'}: 58 | exporter = CaffeExporter() 59 | exporter.export(file_path=file_path, graph=graph, config_path=None, input_shapes=input_shapes) 60 | 61 | # no pre-determined export format, we export according to the 62 | # original model format 63 | elif graph._built_from == NetworkFramework.CAFFE: 64 | exporter = CaffeExporter() 65 | exporter.export(file_path=file_path, graph=graph, config_path=None, input_shapes=input_shapes) 66 | 67 | elif graph._built_from == NetworkFramework.ONNX: 68 | exporter = OnnxExporter() 69 | exporter.export(file_path=file_path, graph=graph, config_path=None) 70 | -------------------------------------------------------------------------------- /ppq/parser/ppl.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from ppq.core import (DataType, QuantizationProperty, QuantizationStates, 4 | TargetPlatform, TensorQuantizationConfig) 5 | from ppq.IR import BaseGraph 6 | from ppq.IR.quantize import QuantableOperation 7 | 8 | from .onnx_exporter import OnnxExporter 9 | from .util import convert_value 10 | 11 | 12 | def convert_type(platform: TargetPlatform) -> str: 13 | if platform == TargetPlatform.PPL_CUDA_INT8: return 'INT8' 14 | if platform == TargetPlatform.SOI: return None 15 | if platform == TargetPlatform.FP32: return None 16 | raise TypeError(f'Unsupported platform type. ({str(platform)})') 17 | 18 | 19 | class PPLBackendExporter(OnnxExporter): 20 | def export_quantization_config(self, config_path: str, graph: BaseGraph): 21 | var_quant_info_recorder, op_platform_recorder = {}, {} 22 | for operation in graph.operations.values(): 23 | if not isinstance(operation, QuantableOperation): continue 24 | for config, var in operation.config_with_variable: 25 | if not config.can_export(): continue 26 | 27 | # PATCH 2021.11.25 28 | # REMOVE BIAS FROM CONFIGURATION 29 | if config.num_of_bits > 8: continue 30 | 31 | if config.state in { 32 | QuantizationStates.FP32, 33 | }: continue 34 | # Simply override recorder is acceptable here, 35 | # we do not support mix precision quantization for CUDA backend now. 36 | # All configurations for this variable should keep identical towards each other. 37 | 38 | if config.state == QuantizationStates.PASSIVE and var.name in var_quant_info_recorder: continue 39 | var_quant_info_recorder[var.name] = config 40 | 41 | # ready to render config to json. 42 | for var in var_quant_info_recorder: 43 | config = var_quant_info_recorder[var] 44 | assert isinstance(config, TensorQuantizationConfig) 45 | tensorwise = config.policy.has_property(QuantizationProperty.PER_TENSOR) 46 | var_quant_info_recorder[var] = { 47 | 'bit_width' : config.num_of_bits, 48 | 'per_channel': config.policy.has_property(QuantizationProperty.PER_CHANNEL), 49 | 'quant_flag' : True, 50 | 'sym' : config.policy.has_property(QuantizationProperty.SYMMETRICAL), 51 | 'scale' : convert_value(config.scale, tensorwise, DataType.FP32), 52 | 'zero_point' : convert_value(config.offset, tensorwise, DataType.INT32), 53 | 'tensor_min' : convert_value(config.scale * (config.quant_min - config.offset), tensorwise, DataType.FP32), 54 | 'tensor_max' : convert_value(config.scale * (config.quant_max - config.offset), tensorwise, DataType.FP32), 55 | 'q_min' : config.quant_min, 56 | 'q_max' : config.quant_max, 57 | 'hash' : config.__hash__(), 58 | 'dominator' : config.dominated_by.__hash__() 59 | } 60 | 61 | for op in graph.operations.values(): 62 | if convert_type(op.platform) is not None: 63 | op_platform_recorder[op.name] = { 64 | 'data_type': convert_type(op.platform) 65 | } 66 | 67 | exports = { 68 | 'quant_info': var_quant_info_recorder, 69 | 'op_info': op_platform_recorder} 70 | 71 | with open(file=config_path, mode='w') as file: 72 | json.dump(exports, file, indent=4) 73 | -------------------------------------------------------------------------------- /ppq/parser/qnn_exporter.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from typing import List 4 | 5 | from ppq.core import (DataType, QuantizationStates, 6 | QuantizationVisibility, NetworkFramework, ppq_warning) 7 | from ppq.IR import BaseGraph, GraphExporter 8 | from ppq.IR.quantize import QuantableOperation 9 | 10 | from .onnx_exporter import OnnxExporter 11 | from .caffe_exporter import CaffeExporter 12 | from .util import convert_value 13 | 14 | 15 | class QNNDSPExporter(GraphExporter): 16 | def export_quantization_config(self, config_path: str, graph: BaseGraph): 17 | activation_info, param_info = {}, {} 18 | topo_order = graph.topological_sort() 19 | for operation in topo_order: 20 | if not isinstance(operation, QuantableOperation): continue 21 | for config, var in operation.config_with_variable: 22 | if not QuantizationStates.can_export(config.state): 23 | raise PermissionError( 24 | 'Can not export quant config cause not all quantization configurations ' 25 | 'have been correctly initialized(or some of them has been deactivated). ' 26 | f'Operation {operation.name} has an invalid quantization state({config.state}) ' 27 | f'at variable {var.name}.') 28 | 29 | if config.visibility == QuantizationVisibility.INTERNAL: continue 30 | if config.state in { 31 | QuantizationStates.FP32, 32 | QuantizationStates.SOI 33 | }: continue 34 | 35 | if var.source_op is not None and var.source_op.type in {"Softmax", "Sigmoid"}: 36 | if config.dominated_by == config: # changeable. 37 | # fix output range greater than 1 38 | config.scale = torch.clamp(config.scale, max=1.0 / (config.quant_max - config.quant_min)) 39 | 40 | if config.state == QuantizationStates.PASSIVE and var.name in activation_info: continue 41 | info = [{ 42 | 'bitwidth': config.num_of_bits, 43 | 'max' : convert_value(config.scale * (config.quant_max - config.offset), True, DataType.FP32), 44 | 'min' : convert_value(config.scale * (config.quant_min - config.offset), True, DataType.FP32), 45 | 'offset' : convert_value(config.offset, True, DataType.INT32), 46 | 'scale' : convert_value(config.scale, True, DataType.FP32) 47 | }] 48 | if var.is_parameter: 49 | param_info[var.name] = info 50 | else: 51 | activation_info[var.name] = info 52 | 53 | exports = { 54 | 'activation_encodings': activation_info, 55 | 'param_encodings': param_info 56 | } 57 | 58 | with open(file=config_path, mode='w') as file: 59 | json.dump(exports, file, indent=4) 60 | 61 | 62 | def export(self, file_path: str, graph: BaseGraph, config_path: str = None, input_shapes: List[List[int]] = [[1, 3, 224, 224]]): 63 | if config_path is not None: 64 | self.export_quantization_config(config_path, graph) 65 | if graph._built_from == NetworkFramework.CAFFE: 66 | exporter = CaffeExporter() 67 | exporter.export(file_path=file_path, graph=graph, config_path=None, input_shapes=input_shapes) 68 | elif graph._built_from == NetworkFramework.ONNX: 69 | exporter = OnnxExporter() 70 | exporter.export(file_path=file_path, graph=graph, config_path=None) 71 | -------------------------------------------------------------------------------- /ppq/parser/util.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import numpy as np 4 | import torch 5 | from ppq.core import DataType, convert_any_to_numpy 6 | 7 | 8 | def convert_value( 9 | value: Union[int, float, np.ndarray, torch.Tensor], 10 | export_as_float: bool, dtype: DataType = DataType.FP32) -> Union[float, list]: 11 | """Converting value from any to python native data dtype, ready for export. 12 | 13 | Args: 14 | value (Union[int, float, np.ndarray, torch.Tensor]): exporting value. 15 | export_as_list (bool): export value as a list. 16 | dtype (DataType, optional): exporting dtype. 17 | 18 | Returns: 19 | Union[float, list]: Converted value 20 | """ 21 | if dtype not in {DataType.FP32, DataType.INT32}: 22 | raise ValueError(f'Can Only export dtype fp32 and int32, ' 23 | f'while you are requiring to dump a {dtype.name} value') 24 | value = convert_any_to_numpy(value, accept_none=False) 25 | value = value.astype(dtype=DataType.to_numpy(dtype)) 26 | if export_as_float: 27 | value = value.item() 28 | assert type(value) in {int, float}, ( 29 | f'Trying to dump a tensorwise quantization value {value}. ' 30 | f'It is Expected to be a int or float value, while {type(value)} was given') 31 | return value 32 | else: 33 | value = convert_any_to_numpy(value, accept_none=False) 34 | return value.tolist() 35 | -------------------------------------------------------------------------------- /ppq/quantization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/quantization/__init__.py -------------------------------------------------------------------------------- /ppq/quantization/algorithm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenPPL/ppq/e39eecb9f7e5f017c28f180cb423f8a685c3db48/ppq/quantization/algorithm/__init__.py -------------------------------------------------------------------------------- /ppq/quantization/algorithm/exprimental.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch 6 | from ppq.core import TensorQuantizationConfig 7 | from ppq.executor import TorchQuantizeDelegator 8 | from ppq.quantization.qfunction import PPQLinearQuantFunction 9 | from ppq.utils.ema import EMARecorder 10 | 11 | 12 | class BanditDelegator(TorchQuantizeDelegator): 13 | """带有多臂赌博机的量化代理,从 ppq 0.6.2 版本后,我们引入 多臂赌博机算法训练 scale 与 offset。在未来我们可能还会引入其他 14 | 类似的算法,例如UCB,马尔可夫蒙特卡洛估计等。 15 | 16 | 引入这些算法的原因是我们注意到 scale 与 offset 的导数非常不靠谱 17 | 为此我们引入简单的强化学习,直接估计P(r | scale=s, context) 18 | 即再给定上下文 context 的情况下,选取当前 scale 为 s,获利的概率 19 | 20 | Quantization with multi-arm bandit. 21 | 22 | Multi-arm bandits are introduced since PPQ 0.6.2 for training 23 | quantization scale and offset. 24 | """ 25 | def __init__(self, arms: List[float], config: TensorQuantizationConfig) -> None: 26 | if len(arms) < 2: raise ValueError('Can not initialize bandit with less than 2 arms.') 27 | self.e = 0.1 28 | self.arms = arms 29 | self.num_of_arms = len(arms) 30 | self.rewards = [EMARecorder() for _ in range(self.num_of_arms)] 31 | self.rewards[0].push(1) 32 | self.last_selected = 0 33 | self.reference = config.scale.clone() 34 | self.config = config 35 | self.decay = 0.99 36 | 37 | def roll(self) -> int: 38 | if random.random() > self.e: selected = random.randint(0, len(self.arms) - 1) 39 | else: selected = np.argmax([ema.pop() for ema in self.rewards]) 40 | self.last_selected = selected 41 | return selected 42 | 43 | def mark(self, rewards: float): 44 | self.rewards[self.last_selected].push(rewards) 45 | 46 | def finalize(self) -> bool: 47 | self.config.scale = self.reference * self.arms[np.argmax([ema.pop() for ema in self.rewards])] 48 | 49 | def withdraw(self): 50 | self.config.scale = self.reference 51 | 52 | def __call__(self, tensor: torch.Tensor, 53 | config: TensorQuantizationConfig) -> torch.Tensor: 54 | config.scale = self.reference * self.arms[self.roll()] 55 | return PPQLinearQuantFunction(tensor, config) -------------------------------------------------------------------------------- /ppq/quantization/analyse/__init__.py: -------------------------------------------------------------------------------- 1 | from .graphwise import graphwise_error_analyse, statistical_analyse 2 | from .layerwise import (layerwise_error_analyse, parameter_analyse, 3 | variable_analyse) 4 | -------------------------------------------------------------------------------- /ppq/quantization/measure/__init__.py: -------------------------------------------------------------------------------- 1 | from .cosine import torch_cosine_similarity, numpy_cosine_similarity, torch_cosine_similarity_as_loss 2 | from .statistic import torch_KL_divergence 3 | from .norm import torch_mean_square_error, torch_snr_error 4 | -------------------------------------------------------------------------------- /ppq/quantization/measure/cosine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from numpy import dot, ndarray 3 | from numpy.linalg import norm 4 | 5 | 6 | def torch_cosine_similarity(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str='mean') -> torch.Tensor: 7 | if y_pred.shape != y_real.shape: 8 | raise ValueError(f'Can not compute mse loss for tensors with different shape. ' 9 | f'({y_pred.shape} and {y_real.shape})') 10 | reduction = str(reduction).lower() 11 | 12 | if y_pred.ndim == 1: 13 | y_pred = y_pred.unsqueeze(0) 14 | y_real = y_real.unsqueeze(0) 15 | 16 | y_pred = y_pred.flatten(start_dim=1).float() 17 | y_real = y_real.flatten(start_dim=1).float() 18 | 19 | cosine_sim = torch.cosine_similarity(y_pred, y_real, dim=-1) 20 | 21 | if reduction == 'mean': 22 | return torch.mean(cosine_sim) 23 | elif reduction == 'sum': 24 | return torch.sum(cosine_sim) 25 | elif reduction == 'none': 26 | return cosine_sim 27 | else: 28 | raise ValueError(f'Unsupported reduction method.') 29 | 30 | 31 | def numpy_cosine_similarity( 32 | x: ndarray, y: ndarray) -> ndarray: 33 | return dot(x, y) / (norm(x) * norm(y)) 34 | 35 | 36 | def torch_cosine_similarity_as_loss( 37 | y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str='mean') -> torch.Tensor: 38 | return 1 - torch_cosine_similarity(y_pred=y_pred, y_real=y_real, reduction=reduction) 39 | -------------------------------------------------------------------------------- /ppq/quantization/measure/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def torch_mean_square_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str='mean') -> torch.Tensor: 4 | """ 5 | Compute mean square error between y_pred(tensor) and y_real(tensor) 6 | 7 | MSE error can be calcualted as following equation: 8 | 9 | MSE(x, y) = (x - y) ^ 2 10 | 11 | if x and y are matrixs, MSE error over matrix should be the mean value of MSE error over all elements. 12 | 13 | MSE(X, Y) = mean((X - Y) ^ 2) 14 | 15 | By this equation, we can easily tell that MSE is an symmtrical measurement: 16 | MSE(X, Y) == MSE(Y, X) 17 | MSE(0, X) == X ^ 2 18 | 19 | Args: 20 | y_pred (torch.Tensor): _description_ 21 | y_real (torch.Tensor): _description_ 22 | reduction (str, optional): _description_. Defaults to 'mean'. 23 | 24 | Raises: 25 | ValueError: _description_ 26 | ValueError: _description_ 27 | 28 | Returns: 29 | torch.Tensor: _description_ 30 | """ 31 | if y_pred.shape != y_real.shape: 32 | raise ValueError(f'Can not compute mse loss for tensors with different shape. ' 33 | f'({y_pred.shape} and {y_real.shape})') 34 | reduction = str(reduction).lower() 35 | 36 | if y_pred.ndim == 1: 37 | y_pred = y_pred.unsqueeze(0) 38 | y_real = y_real.unsqueeze(0) 39 | 40 | diff = torch.pow(y_pred - y_real, 2).flatten(start_dim=1) 41 | mse = torch.mean(diff, dim=-1) 42 | 43 | if reduction == 'mean': 44 | return torch.mean(mse) 45 | elif reduction == 'sum': 46 | return torch.sum(mse) 47 | elif reduction == 'none': 48 | return mse 49 | else: 50 | raise ValueError(f'Unsupported reduction method.') 51 | 52 | def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str='mean') -> torch.Tensor: 53 | """ 54 | Compute SNR between y_pred(tensor) and y_real(tensor) 55 | 56 | SNR can be calcualted as following equation: 57 | 58 | SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 59 | 60 | if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. 61 | 62 | SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) 63 | 64 | Args: 65 | y_pred (torch.Tensor): _description_ 66 | y_real (torch.Tensor): _description_ 67 | reduction (str, optional): _description_. Defaults to 'mean'. 68 | 69 | Raises: 70 | ValueError: _description_ 71 | ValueError: _description_ 72 | 73 | Returns: 74 | torch.Tensor: _description_ 75 | """ 76 | if y_pred.shape != y_real.shape: 77 | raise ValueError(f'Can not compute snr loss for tensors with different shape. ' 78 | f'({y_pred.shape} and {y_real.shape})') 79 | reduction = str(reduction).lower() 80 | 81 | if y_pred.ndim == 1: 82 | y_pred = y_pred.unsqueeze(0) 83 | y_real = y_real.unsqueeze(0) 84 | 85 | y_pred = y_pred.flatten(start_dim=1) 86 | y_real = y_real.flatten(start_dim=1) 87 | 88 | noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) 89 | signal_power = torch.pow(y_real, 2).sum(dim=-1) 90 | snr = (noise_power) / (signal_power + 1e-7) 91 | 92 | if reduction == 'mean': 93 | return torch.mean(snr) 94 | elif reduction == 'sum': 95 | return torch.sum(snr) 96 | elif reduction == 'none': 97 | return snr 98 | else: 99 | raise ValueError(f'Unsupported reduction method.') 100 | -------------------------------------------------------------------------------- /ppq/quantization/measure/statistic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def torch_KL_divergence(hist: torch.Tensor, ref_hist: torch.Tensor, eps=1e-30) -> float: 4 | if hist.ndim != 1 or ref_hist.ndim != 1: raise ValueError( 5 | 'Only 1 dimension tensor can compute KL divergence with another tensor. '\ 6 | f'While your input hist has dimension {hist.ndim} and ref_hist has dimension {ref_hist.ndim}') 7 | if len(hist) != len(ref_hist): raise ValueError( 8 | 'Can not compute KL divergence, len(hist) != len(ref_hist') 9 | 10 | # here we compute KL divergence at float64 precision, make sure your hist and ref_hist are stored at cpu. 11 | # otherwise it might be very slow. 12 | return torch.dot(hist.double(), torch.log10(hist.double() + eps) - torch.log10(ref_hist.double() + eps)).item() 13 | -------------------------------------------------------------------------------- /ppq/quantization/observer/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import Any 3 | 4 | from ppq.core import (QuantizationStates, TensorQuantizationConfig, 5 | ppq_debug_function) 6 | from ppq.IR import Variable 7 | 8 | 9 | class BaseTensorObserver(metaclass=ABCMeta): 10 | def __init__(self, watch_on: Variable, quant_cfg: TensorQuantizationConfig): 11 | self._watch_on = watch_on 12 | self._quant_cfg = quant_cfg 13 | 14 | @ abstractmethod 15 | def observe(self, value: Any): 16 | raise NotImplementedError('Implement this function first.') 17 | 18 | @ abstractmethod 19 | def render_quantization_config(self): 20 | raise NotImplementedError('Implement this function first.') 21 | 22 | def __str__(self) -> str: 23 | return 'PPQ Tensor Observer (' + self.__class__.__name__ + ') mount on variable ' + \ 24 | self._watch_on.name + ' observing algorithm: ' + self._quant_cfg.observer_algorithm 25 | 26 | @ ppq_debug_function 27 | def report(self) -> str: 28 | if self._quant_cfg.state == QuantizationStates.ACTIVATED: 29 | return f'Observer on Variable {self._watch_on.name}, '\ 30 | f'computed scale: {self._quant_cfg.scale}, computed offset: {self._quant_cfg.offset}\n' 31 | -------------------------------------------------------------------------------- /ppq/quantization/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .baking import ParameterBakingPass 2 | from .base import (QuantizationOptimizationPass, 3 | QuantizationOptimizationPipeline) 4 | from .calibration import (IsotoneCalibrationPass, PPLDSPTIReCalibrationPass, 5 | RuntimeCalibrationPass) 6 | from .equalization import (ActivationEqualizationPass, ChannelwiseSplitPass, 7 | LayerwiseEqualizationPass) 8 | from .extension import ExtensionPass 9 | from .legacy import AdaroundPass 10 | from .morph import (GRUSplitPass, HorizontalLayerSplitPass, MetaxGemmSplitPass, 11 | NCNNFormatGemmPass, NXPResizeModeChangePass) 12 | from .parameters import ParameterQuantizePass, PassiveParameterQuantizePass 13 | from .refine import (MishFusionPass, NxpInputRoundingRefinePass, 14 | NxpQuantizeFusionPass, QuantAlignmentPass, 15 | QuantizeFusionPass, QuantizeSimplifyPass, SwishFusionPass) 16 | from .ssd import SSDEqualizationPass 17 | from .training import BiasCorrectionPass, LearnedStepSizePass, RoundTuningPass 18 | -------------------------------------------------------------------------------- /ppq/quantization/optim/baking.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | from ppq.core import empty_ppq_cache 4 | from ppq.executor import BaseGraphExecutor 5 | from ppq.IR import BaseGraph, QuantableOperation 6 | from ppq.quantization.qfunction import PPQuantFunction 7 | 8 | from .base import QuantizationOptimizationPass 9 | 10 | 11 | class ParameterBakingPass(QuantizationOptimizationPass): 12 | """ParameterBakingPass is a useful tool for quantization simulation 13 | acceleration. By default quantizer will bake network parameters once all 14 | quantization procedures are finished. For a typical Convolution layer or 15 | Gemm layer, which has a non-empty bias tensor, ParameterBakingPass will 16 | speed up the layer execution by 30%-50%. 17 | 18 | ParameterBakingPass will rewrite layer parameters with their quantized version, 19 | the quantization procedure will strictly follow layer quantization configuration. 20 | Once the quantization process finished, this pass will change all parameter quantization configuration states 21 | to QuantizationStates.BAKED. 22 | 23 | State QuantizationStates.BAKED indicates corresponding tensor has been pre-quantized and its value 24 | can be used without further quantization, executor will directly use a baked value during execution. 25 | 26 | ATTENTION: value is baked inplace, so to say it will rewrite all network parameters. 27 | ATTENTION: For platforms using int32 accumulator, a float32 bias tensor might lose precision 28 | during the simulation. If you want PPQ simulator to have a consistent result with hardware, it is 29 | highly-recommended to calling ParameterBakingPass before deployment, baking procedure will limit bias 30 | precision to 23 bits (float32 only has 23 fraction bits). 31 | Args: 32 | quantize_function (BaseQuantFunction): a BaseQuantFunction instance to quantize all parameters. 33 | """ 34 | def __init__(self) -> None: 35 | super().__init__(name='PPQ Parameter Baking Pass') 36 | self._quantize_function = PPQuantFunction 37 | 38 | @ empty_ppq_cache 39 | def optimize( 40 | self, 41 | graph: BaseGraph, 42 | **kwargs 43 | ) -> None: 44 | 45 | for _, operation in graph.operations.items(): 46 | if not isinstance(operation, QuantableOperation): continue 47 | operation.baking_parameters(self._quantize_function) 48 | -------------------------------------------------------------------------------- /ppq/quantization/optim/extension.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | from ppq.IR.base.graph import BaseGraph 3 | 4 | from ppq.executor import BaseGraphExecutor 5 | from ppq.IR import BaseGraph 6 | 7 | from .base import QuantizationOptimizationPass 8 | 9 | 10 | class ExtensionPass(QuantizationOptimizationPass): 11 | """ExtensionPass 并没有什么用,它就是告诉你你可以像这样写一个 pass。 你可以直接改写 ExtensionPass 12 | 的逻辑来实现你的功能,并将修改后的代码提交到 github. 13 | 14 | 不过比较我们已经为 ExtensionPass 创建了一个 TemplateSetting 用来给它传递参数 15 | 你可以去 ppq.api.setting.py 里面找到它 16 | 17 | There is nothing in ExtensionPass, it is literally an empty pass, 18 | -- just show you how to create your own pass. 19 | 20 | A TemplateSetting class has been created for passing parameter to this pass. 21 | You can find it in ppq.api.setting.py 22 | 23 | You can overwrite logic inside this pass. 24 | """ 25 | def __init__(self, parameter: str) -> None: 26 | self.parameter = parameter 27 | super().__init__(name='PPQ Extension Pass') 28 | 29 | def optimize( 30 | self, 31 | graph: BaseGraph, 32 | dataloader: Iterable, 33 | executor: BaseGraphExecutor, 34 | **kwargs 35 | ) -> None: 36 | assert isinstance(graph, BaseGraph) 37 | 38 | print('You are invoking Extension Pass now.') 39 | -------------------------------------------------------------------------------- /ppq/quantization/qfunction/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ppq.core import QuantizationProperty, TensorQuantizationConfig 3 | 4 | from .base import BaseQuantFunction 5 | from .floating import PPQFloatingQuantFunction 6 | from .linear import (PPQDyamicLinearQuantFunction, PPQLinearQuant_toInt, 7 | PPQLinearQuantFunction) 8 | 9 | 10 | def PPQuantFunction(tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor: 11 | """ 12 | ## PPQ 核心量化函数 13 | 14 | 根据 config 中描述的策略,量化给定的 tensor. 15 | 16 | 请注意 config.state 必须处在激活状态,该函数起作用。如果 config.state 处于 17 | INITIAL, FP32, PASSIVE_INIT 等未激活状态,该函数不进行任何处理,直接返回 tensor. 18 | 19 | ### 线性量化(QuantizationProperty.LINEAR): 20 | 21 | INT8 = Clip(Round((FP32 / scale))) 22 | 23 | ### 浮点量化(QuantizationProperty.FLOATING): 24 | 25 | FP8 = Clip(FP32_TO_FP8((FP32 / scale))) 26 | 27 | ### 动态线性量化(QuantizationProperty.DYNMAIC) 28 | 29 | scale = max(FP32) / 255 30 | 31 | INT8 = Clip(Round((FP32 / scale))) 32 | 33 | """ 34 | if tensor is None: raise ValueError('Tensor is empty.') 35 | if config.policy.has_property(QuantizationProperty.LINEAR): 36 | if not config.policy.has_property(QuantizationProperty.DYNAMIC): 37 | return PPQLinearQuantFunction(tensor, config) 38 | else: return PPQDyamicLinearQuantFunction(tensor, config) 39 | 40 | if config.policy.has_property(QuantizationProperty.FLOATING): 41 | return PPQFloatingQuantFunction(tensor, config) 42 | 43 | raise ValueError('Unexpected Quantization Property Found in PPQuantFunction. ' 44 | 'Do not konw how to quantize your config yet.') 45 | 46 | 47 | def PPQuantFunction_toInt(tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor: 48 | """ 49 | ## PPQ 核心量化函数 50 | 51 | 根据 config 中描述的策略,这个函数将会执行线性量化,动态量化 52 | 53 | 但是结果直接出来是整数 54 | """ 55 | 56 | if config.policy.has_property(QuantizationProperty.LINEAR): 57 | if not config.policy.has_property(QuantizationProperty.DYNAMIC): 58 | return PPQLinearQuant_toInt(tensor, config) 59 | 60 | raise ValueError('Unexpected Quantization Property Found in PPQuantFunction_toInt. ' 61 | 'Do not konw how to quantize your config yet.') 62 | 63 | 64 | __all__ = ['PPQuantFunction', 'PPQuantFunction_toInt', 'PPQDyamicLinearQuantFunction', 65 | 'PPQLinearQuantFunction', 'PPQFloatingQuantFunction', 'BaseQuantFunction', 66 | 'PPQLinearQuant_toInt'] 67 | -------------------------------------------------------------------------------- /ppq/quantization/qfunction/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from ppq.core import TensorQuantizationConfig 3 | from typing import Any, Callable 4 | 5 | 6 | class BaseQuantFunction(Callable, metaclass=ABCMeta): 7 | def __init__(self) -> None: 8 | pass 9 | 10 | @ abstractmethod 11 | def __call__(self, input_tensor: Any, quantization_config: TensorQuantizationConfig, **kwargs) -> Any: 12 | raise NotImplemented('Implement this first.') 13 | -------------------------------------------------------------------------------- /ppq/quantization/quantizer/MNNQuantizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | from ppq.core import (PASSIVE_OPERATIONS, OperationQuantizationConfig, 5 | QuantizationPolicy, QuantizationProperty, 6 | QuantizationStates, RoundingPolicy, TargetPlatform) 7 | from ppq.IR import BaseGraph, GraphCommandProcessor, Operation 8 | 9 | from .base import BaseQuantizer 10 | 11 | 12 | class MNNQuantizer(BaseQuantizer): 13 | def __init__( 14 | self, graph: Union[BaseGraph, GraphCommandProcessor] 15 | ) -> Union[torch.Tensor, list, dict]: 16 | super().__init__(graph=graph) 17 | self._num_of_bits = 8 18 | self._quant_min = - 127 19 | self._quant_max = + 127 20 | 21 | def init_quantize_config( 22 | self, operation: Operation) -> OperationQuantizationConfig: 23 | base_quant_config = self.create_default_quant_config( 24 | policy=self.quantize_policy, rounding=self.rounding_policy, 25 | op=operation, num_of_bits=self._num_of_bits, exponent_bits=0, 26 | quant_max=self._quant_max, quant_min=self._quant_min, 27 | observer_algorithm='percentile' 28 | ) 29 | 30 | if operation.type == 'Conv': 31 | assert operation.num_of_input > 0, 'Seems you got a Conv layer with no parameters.' 32 | 33 | if operation.inputs[1].is_parameter: 34 | conv_weight_config = base_quant_config.input_quantization_config[1] 35 | conv_weight_config.policy = QuantizationPolicy( 36 | QuantizationProperty.SYMMETRICAL + 37 | QuantizationProperty.LINEAR + 38 | QuantizationProperty.PER_CHANNEL 39 | ) 40 | conv_weight_config.channel_axis = 0 41 | conv_weight_config.observer_algorithm = 'minmax' 42 | 43 | if operation.num_of_input > 2: 44 | bias_config = base_quant_config.input_quantization_config[-1] 45 | bias_config.state = QuantizationStates.FP32 46 | 47 | if operation.type in PASSIVE_OPERATIONS: 48 | # Those op are not active op. 49 | base_quant_config.is_active_quant_op = False 50 | return base_quant_config 51 | 52 | @ property 53 | def target_platform(self) -> TargetPlatform: 54 | return TargetPlatform.MNN_INT8 55 | 56 | @ property 57 | def default_platform(self) -> TargetPlatform: 58 | return TargetPlatform.FP32 59 | 60 | @ property 61 | def quant_operation_types(self) -> set: 62 | return { 63 | 'Conv', 'Add', 'Gemm' 64 | } 65 | 66 | @ property 67 | def quantize_policy(self) -> QuantizationPolicy: 68 | return QuantizationPolicy( 69 | QuantizationProperty.SYMMETRICAL + 70 | QuantizationProperty.LINEAR + 71 | QuantizationProperty.PER_TENSOR 72 | ) 73 | 74 | @ property 75 | def rounding_policy(self) -> RoundingPolicy: 76 | return RoundingPolicy.ROUND_HALF_FAR_FORM_ZERO 77 | 78 | @ property 79 | def activation_fusion_types(self) -> set: 80 | return {'Relu', 'Clip', 'Swish', 'Clip', 'SoftPlus', 'Sigmoid'} 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /ppq/quantization/quantizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseQuantizer 2 | from .DSPQuantizer import PPL_DSP_Quantizer, PPL_DSP_TI_Quantizer 3 | from .MetaxQuantizer import MetaxChannelwiseQuantizer, MetaxTensorwiseQuantizer 4 | from .MyQuantizer import ExtQuantizer 5 | from .NXPQuantizer import NXP_Quantizer 6 | from .RKNNQuantizer import RKNN_PerChannelQuantizer, RKNN_PerTensorQuantizer 7 | from .PPLQuantizer import PPLCUDAQuantizer 8 | # from .TRTQuantizer import TensorRTQuantizer 9 | from .FPGAQuantizer import FPGAQuantizer 10 | from .NCNNQuantizer import NCNNQuantizer 11 | from .OpenvinoQuantizer import OpenvinoQuantizer 12 | from .TengineQuantizer import TengineQuantizer 13 | from .FP8Quantizer import GraphCoreQuantizer, TensorRTQuantizer_FP8 14 | from .TensorRTQuantizer import TensorRTQuantizer, TensorRTQuantizer_InputOnly 15 | from .AscendQuantizer import AscendQuantizer 16 | from .ORTQuantizer import OnnxruntimeQuantizer 17 | from .MNNQuantizer import MNNQuantizer -------------------------------------------------------------------------------- /ppq/samples/Imagenet/Utilities/Imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from .imagenet_util import (evaluate_mmlab_module_with_imagenet, 2 | evaluate_onnx_module_with_imagenet, 3 | evaluate_ppq_module_with_imagenet, 4 | evaluate_torch_module_with_imagenet, 5 | load_imagenet_from_directory) 6 | -------------------------------------------------------------------------------- /ppq/samples/Imagenet/Utilities/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | def shape_to_str(shape: List[int]) -> str: 5 | if len(shape) == 1: 6 | return str(shape[0]) 7 | string_builder = str(shape[0]) 8 | for s in shape[1: ]: 9 | string_builder += '_' + str(s) 10 | return string_builder -------------------------------------------------------------------------------- /ppq/samples/Onnxruntime/Example_Benchmark.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # 这个脚本向你展示了如何使用 Onnxruntime 对 PPQ 导出的模型进行推理 3 | # Onnxruntime 提供一系列 providers 实现不同硬件上的神经网络推理 4 | 5 | # CPUExecutionProvider, CUDAExecutionProvider 是 Onnxruntime 官方提供的 6 | # TensortExecutionProvider 是 Nvidia 提供的 7 | # 不同 Provider 对模型格式有不一样的要求,PPQ 导出的是 CPUExecutionProvider 格式的模型 8 | 9 | # Onnxruntime 没写 INT8 算子的 CUDA 实现,因此当你的模型使用 Onnxruntime 进行部署时,如果使用 10 | # CUDAExecutionProvider, 你无需考虑量化加速 11 | # --------------------------------------------------------------- 12 | 13 | import torchvision 14 | import torch 15 | import ppq 16 | import ppq.api as API 17 | 18 | calibration_dataloader = [torch.rand(size=[1, 3, 224, 224]).cuda()] 19 | model = torchvision.models.shufflenet_v2_x1_0().cuda() 20 | 21 | with API.ENABLE_CUDA_KERNEL(): 22 | quantized = API.quantize_torch_model( 23 | model=model, calib_dataloader=calibration_dataloader, 24 | calib_steps=8, input_shape=[1, 3, 224, 224], platform=ppq.TargetPlatform.ONNXRUNTIME) 25 | 26 | API.export_ppq_graph( 27 | quantized, platform=ppq.TargetPlatform.ONNXRUNTIME, 28 | graph_save_to='Quantized.onnx') 29 | 30 | API.export_ppq_graph( 31 | quantized, platform=ppq.TargetPlatform.ONNX, 32 | graph_save_to='FP32.onnx') 33 | 34 | from ppq.utils.OnnxruntimeUtil import Benchmark, Profile 35 | 36 | Benchmark('FP32.onnx', providers=['CPUExecutionProvider']) 37 | Benchmark('Quantized.onnx', providers=['CPUExecutionProvider']) 38 | 39 | Profile('FP32.onnx', providers=['CPUExecutionProvider']) 40 | Profile('Quantized.onnx', providers=['CPUExecutionProvider']) -------------------------------------------------------------------------------- /ppq/samples/Onnxruntime/Example_Fp32.py: -------------------------------------------------------------------------------- 1 | import onnxruntime 2 | import numpy as np 3 | 4 | # ------------------------------------------------------------------- 5 | # Onnxruntime 需要你提供一个 feed dict 和 output 的名字才能跑推理 6 | # feed dict 就是 input name: data 的形式表示的输入数据 7 | # output name 和 input name 你如果不知道的话,用可视化工具打开 onnx 文件就可以看到了。 8 | # ------------------------------------------------------------------- 9 | 10 | MODEL = 'model.onnx' 11 | FEED_DICT = {'input name': np.zeros(shape=[1, 3, 224, 224])} 12 | OUTPUT_NAMES = ['output name'] 13 | 14 | session = onnxruntime.InferenceSession(MODEL, providers=['CUDAExecutionProvider']) 15 | result = session.run(OUTPUT_NAMES, FEED_DICT) 16 | -------------------------------------------------------------------------------- /ppq/samples/Onnxruntime/Example_PTQ.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # 这个脚本向你展示了如何使用 onnxruntime 对 PPQ 导出的模型进行推理 3 | # 你需要注意,Onnxruntime 可以运行各种各样的量化方案,但模型量化对 Onnxruntime 而言几乎无法起到加速作用 4 | # 你可以使用 Onnxruntime 来验证量化方案以及 ppq 量化的正确性,但这不是一个合理的部署平台 5 | # 修改 QUANT_PLATFROM 来使用不同的量化方案。 6 | 7 | # This Script export ppq internal graph to onnxruntime, 8 | # you should notice that onnx is designed as an Open Neural Network Exchange format. 9 | # It has the capbility to describe most of ppq's quantization policy including combinations of: 10 | # Symmtrical, Asymmtrical, POT, Per-channel, Per-Layer 11 | # However onnxruntime can not accelerate quantized model in most cases, 12 | # you are supposed to use onnxruntime for verifying your network quantization result only. 13 | # --------------------------------------------------------------- 14 | 15 | # For this onnx inference test, all test data is randomly picked. 16 | # If you want to use real data, just rewrite the defination of SAMPLES 17 | import onnxruntime 18 | import torch 19 | from ppq import * 20 | from ppq.api import * 21 | from tqdm import tqdm 22 | 23 | QUANT_PLATFROM = TargetPlatform.ONNXRUNTIME 24 | MODEL = 'model.onnx' 25 | INPUT_SHAPE = [1, 3, 224, 224] 26 | SAMPLES = [torch.rand(size=INPUT_SHAPE) for _ in range(256)] # rewirte this to use real data. 27 | DEVICE = 'cuda' 28 | FINETUNE = False 29 | QS = QuantizationSettingFactory.default_setting() 30 | EXECUTING_DEVICE = 'cuda' 31 | REQUIRE_ANALYZE = True 32 | 33 | # ------------------------------------------------------------------- 34 | # 下面向你展示了常用参数调节选项: 35 | # ------------------------------------------------------------------- 36 | QS.lsq_optimization = FINETUNE # 启动网络再训练过程,降低量化误差 37 | QS.lsq_optimization_setting.steps = 500 # 再训练步数,影响训练时间,500 步大概几分钟 38 | QS.lsq_optimization_setting.collecting_device = 'cuda' # 缓存数据放在那,cuda 就是放在 gpu,如果显存超了你就换成 'cpu' 39 | 40 | print('正准备量化你的网络,检查下列设置:') 41 | print(f'TARGET PLATFORM : {QUANT_PLATFROM.name}') 42 | print(f'NETWORK INPUTSHAPE : {INPUT_SHAPE}') 43 | 44 | # ENABLE CUDA KERNEL 会加速量化效率 3x ~ 10x,但是你如果没有装相应编译环境的话是编译不了的 45 | # 你可以尝试安装编译环境,或者在不启动 CUDA KERNEL 的情况下完成量化:移除 with ENABLE_CUDA_KERNEL(): 即可 46 | with ENABLE_CUDA_KERNEL(): 47 | qir = quantize_onnx_model( 48 | onnx_import_file=MODEL, calib_dataloader=SAMPLES, calib_steps=128, setting=QS, 49 | input_shape=INPUT_SHAPE, collate_fn=lambda x: x.to(EXECUTING_DEVICE), 50 | platform=QUANT_PLATFROM, do_quantize=True) 51 | 52 | # ------------------------------------------------------------------- 53 | # PPQ 计算量化误差时,使用信噪比的倒数作为指标,即噪声能量 / 信号能量 54 | # 量化误差 0.1 表示在整体信号中,量化噪声的能量约为 10% 55 | # 你应当注意,在 graphwise_error_analyse 分析中,我们衡量的是累计误差 56 | # 网络的最后一层往往都具有较大的累计误差,这些误差是其前面的所有层所共同造成的 57 | # 你需要使用 layerwise_error_analyse 逐层分析误差的来源 58 | # ------------------------------------------------------------------- 59 | print('正计算网络量化误差(SNR),最后一层的误差应小于 0.1 以保证量化精度:') 60 | reports = graphwise_error_analyse( 61 | graph=qir, running_device=EXECUTING_DEVICE, steps=32, 62 | dataloader=SAMPLES, collate_fn=lambda x: x.to(EXECUTING_DEVICE)) 63 | for op, snr in reports.items(): 64 | if snr > 0.1: ppq_warning(f'层 {op} 的累计量化误差显著,请考虑进行优化') 65 | 66 | if REQUIRE_ANALYZE: 67 | print('正计算逐层量化误差(SNR),每一层的独立量化误差应小于 0.1 以保证量化精度:') 68 | layerwise_error_analyse(graph=qir, running_device=EXECUTING_DEVICE, 69 | interested_outputs=None, 70 | dataloader=SAMPLES, collate_fn=lambda x: x.to(EXECUTING_DEVICE)) 71 | 72 | print('网络量化结束,正在生成目标文件:') 73 | export_ppq_graph( 74 | graph=qir, platform=QUANT_PLATFROM, 75 | graph_save_to = 'quantized.onnx') -------------------------------------------------------------------------------- /ppq/samples/Openvino/Example_Benchmark.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # 这个脚本向你展示了如何使用 openvino 对 PPQ 导出的模型进行推理 3 | # 你需要注意,openvino 也可以运行各种各样的量化方案,你甚至可以用 tensorRT 的 policy 4 | # 但总的来说,openvino 需要非对称量化的 activation 和对称量化的 weights 5 | # 现在的写法针对单输入网络哦,多输入的你得自己改改 6 | # --------------------------------------------------------------- 7 | 8 | # For this onnx inference test, all test data is randomly picked. 9 | # If you want to use real data, just rewrite the defination of SAMPLES 10 | import numpy as np 11 | import openvino 12 | import torch 13 | from tqdm import tqdm 14 | import time 15 | 16 | from ppq import * 17 | from ppq.api import * 18 | 19 | QUANT_PLATFROM = TargetPlatform.OPENVINO_INT8 20 | BATCHSIZE = 1 21 | DEVICE = 'cuda' 22 | INPUTSHAPE = [BATCHSIZE, 3, 640, 640] 23 | SAMPLES = [torch.rand(size=INPUTSHAPE) for _ in range(256)] 24 | BENCHMARK_SAMPLES = 512 25 | MODEL_PATH = 'Models/yolox_s.onnx' 26 | VALIDATION = False 27 | 28 | with ENABLE_CUDA_KERNEL(): 29 | quantized = quantize_onnx_model( 30 | onnx_import_file=MODEL_PATH, calib_dataloader=SAMPLES, collate_fn=lambda x: x.to(DEVICE), 31 | calib_steps=32, input_shape=INPUTSHAPE, 32 | setting=QuantizationSettingFactory.default_setting(), 33 | platform=QUANT_PLATFROM) 34 | 35 | graphwise_error_analyse(graph=quantized, running_device='cuda', 36 | dataloader=SAMPLES, collate_fn=lambda x: x.cuda(), steps=32) 37 | 38 | export_ppq_graph( 39 | graph=quantized, platform=TargetPlatform.ONNX, 40 | graph_save_to='FP32.onnx') 41 | 42 | export_ppq_graph( 43 | graph=quantized, platform=TargetPlatform.OPENVINO_INT8, 44 | graph_save_to='INT8.onnx') 45 | 46 | from ppq.utils.OpenvinoUtil import Benchmark 47 | Benchmark(ir_or_onnx_file='FP32.onnx', samples=500, jobs=4) 48 | Benchmark(ir_or_onnx_file='INT8.onnx', samples=500, jobs=4) -------------------------------------------------------------------------------- /ppq/samples/Openvino/Example_Fp32.py: -------------------------------------------------------------------------------- 1 | import openvino.runtime 2 | import torch 3 | from ppq import * 4 | from tqdm import tqdm 5 | 6 | MODEL = 'models\\resnet18.onnx' 7 | INPUT_SHAPE = [1, 3, 224, 224] 8 | SAMPLES = [torch.rand(size=INPUT_SHAPE) for _ in range(256)] # rewirte this to use real data. 9 | 10 | # ------------------------------------------------------------------- 11 | # 启动 openvino 进行推理 12 | # ------------------------------------------------------------------- 13 | openvino_executor = openvino.runtime.Core() 14 | openvino_results = [] 15 | model = openvino_executor.compile_model( 16 | model = openvino_executor.read_model(MODEL), device_name="CPU") 17 | for sample in tqdm(SAMPLES, desc='OPENVINO GENERATEING OUTPUTS'): 18 | openvino_results.append(model([convert_any_to_numpy(sample)])) -------------------------------------------------------------------------------- /ppq/samples/Openvino/Example_QAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision 4 | from absl import logging 5 | 6 | # 装一下下面这个库 7 | from pytorch_quantization import nn as quant_nn 8 | 9 | logging.set_verbosity(logging.FATAL) # Disable logging as they are too noisy in notebook 10 | 11 | from pytorch_quantization import quant_modules 12 | 13 | # 调用这个 quant_modules.initialize() 14 | # 然后你正常训练就行了 ... 15 | quant_modules.initialize() 16 | 17 | model = torchvision.models.resnet50() 18 | model.cuda() 19 | 20 | # Quantization Aware Training is based on Straight Through Estimator (STE) derivative approximation. 21 | # It is some time known as “quantization aware training”. 22 | # We don’t use the name because it doesn’t reflect the underneath assumption. 23 | # If anything, it makes training being “unaware” of quantization because of the STE approximation. 24 | 25 | # After calibration is done, Quantization Aware Training is simply select a training schedule and continue training the calibrated model. 26 | # Usually, it doesn’t need to fine tune very long. We usually use around 10% of the original training schedule, 27 | # starting at 1% of the initial training learning rate, 28 | # and a cosine annealing learning rate schedule that follows the decreasing half of a cosine period, 29 | # down to 1% of the initial fine tuning learning rate (0.01% of the initial training learning rate). 30 | 31 | # Quantization Aware Training (Essentially a discrete numerical optimization problem) is not a solved problem mathematically. 32 | # Based on our experience, here are some recommendations: 33 | 34 | # For STE approximation to work well, it is better to use small learning rate. 35 | # Large learning rate is more likely to enlarge the variance introduced by STE approximation and destroy the trained network. 36 | 37 | # Do not change quantization representation (scale) during training, at least not too frequently. 38 | # Changing scale every step, it is effectively like changing data format (e8m7, e5m10, e3m4, et.al) every step, 39 | # which will easily affect convergence. 40 | 41 | # https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/finetune_quant_resnet50.ipynb 42 | 43 | def export_onnx(model, onnx_filename, batch_onnx): 44 | model.eval() 45 | quant_nn.TensorQuantizer.use_fb_fake_quant = True # We have to shift to pytorch's fake quant ops before exporting the model to ONNX 46 | opset_version = 13 47 | 48 | # Export ONNX for multiple batch sizes 49 | print("Creating ONNX file: " + onnx_filename) 50 | dummy_input = torch.randn(batch_onnx, 3, 224, 224, device='cuda') #TODO: switch input dims by model 51 | torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True) 52 | return True 53 | -------------------------------------------------------------------------------- /ppq/samples/RKNN/Example_PTQ.py: -------------------------------------------------------------------------------- 1 | # TO BE CONTINUE -------------------------------------------------------------------------------- /ppq/samples/TensorRT/Example_Benchmark.py: -------------------------------------------------------------------------------- 1 | from ppq.utils.TensorRTUtil import Benchmark, Profiling 2 | 3 | Benchmark(engine_file='Output/INT8.engine') 4 | Benchmark(engine_file='Output/FP16.engine') 5 | Benchmark(engine_file='Output/FP32.engine') 6 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/Example_Fp32.py: -------------------------------------------------------------------------------- 1 | MODEL = 'model.onnx' 2 | INPUT_SHAPE = [1, 3, 224, 224] 3 | SAMPLES = [torch.rand(size=[INPUT_SHAPE]) for _ in range(256)] # rewirte this to use real data. 4 | 5 | # ------------------------------------------------------------------- 6 | # 打开 trt_infer 看到具体细节,这个文件是 nvidia 的官方实例 7 | # ------------------------------------------------------------------- 8 | from trt_infer import EngineBuilder 9 | builder = EngineBuilder() 10 | builder.create_network('model_fp32.onnx') 11 | builder.create_engine(engine_path='model_fp32.engine', precision="fp16") 12 | 13 | # ------------------------------------------------------------------- 14 | # 启动 tensorRT 进行推理,你先装一下 trt 15 | # ------------------------------------------------------------------- 16 | import tensorrt as trt 17 | import trt_infer 18 | 19 | samples = [convert_any_to_numpy(sample) for sample in SAMPLES] 20 | logger = trt.Logger(trt.Logger.INFO) 21 | with open('model_fp32.engine', 'rb') as f, trt.Runtime(logger) as runtime: 22 | engine = runtime.deserialize_cuda_engine(f.read()) 23 | 24 | results = [] 25 | with engine.create_execution_context() as context: 26 | inputs, outputs, bindings, stream = trt_infer.allocate_buffers(context.engine) 27 | for sample in tqdm(samples, desc='TensorRT is running...'): 28 | inputs[0].host = convert_any_to_numpy(sample) 29 | [output] = trt_infer.do_inference( 30 | context, bindings=bindings, inputs=inputs, 31 | outputs=outputs, stream=stream, batch_size=1) 32 | results.append(convert_any_to_torch_tensor(output)) 33 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/Example_PTQ.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | import torchvision 6 | 7 | from ppq import TargetPlatform, graphwise_error_analyse, TorchExecutor 8 | from ppq.api import ENABLE_CUDA_KERNEL, export_ppq_graph, load_torch_model 9 | from ppq.core import convert_any_to_numpy 10 | from ppq.quantization.optim import * 11 | import ppq.lib as PFL 12 | 13 | calibration_dataloader = [] 14 | for file in os.listdir('imagenet'): 15 | path = os.path.join('imagenet', file) 16 | arr = np.fromfile(path, dtype=np.dtype('float32')).reshape([1, 3, 224, 224]) 17 | calibration_dataloader.append(torch.tensor(arr)) 18 | 19 | with ENABLE_CUDA_KERNEL(): 20 | model = torchvision.models.mnasnet1_0(pretrained=True).cuda() 21 | graph = load_torch_model(model=model, sample=torch.zeros(size=[1, 3, 224, 224]).cuda()) 22 | # ------------------------------------------------------------ 23 | # 我们首先进行标准的量化流程,为所有算子初始化量化信息,并进行 Calibration 24 | # ------------------------------------------------------------ 25 | quantizer = PFL.Quantizer(platform=TargetPlatform.TRT_INT8, graph=graph) # 取得 TRT_INT8 所对应的量化器 26 | dispatching = PFL.Dispatcher(graph=graph).dispatch( # 生成调度表 27 | quant_types=quantizer.quant_operation_types) 28 | 29 | # 为算子初始化量化信息 30 | for op in graph.operations.values(): 31 | quantizer.quantize_operation( 32 | op_name = op.name, platform = dispatching[op.name]) 33 | 34 | # 初始化执行器 35 | collate_fn = lambda x: x.to('cuda') 36 | executor = TorchExecutor(graph=graph, device='cuda') 37 | executor.tracing_operation_meta(inputs=torch.zeros(size=[1, 3, 224, 224]).cuda()) 38 | executor.load_graph(graph=graph) 39 | 40 | # ------------------------------------------------------------ 41 | # 创建优化管线,由于后续还要继续训练我们的模型,我们不能在此处调用 42 | # ParameterBakingPass(),一旦模型权重完成烘焙,则它们不能被进一步调整 43 | # ------------------------------------------------------------ 44 | pipeline = PFL.Pipeline([ 45 | QuantizeSimplifyPass(), 46 | QuantizeFusionPass( 47 | activation_type=quantizer.activation_fusion_types), 48 | ParameterQuantizePass(), 49 | RuntimeCalibrationPass(), 50 | PassiveParameterQuantizePass(), 51 | QuantAlignmentPass(force_overlap=True), 52 | ]) 53 | 54 | with ENABLE_CUDA_KERNEL(): 55 | # 调用管线完成量化 56 | pipeline.optimize( 57 | graph=graph, dataloader=calibration_dataloader, verbose=True, 58 | calib_steps=32, collate_fn=collate_fn, executor=executor) 59 | 60 | graphwise_error_analyse( 61 | graph=graph, running_device='cuda', dataloader=calibration_dataloader, 62 | collate_fn=lambda x: x.cuda()) 63 | 64 | export_ppq_graph( 65 | graph=graph, platform=TargetPlatform.TRT_INT8, 66 | graph_save_to='Output/quantized.onnx', 67 | config_save_to='Output/quantized.json') 68 | 69 | results, executor = [], TorchExecutor(graph=graph) 70 | for idx, data in enumerate(calibration_dataloader): 71 | arr = convert_any_to_numpy(executor(data.cuda())[0]) 72 | arr.tofile(f'Output/Result/{idx}.bin') 73 | 74 | from ppq.utils.TensorRTUtil import build_engine 75 | build_engine(onnx_file='Output/quantized.onnx', int8_scale_file='Output/quantized.json', engine_file='Output/INT8.engine', int8=True, fp16=True) 76 | build_engine(onnx_file='Output/quantized.onnx', engine_file='Output/FP16.engine', int8=False, fp16=True) 77 | build_engine(onnx_file='Output/quantized.onnx', engine_file='Output/FP32.engine', int8=False, fp16=False) 78 | 79 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/Example_Profiling.py: -------------------------------------------------------------------------------- 1 | from ppq.utils.TensorRTUtil import Benchmark, Profiling 2 | 3 | print('Profiling with Int8 Model') 4 | Profiling(engine_file='Output/INT8.engine') 5 | print('-------------------------------------------') 6 | 7 | print('Profiling with Fp16 Model') 8 | Profiling(engine_file='Output/FP16.engine') 9 | print('-------------------------------------------') 10 | 11 | print('Profiling with Fp32 Model') 12 | Profiling(engine_file='Output/FP32.engine') 13 | print('-------------------------------------------') -------------------------------------------------------------------------------- /ppq/samples/TensorRT/Example_QAT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision 4 | from absl import logging 5 | 6 | # 装一下下面这个库 7 | from pytorch_quantization import nn as quant_nn 8 | 9 | logging.set_verbosity(logging.FATAL) # Disable logging as they are too noisy in notebook 10 | 11 | from pytorch_quantization import quant_modules 12 | 13 | # 调用这个 quant_modules.initialize() 14 | # 然后你正常训练就行了 ... 15 | quant_modules.initialize() 16 | 17 | model = torchvision.models.resnet50() 18 | model.cuda() 19 | 20 | # Quantization Aware Training is based on Straight Through Estimator (STE) derivative approximation. 21 | # It is some time known as “quantization aware training”. 22 | # We don’t use the name because it doesn’t reflect the underneath assumption. 23 | # If anything, it makes training being “unaware” of quantization because of the STE approximation. 24 | 25 | # After calibration is done, Quantization Aware Training is simply select a training schedule and continue training the calibrated model. 26 | # Usually, it doesn’t need to fine tune very long. We usually use around 10% of the original training schedule, 27 | # starting at 1% of the initial training learning rate, 28 | # and a cosine annealing learning rate schedule that follows the decreasing half of a cosine period, 29 | # down to 1% of the initial fine tuning learning rate (0.01% of the initial training learning rate). 30 | 31 | # Quantization Aware Training (Essentially a discrete numerical optimization problem) is not a solved problem mathematically. 32 | # Based on our experience, here are some recommendations: 33 | 34 | # For STE approximation to work well, it is better to use small learning rate. 35 | # Large learning rate is more likely to enlarge the variance introduced by STE approximation and destroy the trained network. 36 | 37 | # Do not change quantization representation (scale) during training, at least not too frequently. 38 | # Changing scale every step, it is effectively like changing data format (e8m7, e5m10, e3m4, et.al) every step, 39 | # which will easily affect convergence. 40 | 41 | # https://github.com/NVIDIA/TensorRT/blob/main/tools/pytorch-quantization/examples/finetune_quant_resnet50.ipynb 42 | 43 | def export_onnx(model, onnx_filename, batch_onnx): 44 | model.eval() 45 | quant_nn.TensorQuantizer.use_fb_fake_quant = True # We have to shift to pytorch's fake quant ops before exporting the model to ONNX 46 | opset_version = 13 47 | 48 | # Export ONNX for multiple batch sizes 49 | print("Creating ONNX file: " + onnx_filename) 50 | dummy_input = torch.randn(batch_onnx, 3, 224, 224, device='cuda') #TODO: switch input dims by model 51 | torch.onnx.export(model, dummy_input, onnx_filename, verbose=False, opset_version=opset_version, enable_onnx_checker=False, do_constant_folding=True) 52 | return True 53 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/Example_Torch2trt.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # 这个脚本向你展示了如何使用 torch2trt 加速 pytorch 推理 3 | # 截止目前为止 torch2trt 的适配能力有限,不要尝试运行特别奇怪的模型 4 | # 你可以把模型分块来绕开那些不支持的算子。 5 | 6 | # 使用之前你必须先装好 TensorRT, torch2trt等工具包 7 | # https://github.com/NVIDIA-AI-IOT/torch2trt 8 | 9 | # --------------------------------------------------------------- 10 | 11 | import torch 12 | import torch.utils.data 13 | import torchvision 14 | from torch2trt import torch2trt 15 | from tqdm import tqdm 16 | 17 | 18 | SAMPLES = [torch.zeros(1, 3, 224, 224) for _ in range(1024)] 19 | MODEL = torchvision.models.resnet18() 20 | FP16_MODE = True 21 | 22 | # Model has to be the eval mode, and deploy to cuda. 23 | MODEL.eval() 24 | MODEL.cuda() 25 | 26 | def trace_handler(prof): 27 | print(prof.key_averages().table( 28 | sort_by="self_cuda_time_total", row_limit=-1)) 29 | 30 | # Benckmark with pytorch 31 | for sample in tqdm(SAMPLES, desc='Torch Executing'): 32 | MODEL.forward(sample.cuda()) 33 | 34 | # Convert torch.nn.Module to tensorrt 35 | # 在转换过后,你模型中的执行函数将会被 trt 替换,同时进行图融合 36 | model_trt = torch2trt(MODEL, [sample.cuda()], fp16_mode=FP16_MODE) 37 | for sample in tqdm(SAMPLES, desc='TRT Executing'): 38 | model_trt.forward(sample.cuda()) 39 | 40 | # Test performance metrics using torch.profiler 41 | with torch.profiler.profile( 42 | activities=[ 43 | torch.profiler.ProfilerActivity.CPU, 44 | torch.profiler.ProfilerActivity.CUDA], 45 | schedule=torch.profiler.schedule( 46 | wait=2, 47 | warmup=1, 48 | active=7), 49 | on_trace_ready=trace_handler 50 | # on_trace_ready=torch.profiler.tensorboard_trace_handler('log') 51 | # used when outputting for tensorboard 52 | ) as p: 53 | for iter in range(10): 54 | model_trt.forward(sample.cuda()) 55 | # send a signal to the profiler that the next iteration has started 56 | p.step() 57 | 58 | with torch.profiler.profile( 59 | activities=[ 60 | torch.profiler.ProfilerActivity.CPU, 61 | torch.profiler.ProfilerActivity.CUDA], 62 | schedule=torch.profiler.schedule( 63 | wait=2, 64 | warmup=1, 65 | active=7), 66 | on_trace_ready=trace_handler 67 | # on_trace_ready=torch.profiler.tensorboard_trace_handler('log') 68 | # used when outputting for tensorboard 69 | ) as p: 70 | for iter in range(10): 71 | MODEL.forward(sample.cuda()) 72 | # send a signal to the profiler that the next iteration has started 73 | p.step() 74 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/create_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import tensorrt as trt 5 | 6 | TRT_LOGGER = trt.Logger() 7 | 8 | EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 9 | 10 | def GiB(val): 11 | return val * 1 << 30 12 | 13 | def json_load(filename): 14 | with open(filename) as json_file: 15 | data = json.load(json_file) 16 | return data 17 | 18 | def setDynamicRange(network, json_file): 19 | """Sets ranges for network layers.""" 20 | quant_param_json = json_load(json_file) 21 | act_quant = quant_param_json["act_quant_info"] 22 | 23 | for i in range(network.num_inputs): 24 | input_tensor = network.get_input(i) 25 | if act_quant.__contains__(input_tensor.name): 26 | value = act_quant[input_tensor.name] 27 | tensor_max = abs(value) 28 | tensor_min = -abs(value) 29 | input_tensor.dynamic_range = (tensor_min, tensor_max) 30 | 31 | 32 | for i in range(network.num_layers): 33 | layer = network.get_layer(i) 34 | 35 | for output_index in range(layer.num_outputs): 36 | tensor = layer.get_output(output_index) 37 | 38 | if act_quant.__contains__(tensor.name): 39 | value = act_quant[tensor.name] 40 | tensor_max = abs(value) 41 | tensor_min = -abs(value) 42 | tensor.dynamic_range = (tensor_min, tensor_max) 43 | else: 44 | print("\033[1;32m%s\033[0m" % tensor.name) 45 | 46 | 47 | def build_engine(onnx_file, json_file, engine_file): 48 | builder = trt.Builder(TRT_LOGGER) 49 | network = builder.create_network(EXPLICIT_BATCH) 50 | 51 | config = builder.create_builder_config() 52 | 53 | # If it is a dynamic onnx model , you need to add the following. 54 | # profile = builder.create_optimization_profile() 55 | # profile.set_shape("input_name", (batch, channels, min_h, min_w), (batch, channels, opt_h, opt_w), (batch, channels, max_h, max_w)) 56 | # config.add_optimization_profile(profile) 57 | 58 | 59 | parser = trt.OnnxParser(network, TRT_LOGGER) 60 | config.max_workspace_size = GiB(1) 61 | 62 | if not os.path.exists(onnx_file): 63 | quit('ONNX file {} not found'.format(onnx_file)) 64 | 65 | with open(onnx_file, 'rb') as model: 66 | if not parser.parse(model.read()): 67 | print('ERROR: Failed to parse the ONNX file.') 68 | for error in range(parser.num_errors): 69 | print(parser.get_error(error)) 70 | return None 71 | 72 | config.set_flag(trt.BuilderFlag.INT8) 73 | 74 | setDynamicRange(network, json_file) 75 | 76 | engine = builder.build_engine(network, config) 77 | 78 | with open(engine_file, "wb") as f: 79 | f.write(engine.serialize()) 80 | 81 | 82 | if __name__ == '__main__': 83 | # Add plugins if needed 84 | # import ctypes 85 | # ctypes.CDLL("libmmdeploy_tensorrt_ops.so") 86 | parser = argparse.ArgumentParser(description='Writing qparams to onnx to convert tensorrt engine.') 87 | parser.add_argument('--onnx', type=str, default=None) 88 | parser.add_argument('--qparam_json', type=str, default=None) 89 | parser.add_argument('--engine', type=str, default=None) 90 | arg = parser.parse_args() 91 | 92 | build_engine(arg.onnx, arg.qparam_json, arg.engine) 93 | print("\033[1;32mgenerate %s\033[0m" % arg.engine) 94 | 95 | 96 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/lenet_demo/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.1) 2 | project(lenet) 3 | 4 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 5 | set(CMAKE_EXPORT_COMPILE_COMMANDS ON) 6 | set(CMAKE_BUILD_TYPE Debug) 7 | set(TARGET_NAME "lenet_int8") 8 | set(CMAKE_CXX_STANDARD 17) 9 | set(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin) 10 | 11 | message("TENSORRT_INCLUDE:" ${TENSORRT_INCLUDE}) 12 | 13 | # specify the header file path 14 | include_directories(${CMAKE_SOURCE_DIR}/common 15 | /usr/local/cuda/include 16 | /opt/TensorRT-8.4.1.5/include) 17 | # specify the library file path 18 | link_directories(/usr/local/cuda/lib64 19 | /opt/TensorRT-8.4.1.5/lib) 20 | 21 | file(GLOB MTF_SRC ${PROJECT_SOURCE_DIR}/common/*.h 22 | ${PROJECT_SOURCE_DIR}/common/*.cpp) 23 | add_library(trt_deploy_common ${MTF_SRC}) 24 | 25 | add_executable(${TARGET_NAME} lenet_int8.cpp) 26 | target_link_libraries(${TARGET_NAME} nvinfer cudart trt_deploy_common) 27 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/lenet_demo/common/logger.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include "logger.h" 18 | #include "ErrorRecorder.h" 19 | #include "logging.h" 20 | 21 | SampleErrorRecorder gRecorder; 22 | namespace sample 23 | { 24 | Logger gLogger{Logger::Severity::kINFO}; 25 | LogStreamConsumer gLogVerbose{LOG_VERBOSE(gLogger)}; 26 | LogStreamConsumer gLogInfo{LOG_INFO(gLogger)}; 27 | LogStreamConsumer gLogWarning{LOG_WARN(gLogger)}; 28 | LogStreamConsumer gLogError{LOG_ERROR(gLogger)}; 29 | LogStreamConsumer gLogFatal{LOG_FATAL(gLogger)}; 30 | 31 | void setReportableSeverity(Logger::Severity severity) 32 | { 33 | gLogger.setReportableSeverity(severity); 34 | gLogVerbose.setReportableSeverity(severity); 35 | gLogInfo.setReportableSeverity(severity); 36 | gLogWarning.setReportableSeverity(severity); 37 | gLogError.setReportableSeverity(severity); 38 | gLogFatal.setReportableSeverity(severity); 39 | } 40 | } // namespace sample 41 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/lenet_demo/common/logger.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef LOGGER_H 18 | #define LOGGER_H 19 | 20 | #include "logging.h" 21 | 22 | class SampleErrorRecorder; 23 | extern SampleErrorRecorder gRecorder; 24 | namespace sample 25 | { 26 | extern Logger gLogger; 27 | extern LogStreamConsumer gLogVerbose; 28 | extern LogStreamConsumer gLogInfo; 29 | extern LogStreamConsumer gLogWarning; 30 | extern LogStreamConsumer gLogError; 31 | extern LogStreamConsumer gLogFatal; 32 | 33 | void setReportableSeverity(Logger::Severity severity); 34 | } // namespace sample 35 | 36 | #endif // LOGGER_H 37 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/lenet_demo/common/macros.h: -------------------------------------------------------------------------------- 1 | #ifndef __MACROS_H 2 | #define __MACROS_H 3 | 4 | #if NV_TENSORRT_MAJOR >= 8 5 | #define TRT_NOEXCEPT noexcept 6 | #define TRT_CONST_ENQUEUE const 7 | #else 8 | #define TRT_NOEXCEPT 9 | #define TRT_CONST_ENQUEUE 10 | #endif 11 | 12 | #endif // __MACROS_H 13 | -------------------------------------------------------------------------------- /ppq/samples/TensorRT/lenet_demo/generate_onnx.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Lenet5(nn.Module): 6 | """ 7 | for cifar10 dataset. 8 | """ 9 | def __init__(self): 10 | super(Lenet5, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0) 13 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 14 | self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0) 15 | self.fc1 = nn.Linear(16*5*5, 120) 16 | self.fc2 = nn.Linear(120, 84) 17 | self.fc3 = nn.Linear(84, 10) 18 | 19 | def forward(self, x): 20 | print('input: ', x.shape) 21 | x = F.relu(self.conv1(x)) 22 | print('conv1',x.shape) 23 | x = self.pool1(x) 24 | print('pool1: ', x.shape) 25 | x = F.relu(self.conv2(x)) 26 | print('conv2',x.shape) 27 | x = self.pool1(x) 28 | print('pool2',x.shape) 29 | x = x.view(x.size(0), -1) 30 | print('view: ', x.shape) 31 | x = F.relu(self.fc1(x)) 32 | print('fc1: ', x.shape) 33 | x = F.relu(self.fc2(x)) 34 | x = F.softmax(self.fc3(x), dim=1) 35 | return x 36 | 37 | def export_onnx(onnx_filename): 38 | points = torch.full((1, 1, 32, 32), 1.5, dtype=torch.float32).cuda() 39 | inputs = (points, ) 40 | model = Lenet5().cuda() 41 | torch.onnx.export(model, inputs, onnx_filename, opset_version=11) 42 | print("The output of raw network: ", torch.mean(model(*inputs)).detach().cpu().numpy()) 43 | 44 | def print_onnx_model(onnx_filename): 45 | # Load the ONNX model 46 | import onnx 47 | model = onnx.load(onnx_filename) 48 | # Check that the IR is well formed 49 | onnx.checker.check_model(model) 50 | # Print a human readable representation of the graph 51 | print(onnx.helper.printable_graph(model.graph)) 52 | 53 | 54 | def main(): 55 | # get onnx 56 | onnx_filename = "lenet.onnx" 57 | export_onnx(onnx_filename) 58 | print_onnx_model(onnx_filename) 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /ppq/samples/Tutorial/dequantize.py: -------------------------------------------------------------------------------- 1 | # This Example shows you how to analyse your quantized network. 2 | # Check quantization error for each layer. 3 | 4 | from typing import Iterable 5 | 6 | import torch 7 | import torchvision 8 | from ppq import QuantableOperation, QuantizationSettingFactory, TargetPlatform 9 | from ppq.api import quantize_torch_model 10 | from ppq.core.quant import QuantizationStates 11 | from torch.utils.data import DataLoader 12 | 13 | BATCHSIZE = 32 14 | INPUT_SHAPE = [3, 224, 224] 15 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 16 | PLATFORM = TargetPlatform.PPL_CUDA_INT8 # identify a target platform for your network. 17 | 18 | def load_calibration_dataset() -> Iterable: 19 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 20 | 21 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 22 | return batch.to(DEVICE) 23 | 24 | # Load a pretrained mobilenet v2 model 25 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) 26 | model = model.to(DEVICE) 27 | 28 | # create a setting for quantizing your network with PPL CUDA. 29 | quant_setting = QuantizationSettingFactory.pplcuda_setting() 30 | quant_setting.equalization = True # use layerwise equalization algorithm. 31 | quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way. 32 | 33 | # Load training data for creating a calibration dataloader. 34 | calibration_dataset = load_calibration_dataset() 35 | calibration_dataloader = DataLoader( 36 | dataset=calibration_dataset, 37 | batch_size=BATCHSIZE, shuffle=True) 38 | 39 | # quantize your model. 40 | quantized = quantize_torch_model( 41 | model=model, calib_dataloader=calibration_dataloader, 42 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 43 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 44 | onnx_export_file='Output/onnx.model', device=DEVICE, verbose=0) 45 | 46 | # dequantize with operation.dequantize() function: 47 | for operation in quantized.operations.values(): 48 | 49 | if isinstance(operation, QuantableOperation): 50 | # all parameters of this operation will be dequantized, baked value will be replaced. 51 | # input and output of this operation will not be quantized since now. 52 | operation.dequantize() 53 | 54 | # restore quantization state: 55 | for operation in quantized.operations.values(): 56 | 57 | if isinstance(operation, QuantableOperation): 58 | # all parameters of this operation will restore its quantization result. 59 | # input and output of this operation will be quantized since now. 60 | operation.restore_quantize_state() 61 | 62 | # manually dequantize an operation: 63 | for operation in quantized.operations.values(): 64 | if isinstance(operation, QuantableOperation): 65 | for cfg, var in operation.config_with_variable: 66 | 67 | if var.is_parameter and cfg.state == QuantizationStates.BAKED: 68 | print(f'Variable {var.name} is pre-baked, simply overriding its state takes no effects.') 69 | else: 70 | # once state is changed to FP32 71 | # executor will skip this quantization during executing. 72 | cfg.state = QuantizationStates.FP32 73 | -------------------------------------------------------------------------------- /ppq/samples/Tutorial/execute.py: -------------------------------------------------------------------------------- 1 | # This Example shows you how to execute a quantized network and get its result. 2 | from typing import Iterable 3 | 4 | import torch 5 | import torchvision 6 | from ppq import (BaseGraph, QuantableOperation, QuantizationSettingFactory, 7 | TargetPlatform, TorchExecutor) 8 | from ppq.api import quantize_torch_model 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | BATCHSIZE = 32 13 | INPUT_SHAPE = [3, 224, 224] 14 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 15 | PLATFORM = TargetPlatform.PPL_CUDA_INT8 # identify a target platform for your network. 16 | 17 | def load_calibration_dataset() -> Iterable: 18 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 19 | 20 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 21 | return batch.to(DEVICE) 22 | 23 | # Load a pretrained mobilenet v2 model 24 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) 25 | model = model.to(DEVICE) 26 | 27 | # create a setting for quantizing your network with PPL CUDA. 28 | quant_setting = QuantizationSettingFactory.pplcuda_setting() 29 | quant_setting.equalization = True # use layerwise equalization algorithm. 30 | quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way. 31 | 32 | # Load training data for creating a calibration dataloader. 33 | calibration_dataset = load_calibration_dataset() 34 | calibration_dataloader = DataLoader( 35 | dataset=calibration_dataset, 36 | batch_size=BATCHSIZE, shuffle=True) 37 | 38 | # quantize your model. 39 | quantized = quantize_torch_model( 40 | model=model, calib_dataloader=calibration_dataloader, 41 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 42 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 43 | onnx_export_file='Output/onnx.model', device=DEVICE, verbose=0) 44 | 45 | # Quantization Result is a PPQ BaseGraph instance. 46 | assert isinstance(quantized, BaseGraph) 47 | 48 | # build an executor: 49 | executor = TorchExecutor(graph=quantized, device=DEVICE) 50 | 51 | # run with your network, results are torch.Tensors 52 | for data in tqdm(calibration_dataloader, desc='Running with executor.'): 53 | results = executor.forward(inputs=data.to(DEVICE)) 54 | 55 | # extract result for specific variables: 56 | interested_vars = [] 57 | for operation in quantized.operations.values(): 58 | if isinstance(operation, QuantableOperation) and operation.type == 'Conv': 59 | interested_vars.append(operation.outputs[0].name) 60 | 61 | # results contains all convolution layers' output. 62 | results = executor.forward(inputs=data.to(DEVICE), output_names=interested_vars) 63 | print(f'There are {len(results)} convolution results.') 64 | -------------------------------------------------------------------------------- /ppq/samples/Yolo/00_FloatModel.py: -------------------------------------------------------------------------------- 1 | ONNX_PATH = 'models/yolov5s.v5.onnx' # 你的模型位置 2 | ENGINE_PATH = 'Output/yolov5s.v5(fp32).engine' # 生成的 Engine 位置 3 | 4 | # ------------------------------------------------------------------- 5 | # 打开 trt_infer 看到具体细节,这个文件是 nvidia 的官方实例 6 | # ------------------------------------------------------------------- 7 | from trt_infer import EngineBuilder 8 | builder = EngineBuilder() 9 | builder.create_network(ONNX_PATH) 10 | builder.create_engine(engine_path=ENGINE_PATH, precision="fp32") 11 | -------------------------------------------------------------------------------- /ppq/samples/Yolo/02_Quantization.py: -------------------------------------------------------------------------------- 1 | # 这里我们展示两种不同的方法去生成 TensorRT Engine 2 | 3 | # Plan B: PPQ 导出 engine 4 | 5 | import os 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | from ppq import * 9 | from ppq.api import * 10 | 11 | ONNX_PATH = 'models/yolov5s6.onnx' # 你的模型位置 12 | OUTPUT_PATH = 'Output' # 生成的量化模型的位置 13 | CALIBRATION_PATH = 'imgs' # 校准数据集 14 | BATCHSIZE = 1 15 | EXECUTING_DEVICE = 'cuda' 16 | # create dataloader 17 | imgs = [] 18 | trans = transforms.Compose([ 19 | transforms.Resize([640, 640]), # [h,w] 20 | transforms.ToTensor(), 21 | ]) 22 | for file in os.listdir(path=CALIBRATION_PATH): 23 | path = os.path.join(CALIBRATION_PATH, file) 24 | img = Image.open(path).convert('RGB') 25 | img = trans(img) 26 | imgs.append(img) # img is 0 - 1 27 | 28 | from torch.utils.data import DataLoader 29 | dataloader = DataLoader(dataset=imgs, batch_size=BATCHSIZE) 30 | 31 | with ENABLE_CUDA_KERNEL(): 32 | qir = quantize_onnx_model( 33 | platform=TargetPlatform.TRT_INT8, 34 | onnx_import_file=ONNX_PATH, 35 | calib_dataloader=dataloader, 36 | calib_steps=32, device=EXECUTING_DEVICE, 37 | input_shape=[BATCHSIZE, 3, 640, 640], 38 | collate_fn=lambda x: x.to(EXECUTING_DEVICE)) 39 | 40 | snr_report = graphwise_error_analyse( 41 | graph=qir, running_device=EXECUTING_DEVICE, 42 | dataloader=dataloader, collate_fn=lambda x: x.to(EXECUTING_DEVICE)) 43 | 44 | snr_report = layerwise_error_analyse( 45 | graph=qir, running_device=EXECUTING_DEVICE, 46 | dataloader=dataloader, collate_fn=lambda x: x.to(EXECUTING_DEVICE)) 47 | 48 | export_ppq_graph( 49 | qir, platform=TargetPlatform.TRT_INT8, 50 | graph_save_to=OUTPUT_PATH.join('/INT8.onnx'), 51 | config_save_to=OUTPUT_PATH.join('/INT8.json')) 52 | 53 | from ppq.utils.TensorRTUtil import build_engine, Benchmark, Profiling 54 | build_engine( 55 | onnx_file=OUTPUT_PATH.join('/INT8.onnx'), 56 | engine_file=OUTPUT_PATH.join('/INT8.engine'), int8=True, 57 | int8_scale_file=OUTPUT_PATH.join('/INT8.json')) 58 | 59 | Benchmark(OUTPUT_PATH.join('/INT8.engine')) 60 | Profiling(OUTPUT_PATH.join('/INT8.engine')) -------------------------------------------------------------------------------- /ppq/samples/Yolo/yolo_5.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | 5 | import ppq.lib as PFL 6 | from ppq import TargetPlatform, TorchExecutor, graphwise_error_analyse 7 | from ppq.api import ENABLE_CUDA_KERNEL, export_ppq_graph, load_onnx_graph 8 | from ppq.quantization.optim import * 9 | 10 | calibration_dataloader = [torch.rand([1, 3, 640, 640]) for _ in range(32)] 11 | 12 | with ENABLE_CUDA_KERNEL(): 13 | graph = load_onnx_graph('Models/yolov5s.v5.onnx') 14 | 15 | quantizer = PFL.Quantizer(platform=TargetPlatform.OPENVINO_INT8, graph=graph) # 取得 OPENVINO_INT8 所对应的量化器 16 | dispatching = PFL.Dispatcher(graph=graph).dispatch( # 生成调度表 17 | quant_types=quantizer.quant_operation_types) 18 | 19 | # ------------------------------------------------------------ 20 | # Yolo5 前面可能有一坨 Slice, Concat 算子 21 | # 后面可能有一坨后处理算子,我们不希望量化它们,你可用下面的方法将它们解除量化 22 | # 不同模型中层的名字可能不一样,你需要按照你的模型对它们进行手动修改 23 | # ------------------------------------------------------------ 24 | 25 | # Concat_40 往前的所有算子不量化 26 | from ppq.IR import SearchableGraph 27 | search_engine = SearchableGraph(graph) 28 | for op in search_engine.opset_matching( 29 | sp_expr=lambda x: x.name == 'Concat_40', 30 | rp_expr=lambda x, y: True, 31 | ep_expr=None, direction='up' 32 | ): 33 | dispatching[op.name] = TargetPlatform.FP32 34 | 35 | # Sigmoid_280 往后的所有算子不量化 36 | # Sigmoid_246 往后的所有算子不量化 37 | # Sigmoid_314 往后的所有算子不量化 38 | for op in search_engine.opset_matching( 39 | sp_expr=lambda x: x.name in {'Sigmoid_246', 'Sigmoid_280', 'Sigmoid_314'}, 40 | rp_expr=lambda x, y: True, 41 | ep_expr=None, direction='down' 42 | ): 43 | dispatching[op.name] = TargetPlatform.FP32 44 | 45 | # 为算子初始化量化信息 46 | for op in graph.operations.values(): 47 | quantizer.quantize_operation( 48 | op_name = op.name, platform = dispatching[op.name]) 49 | 50 | # 初始化执行器 51 | collate_fn = lambda x: x.to('cuda') 52 | executor = TorchExecutor(graph=graph, device='cuda') 53 | executor.tracing_operation_meta(inputs=torch.zeros(size=[1, 3, 640, 640]).cuda()) 54 | executor.load_graph(graph=graph) 55 | 56 | # ------------------------------------------------------------ 57 | # 创建优化管线,由于后续还要继续训练我们的模型,我们不能在此处调用 58 | # ParameterBakingPass(),一旦模型权重完成烘焙,则它们不能被进一步调整 59 | # ------------------------------------------------------------ 60 | pipeline = PFL.Pipeline([ 61 | QuantizeSimplifyPass(), 62 | QuantizeFusionPass( 63 | activation_type=quantizer.activation_fusion_types), 64 | ParameterQuantizePass(), 65 | RuntimeCalibrationPass(), 66 | PassiveParameterQuantizePass(), 67 | QuantAlignmentPass(force_overlap=True), 68 | ParameterBakingPass() 69 | ]) 70 | 71 | with ENABLE_CUDA_KERNEL(): 72 | # 调用管线完成量化 73 | pipeline.optimize( 74 | graph=graph, dataloader=calibration_dataloader, verbose=True, 75 | calib_steps=32, collate_fn=collate_fn, executor=executor) 76 | 77 | graphwise_error_analyse( 78 | graph=graph, running_device='cuda', dataloader=calibration_dataloader, 79 | collate_fn=lambda x: x.cuda()) 80 | 81 | export_ppq_graph( 82 | graph=graph, platform=TargetPlatform.TRT_INT8, 83 | graph_save_to='Output/quantized.onnx', 84 | config_save_to='Output/quantized.json') 85 | 86 | from ppq.utils.TensorRTUtil import Benchmark, Profiling, build_engine 87 | 88 | build_engine( 89 | onnx_file='Output/quantized.onnx', 90 | engine_file='Output/quantized.engine', 91 | int8=True, int8_scale_file='Output/quantized.json') 92 | Benchmark('Output/quantized.engine') 93 | Profiling('Output/quantized.engine') -------------------------------------------------------------------------------- /ppq/samples/Yolo/yolo_x.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import openvino 3 | import torch 4 | from tqdm import tqdm 5 | import time 6 | 7 | from ppq import * 8 | from ppq.api import * 9 | 10 | QUANT_PLATFROM = TargetPlatform.OPENVINO_INT8 11 | BATCHSIZE = 1 12 | DEVICE = 'cuda' 13 | INPUTSHAPE = [BATCHSIZE, 3, 640, 640] 14 | SAMPLES = [torch.rand(size=INPUTSHAPE) for _ in range(256)] 15 | BENCHMARK_SAMPLES = 512 16 | MODEL_PATH = 'Models/yolox_s.onnx' 17 | VALIDATION = False 18 | 19 | with ENABLE_CUDA_KERNEL(): 20 | quantized = quantize_onnx_model( 21 | onnx_import_file=MODEL_PATH, calib_dataloader=SAMPLES, collate_fn=lambda x: x.to(DEVICE), 22 | calib_steps=32, input_shape=INPUTSHAPE, 23 | setting=QuantizationSettingFactory.default_setting(), 24 | platform=QUANT_PLATFROM) 25 | 26 | graphwise_error_analyse(graph=quantized, running_device='cuda', 27 | dataloader=SAMPLES, collate_fn=lambda x: x.cuda(), steps=32) 28 | 29 | export_ppq_graph( 30 | graph=quantized, platform=TargetPlatform.ONNX, 31 | graph_save_to='FP32.onnx') 32 | 33 | export_ppq_graph( 34 | graph=quantized, platform=TargetPlatform.OPENVINO_INT8, 35 | graph_save_to='INT8.onnx') 36 | 37 | from ppq.utils.OpenvinoUtil import Benchmark 38 | Benchmark(ir_or_onnx_file='FP32.onnx', samples=500, jobs=4) 39 | Benchmark(ir_or_onnx_file='INT8.onnx', samples=500, jobs=4) -------------------------------------------------------------------------------- /ppq/samples/bypass_nms.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------- 2 | # 这个脚本向你展示了如何使用绕过那些跟量化没什么关系的算子 3 | # 当你的算子处于网络的最后,其后也没有什么需要量化的算子了 4 | # 你就可以给它定义一个假的 forward 函数,从而帮助 PPQ 完成量化 5 | # PPQ 不再需要收集其后的数据信息,所以错误额计算过程也能得到正确的量化结果 6 | # 7 | # 当然如果你的自定义算子如果会干涉到量化过程,你还是需要向 PPQ 提供一个的执行函数 8 | # --------------------------------------------------------------- 9 | 10 | # For this inference test, all test data is randomly picked. 11 | # If you want to use real data, just rewrite the defination of SAMPLES 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from ppq import * 16 | from ppq.api import * 17 | 18 | BATCHSIZE = 32 19 | INPUT_SHAPE = [3, 224, 224] 20 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 21 | PLATFORM = TargetPlatform.TRT_INT8 # identify a target platform for your network. 22 | 23 | def load_calibration_dataset() -> Iterable: 24 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 25 | 26 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 27 | return batch.to(DEVICE) 28 | 29 | def happy_forward_func(op: Operation, values, ctx, **kwargs): 30 | """你必须保证函数签名满足要求,即函数的输入和返回值都满足 PPQ 的系统要求 31 | 32 | 你的执行函数将接收 op, values, ctx 三个元素作为输入 33 | 其中 op 反应了当前执行的算子信息,values是一个数组,包含了算子所有输入 34 | ctx 是 PPQ 执行上下文 35 | 36 | 你将返回一个 torch.Tensor 或者多个 torch.Tensor 作为结果 37 | 这取决于你的算子在onnx中有多少个输出 38 | """ 39 | return torch.zeros(size=[1, 100, 5]), torch.zeros(size=[1, 100]) 40 | 41 | # --------------------------------------------- 42 | # 注册一个假的函数让我们绕过 nms 43 | # --------------------------------------------- 44 | register_operation_handler( 45 | happy_forward_func, 46 | operation_type="TRTBatchedNMS", 47 | platform=TargetPlatform.FP32) 48 | 49 | quant_setting = QuantizationSettingFactory.default_setting() 50 | 51 | # For pytorch user, just dump your network to disk with onnx first 52 | unquantized = load_onnx_graph(onnx_import_file='Output/onnx.model') 53 | 54 | # Load training data for creating a calibration dataloader. 55 | calibration_dataset = load_calibration_dataset() 56 | calibration_dataloader = DataLoader( 57 | dataset=calibration_dataset, 58 | batch_size=BATCHSIZE, shuffle=False) 59 | 60 | # quantize your model. 61 | quantized = quantize_native_model( 62 | model=unquantized, calib_dataloader=calibration_dataloader, 63 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 64 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 65 | device=DEVICE, verbose=0) 66 | 67 | # Quantization Result is a PPQ BaseGraph instance. 68 | assert isinstance(quantized, BaseGraph) 69 | 70 | # export quantized graph. 71 | export_ppq_graph(graph=quantized, platform=PLATFORM, 72 | graph_save_to='Output/quantized(onnx).onnx', 73 | config_save_to='Output/quantized(onnx).json') 74 | -------------------------------------------------------------------------------- /ppq/samples/custimize_quant_func.py: -------------------------------------------------------------------------------- 1 | # This Example shows you how to replace quantization function with your own quant delegator. 2 | 3 | from typing import Iterable 4 | 5 | import torch 6 | import torchvision 7 | from ppq import * 8 | from ppq.api import quantize_torch_model 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | class MyQuantDelegator(TorchQuantizeDelegator): 13 | """Use This class to realize your quantization logic. 14 | 15 | Inherit class TorchQuantizeDelegate, implement interface __call__, then 16 | register your delegator with executor.register_quantize_delegate 17 | """ 18 | def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor: 19 | if config.policy.has_property(QuantizationProperty.ASYMMETRICAL): 20 | raise ValueError('Sorry, this delegator handles only Symmetrical Quantizations.') 21 | print('You are invoking cusitmized quant function now.') 22 | return torch.round(tensor / config.scale) * config.scale 23 | 24 | BATCHSIZE = 32 25 | INPUT_SHAPE = [3, 224, 224] 26 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 27 | PLATFORM = TargetPlatform.PPL_CUDA_INT8 # identify a target platform for your network. 28 | 29 | def load_calibration_dataset() -> Iterable: 30 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 31 | 32 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 33 | return batch.to(DEVICE) 34 | 35 | # Load a pretrained mobilenet v2 model 36 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) 37 | model = model.to(DEVICE) 38 | 39 | # create a setting for quantizing your network with PPL CUDA. 40 | quant_setting = QuantizationSettingFactory.pplcuda_setting() 41 | quant_setting.equalization = True # use layerwise equalization algorithm. 42 | quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way. 43 | 44 | # Load training data for creating a calibration dataloader. 45 | calibration_dataset = load_calibration_dataset() 46 | calibration_dataloader = DataLoader( 47 | dataset=calibration_dataset, 48 | batch_size=BATCHSIZE, shuffle=True) 49 | 50 | # quantize your model. 51 | quantized = quantize_torch_model( 52 | model=model, calib_dataloader=calibration_dataloader, 53 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 54 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 55 | onnx_export_file='Output/onnx.model', device=DEVICE, verbose=0) 56 | 57 | # Quantization Result is a PPQ BaseGraph instance. 58 | assert isinstance(quantized, BaseGraph) 59 | 60 | # build an executor: 61 | executor = TorchExecutor(graph=quantized, device=DEVICE) 62 | 63 | # register quant delegator 64 | for var in quantized.variables.values(): 65 | if isinstance(var, QuantableVariable): 66 | if var.source_op_config is not None: 67 | executor.register_quantize_delegate(var.source_op_config, MyQuantDelegator) 68 | 69 | # run with your network, results are torch.Tensors 70 | for data in tqdm(calibration_dataloader, desc='Running with executor.'): 71 | results = executor.forward(inputs=data.to(DEVICE)) 72 | break 73 | 74 | # remove delegators 75 | for var in quantized.variables.values(): 76 | if isinstance(var, QuantableVariable): 77 | if var.source_op_config is not None: 78 | executor.remove_quantize_delegate(var.source_op_config) 79 | -------------------------------------------------------------------------------- /ppq/samples/custimized_quant.py: -------------------------------------------------------------------------------- 1 | """这个脚本将教会你如何使用 PPQ 量化自定义算子""" 2 | 3 | import torch 4 | from ppq import * 5 | from ppq.api import * 6 | from ppq.quantization.quantizer import TensorRTQuantizer 7 | 8 | B = 1 9 | T = 64 10 | MODEL_PATH = 'models\encoder_ln.onnx' 11 | 12 | def generate_samples(num_of_samples: int = 32): 13 | """生成样本数据,把这个函数改成真实数据读入就可以完成量化了 14 | 这个语音数据量很小 我建议你把整个数据集直接全部送上CUDA 15 | """ 16 | sample = { 17 | 'speech': torch.rand(size=[B, T, 80]).float().cuda(), 18 | 'speech_lengths': torch.ones(size=[B]).int().cuda()} 19 | samples = [sample for _ in range(num_of_samples)] 20 | return samples 21 | SAMPLES = generate_samples() 22 | 23 | # 定义一个自己的量化器,定制量化行为,继承于 TensorRTQuantizer 量化器 24 | class MyTensorRTQuantizer(TensorRTQuantizer): 25 | @ property 26 | def quant_operation_types(self) -> set: 27 | """覆盖 quant_operation_types 自定义需要量化的算子""" 28 | return {'LayerNormPlugin'} 29 | 30 | def init_quantize_config( 31 | self, operation: Operation) -> OperationQuantizationConfig: 32 | config = super().init_quantize_config(operation=operation) 33 | """针对 LayerNormPlugin 生成量化配置信息""" 34 | if operation.type == 'LayerNormPlugin': 35 | wconfig = config.input_quantization_config[1] # weight config 36 | bconfig = config.input_quantization_config[2] 37 | 38 | wconfig.policy = QuantizationPolicy( # weight 做 Per channel 量化 39 | QuantizationProperty.SYMMETRICAL + 40 | QuantizationProperty.LINEAR + 41 | QuantizationProperty.PER_TENSOR) 42 | bconfig.state = QuantizationStates.FP32 # bias 不量化 43 | ''' 44 | # 将 weight config 升级为 ChannelwiseTensorQuantizationConfig 45 | config.input_quantization_config[1] = ( 46 | ChannelwiseTensorQuantizationConfig.convert_from_tensor_config( 47 | convert_from = wconfig, 48 | offsets = None, scales = None, 49 | channel_axis = 0)) 50 | ''' 51 | config.input_quantization_config[1].observer_algorithm = 'Minmax' 52 | return config 53 | 54 | def layernorm_forward(op: Operation, values: List[torch.Tensor], ctx = None, **kwargs): 55 | """自定义算子的前向传播函数""" 56 | return values[0] 57 | 58 | register_operation_handler( 59 | handler=layernorm_forward, 60 | operation_type='LayerNormPlugin', 61 | platform=TargetPlatform.TRT_INT8) 62 | 63 | register_network_quantizer( 64 | quantizer=MyTensorRTQuantizer, 65 | platform=TargetPlatform.TRT_INT8) 66 | 67 | with ENABLE_CUDA_KERNEL(): 68 | QS = QuantizationSettingFactory.trt_setting() 69 | ir = load_onnx_graph(onnx_import_file=MODEL_PATH) 70 | # 默认调度失效,直接手动调度所有 LayerNormPlugin 送上量化区 71 | for op in ir.operations.values(): 72 | if op.type == 'LayerNormPlugin': 73 | QS.dispatching_table.append(operation=op.name, platform=TargetPlatform.TRT_INT8) 74 | 75 | qir = quantize_native_model( 76 | model=ir, calib_dataloader=SAMPLES, calib_steps=32, 77 | input_shape=None, inputs=SAMPLES[0], 78 | platform=TargetPlatform.TRT_INT8, setting=QS) 79 | 80 | graphwise_error_analyse( 81 | graph=qir, running_device='cuda', 82 | dataloader=SAMPLES) 83 | 84 | export_ppq_graph( 85 | graph=qir, platform=TargetPlatform.ONNXRUNTIME, 86 | graph_save_to='quantized.onnx') 87 | -------------------------------------------------------------------------------- /ppq/samples/dynamic_shape.py: -------------------------------------------------------------------------------- 1 | # This example shows how to make a dynamic-shape network 2 | # dynamic shape is only supported by onnx 3 | 4 | # first of all, load your model from anywhere 5 | from ppq import * 6 | from ppq.api import * 7 | 8 | YOU_WANT_TO_QUANTIZE_IT = True 9 | 10 | ir = load_onnx_graph('onnx model path') 11 | input_shape = [1, 3, 224, 224] 12 | samples = [torch.zeros(size=input_shape).cuda()] 13 | 14 | if YOU_WANT_TO_QUANTIZE_IT: 15 | ir = dispatch_graph(ir, platform=TargetPlatform.NCNN_INT8, 16 | setting=QuantizationSettingFactory.ncnn_setting()) 17 | 18 | ir = quantize_native_model( 19 | model=ir, calib_dataloader=samples, calib_steps=32, 20 | input_shape=input_shape, setting=QuantizationSettingFactory.ncnn_setting()) 21 | 22 | # You are supposed to set dynamic shape input/output variable just before export. 23 | # Get variable instance from ir by its name, set shape attribute as your wish. 24 | var = ir.variables['input variable name'] 25 | var.shape = ['Batch', 3, 'Width', 'Height'] 26 | # text, None, int are both acceptable here. 27 | 28 | export_ppq_graph(graph=ir, platform=TargetPlatform.NCNN_INT8, 29 | graph_save_to='onnx model save to', config_save_to='config save to') -------------------------------------------------------------------------------- /ppq/samples/enable_cuda_kernel.py: -------------------------------------------------------------------------------- 1 | # Since ppq 0.6.4, PPQ_CONFIG.USING_CUDA_KERNEL = False is the defualt execution option in ppq. 2 | # However you should notice that if you are able to compile ppq kernel functions, the execution speed wiil boost at least 3x. 3 | # This example will show you how to enable kernel function within ppq. 4 | # if you want to use kernel function everywhere, just rewrite ppq.core.config.PPQ_CONFIG.USING_CUDA_KERNEL = True 5 | 6 | from typing import Iterable 7 | 8 | import torch 9 | import torchvision 10 | from torch.utils.data import DataLoader 11 | 12 | from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform 13 | from ppq.api import export_ppq_graph, quantize_torch_model, ENABLE_CUDA_KERNEL 14 | 15 | BATCHSIZE = 32 16 | INPUT_SHAPE = [3, 224, 224] 17 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 18 | PLATFORM = TargetPlatform.PPL_CUDA_INT8 # identify a target platform for your network. 19 | 20 | def load_calibration_dataset() -> Iterable: 21 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 22 | 23 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 24 | return batch.to(DEVICE) 25 | 26 | # Load a pretrained mobilenet v2 model 27 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) 28 | model = model.to(DEVICE) 29 | 30 | # use this to wrap up your code 31 | # all functions inside ENABLE_CUDA_KERNEL will switch to ppq kernel functions. 32 | with ENABLE_CUDA_KERNEL(): 33 | 34 | # create a setting for quantizing your network with PPL CUDA. 35 | quant_setting = QuantizationSettingFactory.pplcuda_setting() 36 | quant_setting.equalization = True # use layerwise equalization algorithm. 37 | quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way. 38 | quant_setting.lsq_optimization = True # finetune your network. 39 | 40 | # Load training data for creating a calibration dataloader. 41 | calibration_dataset = load_calibration_dataset() 42 | calibration_dataloader = DataLoader( 43 | dataset=calibration_dataset, 44 | batch_size=BATCHSIZE, shuffle=True) 45 | 46 | # quantize your model. 47 | quantized = quantize_torch_model( 48 | model=model, calib_dataloader=calibration_dataloader, 49 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 50 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 51 | onnx_export_file='Output/onnx.model', device=DEVICE, verbose=0) 52 | 53 | # Quantization Result is a PPQ BaseGraph instance. 54 | assert isinstance(quantized, BaseGraph) 55 | 56 | # export quantized graph. 57 | export_ppq_graph(graph=quantized, platform=PLATFORM, 58 | graph_save_to='Output/quantized(onnx).onnx', 59 | config_save_to='Output/quantized(onnx).json') -------------------------------------------------------------------------------- /ppq/samples/onnx_converter.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import version_converter, helper 3 | 4 | # Preprocessing: load the model to be converted. 5 | model_path = 'models/mbfill.onnx' 6 | original_model = onnx.load(model_path) 7 | 8 | # A full list of supported adapters can be found here: 9 | # https://github.com/onnx/onnx/blob/main/onnx/version_converter.py#L21 10 | # Apply the version conversion on the original model 11 | converted_model = version_converter.convert_version(original_model, 11) 12 | 13 | onnx.save(converted_model, 'models/mbfill_11.onnx') -------------------------------------------------------------------------------- /ppq/samples/quantize_caffe_model.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | import torchvision 5 | from torch.utils.data import DataLoader 6 | 7 | from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform 8 | from ppq.api import export_ppq_graph, quantize_caffe_model 9 | 10 | BATCHSIZE = 32 11 | INPUT_SHAPE = [3, 96, 96] 12 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 13 | PLATFORM = TargetPlatform.PPL_CUDA_INT8 # identify a target platform for your network. 14 | PROTO_PATH = 'Models/model.prototxt' # For successfully loading caffe model, .prototxt file is required. 15 | MODEL_PATH = 'Models/model.caffemodel' # For successfully loading caffe model, .caffemodel file is required. 16 | 17 | def load_calibration_dataset() -> Iterable: 18 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 19 | 20 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 21 | return batch.to(DEVICE) 22 | 23 | # Load a pretrained mobilenet v2 model 24 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) 25 | model = model.to(DEVICE) 26 | 27 | # create a setting for quantizing your network with PPL CUDA. 28 | quant_setting = QuantizationSettingFactory.pplcuda_setting() 29 | quant_setting.equalization = True # use layerwise equalization algorithm. 30 | quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way. 31 | 32 | # Load training data for creating a calibration dataloader. 33 | calibration_dataset = load_calibration_dataset() 34 | calibration_dataloader = DataLoader( 35 | dataset=calibration_dataset, 36 | batch_size=BATCHSIZE, shuffle=True) 37 | 38 | # quantize your model. 39 | quantized = quantize_caffe_model( 40 | caffe_model_file=MODEL_PATH, caffe_proto_file=PROTO_PATH, 41 | calib_dataloader=calibration_dataloader, 42 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 43 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 44 | device=DEVICE, verbose=0) 45 | 46 | # Quantization Result is a PPQ BaseGraph instance. 47 | assert isinstance(quantized, BaseGraph) 48 | 49 | # export quantized graph. 50 | export_ppq_graph(graph=quantized, platform=TargetPlatform.CAFFE, 51 | graph_save_to='Output/quantized(caffe)', 52 | config_save_to='Output/quantized(caffe).json') 53 | -------------------------------------------------------------------------------- /ppq/samples/quantize_dsp.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | import torchvision 5 | from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform 6 | from ppq.api import export_ppq_graph, quantize_torch_model 7 | from torch.utils.data import DataLoader 8 | 9 | BATCHSIZE = 32 10 | INPUT_SHAPE = [3, 224, 224] 11 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 12 | PLATFORM = TargetPlatform.PPL_DSP_INT8 # identify a target platform for your network. 13 | 14 | def load_calibration_dataset() -> Iterable: 15 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 16 | 17 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 18 | return batch.to(DEVICE) 19 | 20 | # Load a pretrained mobilenet v2 model 21 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) 22 | model = model.to(DEVICE) 23 | 24 | # create a setting for quantizing your network with PPL CUDA. 25 | quant_setting = QuantizationSettingFactory.pplcuda_setting() 26 | quant_setting.equalization = True # use layerwise equalization algorithm. 27 | quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way. 28 | 29 | # Load training data for creating a calibration dataloader. 30 | calibration_dataset = load_calibration_dataset() 31 | calibration_dataloader = DataLoader( 32 | dataset=calibration_dataset, 33 | batch_size=BATCHSIZE, shuffle=True) 34 | 35 | # quantize your model. 36 | quantized = quantize_torch_model( 37 | model=model, calib_dataloader=calibration_dataloader, 38 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 39 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 40 | onnx_export_file='Output/onnx.model', device=DEVICE, verbose=0) 41 | 42 | # Quantization Result is a PPQ BaseGraph instance. 43 | assert isinstance(quantized, BaseGraph) 44 | 45 | # export quantized graph. 46 | export_ppq_graph(graph=quantized, platform=PLATFORM, 47 | graph_save_to='Output/quantized(onnx).onnx', 48 | config_save_to='Output/quantized(onnx).json') 49 | -------------------------------------------------------------------------------- /ppq/samples/quantize_onnx_model.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform 7 | from ppq.api import export_ppq_graph, quantize_onnx_model 8 | 9 | BATCHSIZE = 32 10 | INPUT_SHAPE = [3, 224, 224] 11 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 12 | PLATFORM = TargetPlatform.TRT_INT8 # identify a target platform for your network. 13 | ONNX_PATH = 'Models/cls_model/mobilenet_v2.onnx' 14 | 15 | def load_calibration_dataset() -> Iterable: 16 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 17 | 18 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 19 | return batch.to(DEVICE) 20 | 21 | quant_setting = QuantizationSettingFactory.trt_setting() 22 | 23 | # Load training data for creating a calibration dataloader. 24 | calibration_dataset = load_calibration_dataset() 25 | calibration_dataloader = DataLoader( 26 | dataset=calibration_dataset, 27 | batch_size=BATCHSIZE, shuffle=True) 28 | 29 | # quantize your model. 30 | quantized = quantize_onnx_model( 31 | onnx_import_file=ONNX_PATH, calib_dataloader=calibration_dataloader, 32 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 33 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 34 | device=DEVICE, verbose=0) 35 | 36 | # Quantization Result is a PPQ BaseGraph instance. 37 | assert isinstance(quantized, BaseGraph) 38 | 39 | # export quantized graph. 40 | export_ppq_graph(graph=quantized, platform=PLATFORM, 41 | graph_save_to='Output/quantized(onnx).onnx', 42 | config_save_to='Output/quantized(onnx).json') 43 | -------------------------------------------------------------------------------- /ppq/samples/quantize_torch_model.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | 3 | import torch 4 | import torchvision 5 | from torch.utils.data import DataLoader 6 | 7 | from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform 8 | from ppq.api import export_ppq_graph, quantize_torch_model 9 | 10 | BATCHSIZE = 32 11 | INPUT_SHAPE = [3, 224, 224] 12 | DEVICE = 'cuda' # only cuda is fully tested :( For other executing device there might be bugs. 13 | PLATFORM = TargetPlatform.PPL_CUDA_INT8 # identify a target platform for your network. 14 | 15 | def load_calibration_dataset() -> Iterable: 16 | return [torch.rand(size=INPUT_SHAPE) for _ in range(32)] 17 | 18 | def collate_fn(batch: torch.Tensor) -> torch.Tensor: 19 | return batch.to(DEVICE) 20 | 21 | # Load a pretrained mobilenet v2 model 22 | model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True) 23 | model = model.to(DEVICE) 24 | 25 | # create a setting for quantizing your network with PPL CUDA. 26 | quant_setting = QuantizationSettingFactory.pplcuda_setting() 27 | quant_setting.equalization = True # use layerwise equalization algorithm. 28 | quant_setting.dispatcher = 'conservative' # dispatch this network in conservertive way. 29 | 30 | # Load training data for creating a calibration dataloader. 31 | calibration_dataset = load_calibration_dataset() 32 | calibration_dataloader = DataLoader( 33 | dataset=calibration_dataset, 34 | batch_size=BATCHSIZE, shuffle=True) 35 | 36 | # quantize your model. 37 | quantized = quantize_torch_model( 38 | model=model, calib_dataloader=calibration_dataloader, 39 | calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE, 40 | setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM, 41 | onnx_export_file='Output/onnx.model', device=DEVICE, verbose=0) 42 | 43 | # Quantization Result is a PPQ BaseGraph instance. 44 | assert isinstance(quantized, BaseGraph) 45 | 46 | # export quantized graph. 47 | export_ppq_graph(graph=quantized, platform=PLATFORM, 48 | graph_save_to='Output/quantized(onnx).onnx', 49 | config_save_to='Output/quantized(onnx).json') 50 | -------------------------------------------------------------------------------- /ppq/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import (GraphDispatcher, reverse_tracing_pattern, 2 | value_tracing_pattern) 3 | from .dispatchers import AggresiveDispatcher, ConservativeDispatcher, PPLNNDispatcher, PointDispatcher 4 | from .allin import AllinDispatcher 5 | from .perseus import Perseus 6 | # Do not forget register your dispather here. 7 | 8 | DISPATCHER_TABLE = { 9 | "conservative": ConservativeDispatcher, 10 | "pplnn": PPLNNDispatcher, 11 | "aggresive": AggresiveDispatcher, 12 | "pointwise": PointDispatcher, 13 | "allin": AllinDispatcher, 14 | 'perseus': Perseus 15 | } -------------------------------------------------------------------------------- /ppq/scheduler/allin.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Set 2 | 3 | from ppq.core import TargetPlatform 4 | from ppq.IR import BaseGraph 5 | from .base import GraphDispatcher 6 | 7 | 8 | class AllinDispatcher(GraphDispatcher): 9 | """Graph Dispatcher cuts a graph into parts, each part of graph will 10 | dispatch to a specific platform for further execution and quantization. 11 | ATTENTION: this dispatcher will enable all ops in quant_types to quant_platform. 12 | """ 13 | def __init__(self, graph: BaseGraph) -> None: 14 | super().__init__() 15 | self.graph = graph 16 | 17 | def dispatch( 18 | self, quant_types: Set[str], 19 | quant_platform: TargetPlatform = TargetPlatform.UNSPECIFIED, 20 | fp32_platform: TargetPlatform = TargetPlatform.FP32, 21 | SOI_platform: TargetPlatform = TargetPlatform.SOI, **kwargs 22 | ) -> Dict[str, TargetPlatform]: 23 | """ 24 | We assume all ops in origin model can be quant. 25 | This is suitable for some npu platform. 26 | Args: 27 | graph (BaseGraph): graph object which going to be dispatched by this dispatcher. 28 | quant_types(Set[str]): all quantable types for given platforms. 29 | quant_platform (TargetPlatform): 30 | platform object where quantable parts will goes to. 31 | fp32_platform (TargetPlatform): 32 | platform object where SOI parts will goes to. 33 | SOI_platform (TargetPlatform): 34 | platform object where remaining parts will goes to. 35 | Returns: 36 | Dict[str, TargetPlatform]: [description] 37 | """ 38 | graph = self.graph 39 | 40 | dispatching_table = {} 41 | for op in graph.operations.values(): 42 | if op.type in quant_types: 43 | dispatching_table[op.name] = TargetPlatform.UNSPECIFIED 44 | else: 45 | dispatching_table[op.name] = TargetPlatform.FP32 46 | 47 | return dispatching_table -------------------------------------------------------------------------------- /ppq/utils/OnnxruntimeUtil.py: -------------------------------------------------------------------------------- 1 | # https://onnxruntime.ai/docs/api/python/auto_examples/plot_profiling.html 2 | import onnxruntime as ort 3 | import numpy as np 4 | from tqdm import tqdm 5 | from time import time 6 | 7 | def Benchmark(onnx_file: str, steps: int = 10000, providers=['CUDAExecutionProvider'], provider_options=None) -> float: 8 | """ Benckmark with Onnxruntime 9 | 10 | * Quantized Model of Onnxruntime - TensorrtExecutionProvider and Onnxruntime - CUDAExecutionProvider has different format. 11 | 12 | * Onnx that generated by PPQ is not supportable with TensorrtExecutionProvider. 13 | 14 | * Set providers=CUDAExecutionProvider before benchmark this file. 15 | 16 | """ 17 | sess = ort.InferenceSession(path_or_bytes=onnx_file, providers=providers, provider_options=provider_options) 18 | 19 | feed_dict, output_names = {}, [] 20 | for input_meta in sess.get_inputs(): 21 | name, dtype, shape = input_meta.name, input_meta.type, input_meta.shape 22 | 23 | for element in shape: 24 | if element is None or type(element) == str: 25 | raise TypeError('Dynamic input is not supported by this function.') 26 | 27 | if dtype == 'tensor(float)': 28 | feed_dict[name] = np.random.random(size=shape).astype(np.float32) 29 | else: 30 | raise Exception(f'Input {name} has unexpected data type.') 31 | 32 | for output_meta in sess.get_outputs(): 33 | output_names.append(output_meta.name) 34 | 35 | tick = time() 36 | for _ in tqdm(range(steps)): 37 | sess.run(output_names=output_names, input_feed=feed_dict) 38 | tok = time() 39 | 40 | print(f'Time span: {tok - tick : .4f} sec') 41 | return tick - tok 42 | 43 | 44 | def Profile(onnx_file: str, steps: int = 1, providers=['CUDAExecutionProvider'], provider_options=None): 45 | """ Profile with Onnxruntime 46 | 47 | * Quantized Model of Onnxruntime - TensorrtExecutionProvider and Onnxruntime - CUDAExecutionProvider has different format. 48 | 49 | * Onnx that generated by PPQ is not supportable with TensorrtExecutionProvider. 50 | 51 | * Set providers=CUDAExecutionProvider before benchmark this file. 52 | 53 | """ 54 | options = ort.SessionOptions() 55 | options.enable_profiling = True 56 | sess = ort.InferenceSession( 57 | path_or_bytes=onnx_file, providers=providers, 58 | provider_options=provider_options, sess_options=options) 59 | 60 | feed_dict, output_names = {}, [] 61 | for input_meta in sess.get_inputs(): 62 | name, dtype, shape = input_meta.name, input_meta.type, input_meta.shape 63 | 64 | for element in shape: 65 | if element is None or type(element) == str: 66 | raise TypeError('Dynamic input is not supported by this function.') 67 | 68 | if dtype == 'tensor(float)': 69 | feed_dict[name] = np.random.random(size=shape).astype(np.float32) 70 | else: 71 | raise Exception(f'Input {name} has unexpected data type.') 72 | 73 | for output_meta in sess.get_outputs(): 74 | output_names.append(output_meta.name) 75 | 76 | for _ in tqdm(range(steps)): 77 | sess.run(output_names=output_names, input_feed=feed_dict) 78 | 79 | prof_file = sess.end_profiling() 80 | print(f'Profile file is generated at {prof_file}, open it with your web browser chrome://tracing/') -------------------------------------------------------------------------------- /ppq/utils/OpenvinoUtil.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import numpy as np 4 | import openvino.runtime as ov 5 | from tqdm import tqdm 6 | 7 | if ov.get_version() < '2022.1.0': 8 | raise Exception('Please Install Openvino >= 2022.1.0') 9 | 10 | def Benchmark(ir_or_onnx_file: str, samples: int = 500, jobs: int = 4) -> float: 11 | """ Run Performance Benckmark with given onnx model. (Or Openvino IR) 12 | 13 | By default this function will run with Async Mode. 14 | """ 15 | # https://docs.openvino.ai/latest/api/ie_python_api/_autosummary/openvino.runtime.InferRequest.html 16 | core = ov.Core() 17 | # core.add_extension("path_to_extension_library.so") 18 | model = core.read_model(ir_or_onnx_file) 19 | compiled_model = core.compile_model(model, 'CPU') 20 | 21 | infer_request = compiled_model.create_infer_request() 22 | print(f'Openvino Model Loaded: {len(infer_request.input_tensors)} Input Tensors, {len(infer_request.output_tensors)} Output Tensors') 23 | 24 | feed_dict = [] 25 | for tensor in infer_request.input_tensors: 26 | feed_dict.append(np.random.random(size=tensor.shape).astype(tensor.element_type.to_dtype())) 27 | 28 | # Start async inference on a single infer request 29 | infer_request.start_async() 30 | # Wait for 1 milisecond 31 | infer_request.wait_for(1) 32 | # Wait for inference completion 33 | infer_request.wait() 34 | infer_queue = ov.AsyncInferQueue(compiled_model, jobs=jobs) 35 | 36 | tick = time() 37 | for _ in tqdm(range(samples)): 38 | # Wait for at least one available infer request and start asynchronous inference 39 | infer_queue.start_async(feed_dict) 40 | # Wait for all requests to complete 41 | infer_queue.wait_all() 42 | tok = time() 43 | 44 | print(f'Time span: {tok - tick : .4f} sec') 45 | return tick - tok 46 | 47 | """ 48 | infer_request.infer(feed_dict) 49 | for record in infer_request.get_profiling_info(): 50 | print(record.node_name, record.node_type, record.cpu_time.total_seconds(), record.real_time.total_seconds()) 51 | """ -------------------------------------------------------------------------------- /ppq/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .attribute import process_attribute 2 | -------------------------------------------------------------------------------- /ppq/utils/attribute.py: -------------------------------------------------------------------------------- 1 | from ppq.log import NaiveLogger 2 | from ppq.core.defs import ppq_legacy 3 | 4 | logger = NaiveLogger.get_logger('PPQ') 5 | 6 | # attribute checker and preprocess 7 | def process_attribute(attr, input_shape, kernel_shape=None, op_type=None): 8 | # ASSUME input is 2D 9 | # assert len(input_shape) == 2 10 | # Get default attr value 11 | auto_pad = attr.get('auto_pad', 'NOTSET') 12 | strides = attr.get('strides', [1, 1]) 13 | dilations = attr.get('dilations', [1, 1]) 14 | kernels = attr.get('kernel_shape', kernel_shape) 15 | pad_needed = None 16 | 17 | if op_type == 'ConvTranspose' and 'output_shape' in attr: 18 | output_shape = attr['output_shape'] 19 | out_pad = [0, 1] if output_shape % 2 != 0 else [0, 0] 20 | pad_needed = [(input_shape[i] - 1) * strides[i] + dilations[i] * (kernels[i] - 1) + 1 + out_pad[i] - 21 | output_shape[i] for i in range(len(input_shape))] 22 | 23 | if auto_pad != 'NOTSET': 24 | if 'pads' in attr: 25 | logger.warning('auto_pad is conflict with pads attribute. Use pads here.') 26 | elif auto_pad == 'VALID': 27 | attr['pads'] = [0, 0, 0, 0] 28 | elif auto_pad in ('SAME_UPPER', 'SAME_LOWER'): 29 | if op_type == 'ConvTranspose': 30 | # `output_padding` is only used to find output shape, but does not actually add zero-padding to output 31 | out_pad = attr.get('output_padding', [0, 0]) 32 | output_shape = [input_shape[i] * strides[i] for i in range(len(input_shape))] 33 | pad_needed = [(input_shape[i] - 1) * strides[i] + dilations[i] * (kernels[i] - 1) + 1 + out_pad[i] - 34 | output_shape[i] for i in range(len(input_shape))] 35 | else: 36 | output_shape = [(input_shape[i] + strides[i] - 1) // strides[i] for i in range(len(input_shape))] 37 | pad_needed = [(output_shape[i] - 1) * strides[i] + dilations[i] * (kernels[i] - 1) + 1 - input_shape[i] 38 | for i in range(len(input_shape))] 39 | else: 40 | raise ValueError(f'Invalid auto_pad value {auto_pad}') 41 | 42 | if pad_needed is not None: 43 | pads = [] 44 | for item in pad_needed: 45 | pads.append((item if auto_pad == 'SAME_UPPER' else item + 1) // 2) 46 | # onnx pads format should be as follow [x1_begin, x2_begin...x1_end, x2_end,...] 47 | pads = pads + [pad_needed[i] - p for i, p in enumerate(pads)] 48 | attr['pads'] = pads 49 | # onnx pads attribute cannot be used simultaneously with auto_pad attribute 50 | attr.pop('auto_pad') 51 | 52 | 53 | def preprocess_attr(attr, op_type=None): 54 | processed_attribute = {} 55 | if 'kernel_shape' in attr and op_type == 'Pooling': 56 | processed_attribute['kernel_size'] = attr['kernel_shape'] 57 | if 'group' in attr: 58 | processed_attribute['groups'] = attr['group'] 59 | if 'pads' in attr: 60 | # Change pads from start-end to torch format 61 | pads = attr['pads'] 62 | assert (len(pads) % 2 == 0) 63 | if len(pads) == 4: 64 | begin_pad = pads[:2] 65 | end_pad = pads[2:] 66 | if begin_pad == end_pad: 67 | processed_attribute['padding'] = begin_pad 68 | else: 69 | raise ValueError('Torch function only support begin_pad == end_pad in layer') 70 | else: 71 | processed_attribute['padding'] = pads 72 | 73 | if 'dilations' in attr: 74 | processed_attribute['dilation'] = attr['dilations'] 75 | if 'strides' in attr: 76 | processed_attribute['stride'] = attr['strides'] 77 | if 'ceil_mode' in attr: 78 | processed_attribute['ceil_mode'] = bool(attr['ceil_mode']) 79 | return processed_attribute 80 | -------------------------------------------------------------------------------- /ppq/utils/ema.py: -------------------------------------------------------------------------------- 1 | from math import pow 2 | 3 | class EMARecorder(): 4 | """Exponential Moving Average(EMA) with bias correction.""" 5 | def __init__(self, beta: float = 0.98): 6 | self.beta = beta 7 | self.t = 0 8 | self.value = 0 9 | 10 | def push(self, value: float): 11 | self.value = (1.0 - self.beta) * value + self.beta * self.value 12 | self.t += 1 13 | 14 | def pop(self) -> float: 15 | if self.t == 0: return 0 16 | return self.value / (1 - pow(self.beta, self.t)) -------------------------------------------------------------------------------- /ppq/utils/graph_editor.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from ppq.IR import BaseGraph 3 | from ppq.IR import GraphFormatter 4 | 5 | 6 | def truncate_graph(graph: BaseGraph, outputs: List[str]): 7 | """truncate your graph, so that all operations behind outputs(function 8 | parameter) will be eliminated. A list of output variable is given as 9 | parameter of this function. PPQ will goes forward from all those variables, 10 | mark all downstream operations for removing. 11 | 12 | A truncated graph object will return as result. 13 | 14 | ATTENTION: do not attempt to delete input variable. 15 | ATTETNION: you should invoke this function before quantization. 16 | 17 | Args: 18 | graph (BaseGraph): graph to be truncated 19 | outputs (List[str]): truncating from where 20 | 21 | Raises: 22 | KeyError: truncating variable is not in graph 23 | 24 | Returns: 25 | [type]: truncated graph 26 | """ 27 | for output in outputs: 28 | if output not in graph.variables: 29 | raise KeyError(f'Can not find variable {output} in current graph.') 30 | processor = GraphFormatter(graph) 31 | 32 | for output in outputs: 33 | output_var = graph.variables[output] 34 | processor.truncate_on_var(output_var, mark_as_output=True) 35 | processor.delete_isolated() 36 | return graph 37 | -------------------------------------------------------------------------------- /ppq/utils/write_qparams_onnx2trt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import tensorrt as trt 5 | 6 | TRT_LOGGER = trt.Logger() 7 | 8 | EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 9 | 10 | def GiB(val): 11 | return val * 1 << 30 12 | 13 | def json_load(filename): 14 | with open(filename) as json_file: 15 | data = json.load(json_file) 16 | return data 17 | 18 | def setDynamicRange(network, json_file): 19 | """Sets ranges for network layers.""" 20 | quant_param_json = json_load(json_file) 21 | act_quant = quant_param_json["act_quant_info"] 22 | 23 | for i in range(network.num_inputs): 24 | input_tensor = network.get_input(i) 25 | if act_quant.__contains__(input_tensor.name): 26 | print(input_tensor.name) 27 | value = act_quant[input_tensor.name] 28 | tensor_max = abs(value) 29 | tensor_min = -abs(value) 30 | input_tensor.dynamic_range = (tensor_min, tensor_max) 31 | 32 | 33 | for i in range(network.num_layers): 34 | layer = network.get_layer(i) 35 | 36 | for output_index in range(layer.num_outputs): 37 | tensor = layer.get_output(output_index) 38 | 39 | if act_quant.__contains__(tensor.name): 40 | value = act_quant[tensor.name] 41 | tensor_max = abs(value) 42 | tensor_min = -abs(value) 43 | tensor.dynamic_range = (tensor_min, tensor_max) 44 | else: 45 | print("\033[1;32m%s\033[0m" % tensor.name) 46 | 47 | 48 | def build_engine(onnx_file, json_file, engine_file): 49 | builder = trt.Builder(TRT_LOGGER) 50 | network = builder.create_network(EXPLICIT_BATCH) 51 | 52 | config = builder.create_builder_config() 53 | 54 | # If it is a dynamic onnx model , you need to add the following. 55 | # profile = builder.create_optimization_profile() 56 | # profile.set_shape("input_name", (batch, channels, min_h, min_w), (batch, channels, opt_h, opt_w), (batch, channels, max_h, max_w)) 57 | # config.add_optimization_profile(profile) 58 | 59 | 60 | parser = trt.OnnxParser(network, TRT_LOGGER) 61 | config.max_workspace_size = GiB(1) 62 | 63 | if not os.path.exists(onnx_file): 64 | quit('ONNX file {} not found'.format(onnx_file)) 65 | 66 | with open(onnx_file, 'rb') as model: 67 | if not parser.parse(model.read()): 68 | print('ERROR: Failed to parse the ONNX file.') 69 | for error in range(parser.num_errors): 70 | print(parser.get_error(error)) 71 | return None 72 | 73 | config.set_flag(trt.BuilderFlag.INT8) 74 | 75 | setDynamicRange(network, json_file) 76 | 77 | engine = builder.build_engine(network, config) 78 | 79 | with open(engine_file, "wb") as f: 80 | f.write(engine.serialize()) 81 | 82 | 83 | if __name__ == '__main__': 84 | # Add plugins if needed 85 | # import ctypes 86 | # ctypes.CDLL("libmmdeploy_tensorrt_ops.so") 87 | parser = argparse.ArgumentParser(description='Writing qparams to onnx to convert tensorrt engine.') 88 | parser.add_argument('--onnx', type=str, default=None) 89 | parser.add_argument('--qparam_json', type=str, default=None) 90 | parser.add_argument('--engine', type=str, default=None) 91 | arg = parser.parse_args() 92 | 93 | build_engine(arg.onnx, arg.qparam_json, arg.engine) 94 | print("\033[1;32mgenerate %s\033[0m" % arg.engine) 95 | 96 | 97 | -------------------------------------------------------------------------------- /ppq/utils/write_qparams_to_snpe_dlc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import snpe 5 | import qti.aisw.dlc_utils as dlc 6 | 7 | parser = argparse.ArgumentParser(description='Write ppq qparams to snpe dlc') 8 | parser.add_argument('--input_dlc_model', default='snpe_quantized.dlc', help='path to snpe quantized dlc model') 9 | parser.add_argument('--output_dlc_model', default='ppq_export.dlc', help='path to export quantized dlc') 10 | parser.add_argument('--qparam', default='quantized.json', help='path to ppq qparams json') 11 | 12 | def json_load(filename): 13 | with open(filename) as json_file: 14 | data = json.load(json_file) 15 | return data 16 | 17 | def write_qparams_to_dlc_model(input_dlc, output_dlc, activation_qparams): 18 | model = dlc.modeltools.Model() 19 | model.load(input_dlc) 20 | model.set_tf_encoding_type("TF") 21 | 22 | for snpe_layer in model.get_layers(): 23 | print('\n write qparams to {}'.format(snpe_layer['name'])) 24 | for snpe_layer_out_ind, snpe_layer_out in enumerate(snpe_layer['output_names']): 25 | layer_name = snpe_layer['name'] 26 | print('original quant encodings : ', model.get_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind)) 27 | top = snpe_layer['output_names'][0] 28 | 29 | if top not in activation_qparams.keys(): 30 | # Before the Reshape layer, SNPE will insert the shape conversion layer(xxx.ncs) 31 | # Because the SNPE data is arranged as NHWC 32 | assert top.endswith('.ncs'), '{} ranges not exists'.format(top) 33 | bottom = snpe_layer['input_names'][0] 34 | new_enc = activation_qparams[bottom][0] #List[dict] 35 | else: 36 | new_enc = activation_qparams[top][0] #List[dict] 37 | 38 | model.set_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind, bitwidth=8, min=new_enc["min"], max=new_enc["max"]) 39 | print('ppq quant encodings : ', model.get_tf_output_encoding_by_index(name=layer_name, index=snpe_layer_out_ind)) 40 | model.quantize_weights(should_quantize=True) 41 | model.save(output_dlc) 42 | 43 | if __name__ == '__main__': 44 | args = parser.parse_args() 45 | act_ranges = json_load(args.qparam)['activation_encodings'] 46 | write_qparams_to_dlc_model(args.input_dlc_model, args.output_dlc_model, act_ranges) 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | onnx >= 1.9.0 3 | protobuf 4 | torch >= 1.6.0 5 | tqdm 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from ppq.core import PPQ_CONFIG 3 | 4 | def readme(): 5 | with open('README.md', encoding='utf-8') as f: 6 | content = f.read() 7 | return content 8 | 9 | setup(author='ppq', 10 | author_email='dcp-ppq@sensetime.com', 11 | description='PPQ is an offline quantization tools', 12 | long_description=readme(), 13 | long_description_content_type='text/markdown', 14 | install_requires=open('requirements.txt').readlines(), 15 | python_requires='>=3.6', 16 | name='ppq', 17 | packages=find_packages(), 18 | classifiers=[ 19 | 'Development Status :: 3 - Alpha', 20 | 'License :: OSI Approved :: Apache Software License', 21 | 'Operating System :: OS Independent', 22 | 'Programming Language :: Python :: 3.6', 23 | 'Programming Language :: Python :: 3.7', 24 | 'Programming Language :: Python :: 3.8', 25 | 'Programming Language :: Python :: 3.9', 26 | ], 27 | license='Apache License 2.0', 28 | include_package_data=True, 29 | version=PPQ_CONFIG.VERSION, 30 | zip_safe=False 31 | ) 32 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """ppq/tests contains all PPQ test scripts, all scripts are manage by pytest. 2 | 3 | We start building this test collection since ppq 0.6.2 you are welcome to 4 | contribute codes to this collection. 5 | """ 6 | -------------------------------------------------------------------------------- /tests/test_activation_fusion.py: -------------------------------------------------------------------------------- 1 | from ppq import * 2 | from ppq.IR.morph import GraphMerger 3 | from ppq.api import * 4 | import torch 5 | 6 | graph = BaseGraph(name='test', built_from=NetworkFramework.ONNX) 7 | matmul = \ 8 | graph.create_operation(op_type='Matmul', name='matmul', 9 | platform=TargetPlatform.UNSPECIFIED, 10 | inputs=[graph.create_variable(), graph.create_variable(is_parameter=True, value=torch.zeros(size=[10, 10]))], 11 | outputs=[graph.create_variable()]) 12 | graph.create_operation(op_type='Relu', name='relu', platform=TargetPlatform.UNSPECIFIED, 13 | inputs=[matmul.outputs[0], graph.create_variable(is_parameter=True, value=torch.zeros(size=[10, ]))], 14 | outputs=[graph.create_variable()]) 15 | processor = QuantableGraph(graph) 16 | processor.quantize_operation('matmul', target_platform=TargetPlatform.PPL_CUDA_INT8) 17 | processor.quantize_operation('relu', target_platform=TargetPlatform.PPL_CUDA_INT8) 18 | -------------------------------------------------------------------------------- /tests/test_block.py: -------------------------------------------------------------------------------- 1 | from tests.tmodel import * 2 | from tests.tscheme import * 3 | from ..ppq import * 4 | from ..ppq.api import * 5 | 6 | DEVICE = 'cuda' 7 | 8 | with ENABLE_CUDA_KERNEL(): 9 | for scheme in TEST_SCHEMES: 10 | for case in TORCH_TEST_BLOCKS: 11 | try: 12 | print(f'PPQ System test start with model {case.model_name}, Scheme: {scheme.name}') 13 | 14 | dataset = [case.input_generator().to(DEVICE) for _ in range(8)] 15 | model = case.model_builder().to(DEVICE).eval() 16 | reference_outputs = torch.cat([model(batch) for batch in dataset]) 17 | 18 | quantized = quantize_torch_model( 19 | model=model, 20 | calib_dataloader=dataset, 21 | calib_steps=8, 22 | input_shape=case.input_generator().shape, 23 | platform=scheme.quant_platform, 24 | setting=scheme.setting, 25 | verbose=False) 26 | 27 | executor = TorchExecutor(quantized) 28 | for op in quantized.operations.values(): 29 | if isinstance(op, QuantableOperation): op.dequantize() 30 | ppq_outputs = torch.cat([executor(batch)[0] for batch in dataset]) 31 | 32 | for op in quantized.operations.values(): 33 | if isinstance(op, QuantableOperation): op.restore_quantize_state() 34 | quant_outputs = torch.cat([executor(batch)[0] for batch in dataset]) 35 | assert torch_snr_error(ppq_outputs, reference_outputs).item() < 1e-4, ( 36 | f'Network Simulating Failed, expect error < 1e-3, ' 37 | f'got {torch_snr_error(ppq_outputs, reference_outputs).item()}') 38 | assert torch_snr_error(quant_outputs, reference_outputs).item() < 0.1, ( 39 | f'Network Quantization Failed, expect error < 0.1, ' 40 | f'got {torch_snr_error(quant_outputs, reference_outputs).item()}') 41 | 42 | if (case.depoly_platforms is None or 43 | scheme.export_platform in case.depoly_platforms): 44 | export_ppq_graph( 45 | graph=quantized, 46 | platform=scheme.export_platform, 47 | graph_save_to='tworkingspace/export', 48 | config_save_to='tworkingspace/export.json') 49 | except NotImplementedError as e: 50 | pass -------------------------------------------------------------------------------- /tests/test_block_split.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ppq.core.quant import TargetPlatform 4 | 5 | class MyModel(torch.nn.Module): 6 | def __init__(self) -> None: 7 | super().__init__() 8 | self.gemm_1 = torch.nn.Linear(in_features=10, out_features=10) 9 | self.gemm_2 = torch.nn.Linear(in_features=10, out_features=10) 10 | self.gemm_3 = torch.nn.Linear(in_features=10, out_features=10) 11 | self.gemm_4 = torch.nn.Linear(in_features=10, out_features=10) 12 | self.gemm_5 = torch.nn.Linear(in_features=10, out_features=10) 13 | self.gemm_6 = torch.nn.Linear(in_features=10, out_features=10) 14 | self.gemm_7 = torch.nn.Linear(in_features=10, out_features=10) 15 | self.gemm_8 = torch.nn.Linear(in_features=10, out_features=10) 16 | self.gemm_9 = torch.nn.Linear(in_features=10, out_features=10) 17 | self.gemm_10 = torch.nn.Linear(in_features=10, out_features=10) 18 | self.gemm_J = torch.nn.Linear(in_features=10, out_features=10) 19 | self.gemm_Q = torch.nn.Linear(in_features=10, out_features=10) 20 | self.gemm_K = torch.nn.Linear(in_features=10, out_features=10) 21 | self.gemm_A = torch.nn.Linear(in_features=10, out_features=10) 22 | 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: 24 | x = self.gemm_1(x) 25 | x = torch.relu(x) 26 | 27 | x2 = torch.relu(self.gemm_2(x)) 28 | x3 = torch.relu(self.gemm_3(x)) 29 | x4 = torch.relu(self.gemm_4(x)) 30 | x5 = torch.relu(self.gemm_5(x)) 31 | x6 = torch.relu(self.gemm_6(x)) 32 | 33 | x2 = self.gemm_7(x2) 34 | x3 = self.gemm_8(x3) 35 | x4 = self.gemm_9(x4) 36 | x5 = self.gemm_10(x5) 37 | x6 = self.gemm_J(x6) 38 | 39 | x7 = torch.relu(self.gemm_Q(x)) 40 | x7 = self.gemm_K(x7) 41 | 42 | x8 = torch.max_pool1d(x7, kernel_size=2) 43 | return torch.cat([x2, x3, x4, x5, x6, x7, x8], dim=-1) 44 | 45 | model = MyModel().cuda() 46 | model.forward(torch.zeros(size=[10, 10]).cuda()) 47 | 48 | from ppq.api import load_torch_model 49 | from ppq.api import quantize_native_model 50 | from ppq.api import QuantizationSettingFactory 51 | 52 | graph = load_torch_model(model=model, sample=torch.zeros(size=[10, 10]).cuda()) 53 | s = QuantizationSettingFactory.default_setting() 54 | s.lsq_optimization = True 55 | s.lsq_optimization_setting.block_size = 4 56 | 57 | quantize_native_model( 58 | model=graph, calib_dataloader=[torch.zeros(size=[10, 10])], input_shape=[10, 10], 59 | calib_steps=8, collate_fn=lambda x: x.cuda(), platform=TargetPlatform.TRT_INT8, 60 | setting=s) -------------------------------------------------------------------------------- /tests/test_gemm_fusion.py: -------------------------------------------------------------------------------- 1 | from ppq import * 2 | from ppq.IR.morph import GraphMerger 3 | from ppq.api import * 4 | import torch 5 | 6 | graph = BaseGraph(name='test', built_from=NetworkFramework.ONNX) 7 | matmul = \ 8 | graph.create_operation(op_type='Matmul', name='matmul', 9 | platform=TargetPlatform.UNSPECIFIED, 10 | inputs=[graph.create_variable(), graph.create_variable(is_parameter=True, value=torch.zeros(size=[10, 10]))], 11 | outputs=[graph.create_variable()]) 12 | graph.create_operation(op_type='Add', name='add', platform=TargetPlatform.UNSPECIFIED, 13 | inputs=[matmul.outputs[0], graph.create_variable(is_parameter=True, value=torch.zeros(size=[10, ]))], 14 | outputs=[graph.create_variable()]) 15 | processor = GraphMerger(graph) 16 | processor.fuse_gemm() 17 | 18 | assert len(graph.operations) == 1 19 | assert len(graph.operations['matmul'].inputs) == 3 20 | assert graph.operations['matmul'].type == 'Gemm' 21 | 22 | graph = BaseGraph(name='test', built_from=NetworkFramework.ONNX) 23 | matmul = \ 24 | graph.create_operation(op_type='Matmul', name='matmul', 25 | platform=TargetPlatform.UNSPECIFIED, 26 | inputs=[graph.create_variable(), graph.create_variable(is_parameter=True, value=torch.zeros(size=[10, 10]))], 27 | outputs=[graph.create_variable()]) 28 | test = \ 29 | graph.create_operation(op_type='Test', name='test', platform=TargetPlatform.UNSPECIFIED, 30 | inputs=[], outputs=[graph.create_variable()]) 31 | graph.create_operation(op_type='Add', name='add', platform=TargetPlatform.UNSPECIFIED, 32 | inputs=[matmul.outputs[0], test.outputs[0]], 33 | outputs=[graph.create_variable()]) 34 | processor = GraphMerger(graph) 35 | processor.fuse_gemm() 36 | 37 | assert len(graph.operations) == 3 38 | assert len(graph.operations['matmul'].inputs) == 2 39 | assert graph.operations['matmul'].type == 'Gemm' 40 | 41 | 42 | graph = BaseGraph(name='test', built_from=NetworkFramework.ONNX) 43 | matmul = \ 44 | graph.create_operation(op_type='Matmul', name='matmul', 45 | platform=TargetPlatform.UNSPECIFIED, 46 | inputs=[graph.create_variable(), graph.create_variable(is_parameter=True, value=torch.zeros(size=[10, 10]))], 47 | outputs=[graph.create_variable()]) 48 | graph.create_operation(op_type='Add', name='add', platform=TargetPlatform.UNSPECIFIED, 49 | inputs=[matmul.outputs[0], graph.create_variable(is_parameter=True, value=torch.zeros(size=[1, ]))], 50 | outputs=[graph.create_variable()]) 51 | processor = GraphMerger(graph) 52 | processor.fuse_gemm() 53 | 54 | assert len(graph.operations) == 2 55 | assert len(graph.operations['matmul'].inputs) == 2 56 | assert graph.operations['matmul'].type == 'Gemm' -------------------------------------------------------------------------------- /tests/test_isotone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ppq.IR import Variable 4 | from ppq.lib import LinearQuantizationConfig 5 | from ppq.quantization.observer import TorchIsotoneObserver, TorchMinMaxObserver 6 | from ppq.quantization.qfunction import PPQLinearQuantFunction 7 | 8 | from ppq import QuantizationStates 9 | 10 | TQC = LinearQuantizationConfig() 11 | v = Variable(name='TestVariable') 12 | 13 | for i in range(10000): 14 | isotone_observer = TorchIsotoneObserver(v, TQC) 15 | minmax_observer = TorchMinMaxObserver(v, TQC) 16 | 17 | TQC.state = QuantizationStates.INITIAL 18 | value = torch.softmax(torch.rand(size=[1, 10]), dim=-1) 19 | value, _ = torch.sort(value, dim=-1) 20 | isotone_observer.observe(value) 21 | isotone_observer.render_quantization_config() 22 | 23 | isotone_quant = PPQLinearQuantFunction(value, TQC) 24 | o = torch.argmax(value, dim=-1) 25 | q = torch.argmax(isotone_quant, dim=-1) 26 | isotone_error_num = torch.sum(o != q).item() 27 | isotone_scale = TQC.scale 28 | 29 | TQC.state = QuantizationStates.INITIAL 30 | minmax_observer.observe(value) 31 | minmax_observer.render_quantization_config() 32 | 33 | minmax_quant = PPQLinearQuantFunction(value, TQC) 34 | o = torch.argmax(value, dim=-1) 35 | q = torch.argmax(minmax_quant, dim=-1) 36 | minmax_error_num = torch.sum(o != q).item() 37 | minmax_scale = TQC.scale 38 | 39 | if not isotone_error_num <= minmax_error_num: 40 | print(isotone_observer.s_candidates) 41 | print(isotone_error_num, minmax_error_num) 42 | print(value) 43 | print(isotone_quant, isotone_scale) 44 | print(minmax_quant, minmax_scale) 45 | raise Exception('Test Failed.') 46 | 47 | 48 | TQC = LinearQuantizationConfig(symmetrical=False) 49 | v = Variable(name='TestVariable') 50 | 51 | for i in range(10000): 52 | isotone_observer = TorchIsotoneObserver(v, TQC) 53 | minmax_observer = TorchMinMaxObserver(v, TQC) 54 | 55 | TQC.state = QuantizationStates.INITIAL 56 | value = torch.softmax(torch.rand(size=[1, 10]), dim=-1) 57 | value, _ = torch.sort(value, dim=-1) 58 | isotone_observer.observe(value) 59 | isotone_observer.render_quantization_config() 60 | 61 | isotone_quant = PPQLinearQuantFunction(value, TQC) 62 | o = torch.argmax(value, dim=-1) 63 | q = torch.argmax(isotone_quant, dim=-1) 64 | isotone_error_num = torch.sum(o != q).item() 65 | isotone_scale = TQC.scale 66 | 67 | TQC.state = QuantizationStates.INITIAL 68 | minmax_observer.observe(value) 69 | minmax_observer.render_quantization_config() 70 | 71 | minmax_quant = PPQLinearQuantFunction(value, TQC) 72 | o = torch.argmax(value, dim=-1) 73 | q = torch.argmax(minmax_quant, dim=-1) 74 | minmax_error_num = torch.sum(o != q).item() 75 | minmax_scale = TQC.scale 76 | 77 | if not isotone_error_num <= minmax_error_num: 78 | print(isotone_observer.s_candidates) 79 | print(isotone_error_num, minmax_error_num) 80 | print(value) 81 | print(isotone_quant, isotone_scale) 82 | print(minmax_quant, minmax_scale) 83 | raise Exception('Test Failed.') -------------------------------------------------------------------------------- /tests/test_rounding.py: -------------------------------------------------------------------------------- 1 | from ppq.core.quant import RoundingPolicy 2 | from ppq.utils.round import ppq_numerical_round, ppq_round_to_power_of_2 3 | 4 | if __name__ == '__main__': 5 | assert ppq_numerical_round(1.5, policy=RoundingPolicy.ROUND_HALF_EVEN) == 2 6 | assert ppq_numerical_round(2.5, policy=RoundingPolicy.ROUND_HALF_EVEN) == 2 7 | assert ppq_numerical_round(0.5, policy=RoundingPolicy.ROUND_HALF_EVEN) == 0 8 | assert ppq_numerical_round(-0.5, policy=RoundingPolicy.ROUND_HALF_EVEN) == 0 9 | assert ppq_numerical_round(1.1, policy=RoundingPolicy.ROUND_HALF_EVEN) == 1 10 | assert ppq_numerical_round(1.2, policy=RoundingPolicy.ROUND_HALF_EVEN) == 1 11 | assert ppq_numerical_round(1.3, policy=RoundingPolicy.ROUND_HALF_EVEN) == 1 12 | assert ppq_numerical_round(-1.1, policy=RoundingPolicy.ROUND_HALF_EVEN) == -1 13 | assert ppq_numerical_round(-1.2, policy=RoundingPolicy.ROUND_HALF_EVEN) == -1 14 | assert ppq_numerical_round(-1.3, policy=RoundingPolicy.ROUND_HALF_EVEN) == -1 15 | 16 | assert ppq_numerical_round(1.5, policy=RoundingPolicy.ROUND_HALF_UP) == 2 17 | assert ppq_numerical_round(2.5, policy=RoundingPolicy.ROUND_HALF_UP) == 3 18 | assert ppq_numerical_round(0.5, policy=RoundingPolicy.ROUND_HALF_UP) == 1 19 | assert ppq_numerical_round(-0.5, policy=RoundingPolicy.ROUND_HALF_UP) == 0 20 | 21 | assert ppq_numerical_round(1.5, policy=RoundingPolicy.ROUND_HALF_DOWN) == 1 22 | assert ppq_numerical_round(2.5, policy=RoundingPolicy.ROUND_HALF_DOWN) == 2 23 | assert ppq_numerical_round(0.5, policy=RoundingPolicy.ROUND_HALF_DOWN) == 0 24 | assert ppq_numerical_round(-0.5, policy=RoundingPolicy.ROUND_HALF_DOWN) == -1 25 | 26 | assert ppq_numerical_round(1.5, policy=RoundingPolicy.ROUND_HALF_TOWARDS_ZERO) == 1 27 | assert ppq_numerical_round(2.5, policy=RoundingPolicy.ROUND_HALF_TOWARDS_ZERO) == 2 28 | assert ppq_numerical_round(0.5, policy=RoundingPolicy.ROUND_HALF_TOWARDS_ZERO) == 0 29 | 30 | # 我并不知道下面这个判断为什么会出错 31 | # assert ppq_numerical_round(-0.5, policy=RoundingPolicy.ROUND_HALF_TOWARDS_ZERO) == 0 32 | 33 | assert ppq_round_to_power_of_2(1.0) == 1 34 | assert ppq_round_to_power_of_2(1.2) == 2 35 | assert ppq_round_to_power_of_2(3.2) == 4 36 | assert ppq_round_to_power_of_2(0.26) == 0.5 37 | assert ppq_round_to_power_of_2(0.24) == 0.25 38 | -------------------------------------------------------------------------------- /tests/test_system.py: -------------------------------------------------------------------------------- 1 | from tmodel import * 2 | from tscheme import * 3 | from ppq import * 4 | from ppq.api import * 5 | import sys 6 | 7 | DEVICE = 'cuda' 8 | 9 | with ENABLE_CUDA_KERNEL(): 10 | for scheme in TEST_SCHEMES: 11 | for case in TORCH_TEST_CASES: 12 | try: 13 | print(f'PPQ System test start with model {case.model_name}, Scheme: {scheme.name}') 14 | dataset = [case.input_generator().to(DEVICE) for _ in range(8)] 15 | model = case.model_builder().to(DEVICE) 16 | 17 | quantized = quantize_torch_model( 18 | model=model, 19 | calib_dataloader=dataset, 20 | calib_steps=8, 21 | input_shape=case.input_generator().shape, 22 | platform=scheme.quant_platform, 23 | setting=scheme.setting) 24 | 25 | if (case.deploy_platforms is None or 26 | scheme.export_platform in case.deploy_platforms): 27 | export_ppq_graph( 28 | graph=quantized, 29 | platform=scheme.export_platform, 30 | graph_save_to='tworkingspace/export', 31 | config_save_to='tworkingspace/export.json') 32 | except NotImplementedError as e: 33 | print(f'{time.strftime("%Y-%m-%d %H:%M:%S")} | Error occurred: {e}') 34 | sys.exit(1) 35 | -------------------------------------------------------------------------------- /tests/tmodel/__init__.py: -------------------------------------------------------------------------------- 1 | from .torchmodels import TORCH_TEST_CASES 2 | from .testblocks import TORCH_TEST_BLOCKS 3 | from .base import * -------------------------------------------------------------------------------- /tests/tmodel/base.py: -------------------------------------------------------------------------------- 1 | """This file defines all ppq test models.""" 2 | 3 | from enum import Enum 4 | from typing import Callable, List 5 | 6 | import torch 7 | from ppq.core import TargetPlatform 8 | 9 | 10 | class ModelType(Enum): 11 | CLASSIFY = 1 # 图像分类 12 | DETECTION = 2 # 图像检测 13 | SEGMENTATION = 3 # 图像分割 14 | SUPERRES = 4 # 超分辨率 15 | POINTCLOUD = 5 # 三维点云 16 | OCR = 6 # OCR 17 | TEXT_CLASSIFY = 7 # 文本分类 18 | TEXT_LABELING = 8 # 文本序列标注 19 | GAN = 9 # 生成对抗网络 20 | SEQ2SEQ = 10 # 文本生成 21 | NERF = 11 # 神经辐射场 22 | REC = 12 # 推荐系统 23 | BLOCK = 13 # 小型子网 24 | 25 | 26 | class PPQTestCase(): 27 | def __init__(self, model_builder: Callable, 28 | input_generator: Callable, model_type: ModelType, 29 | model_name: str, running_device = 'cuda', 30 | deploy_platforms: List[TargetPlatform] = None) -> None: 31 | self.deploy_platforms = deploy_platforms 32 | self.model_builder = model_builder 33 | self.input_generator = input_generator 34 | self.model_type = model_type 35 | self.model_name = model_name 36 | self.running_device = running_device 37 | 38 | 39 | def rand_tensor_generator(shape: List[int]): 40 | return torch.rand(size=shape) 41 | -------------------------------------------------------------------------------- /tests/tscheme/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * -------------------------------------------------------------------------------- /tests/tscheme/base.py: -------------------------------------------------------------------------------- 1 | from ppq import * 2 | 3 | class PPQTestScheme(): 4 | def __init__( 5 | self, name: str, quant_platform: TargetPlatform, 6 | export_platform: TargetPlatform, setting: QuantizationSetting): 7 | self.name = name 8 | self.quant_platform = quant_platform 9 | self.export_platform = export_platform 10 | self.setting = setting 11 | 12 | TEST_SCHEMES = [ 13 | PPQTestScheme( 14 | name = 'Tengine', 15 | quant_platform=TargetPlatform.TENGINE_INT8, 16 | export_platform=TargetPlatform.TENGINE_INT8, 17 | setting=QuantizationSettingFactory.pplcuda_setting()), 18 | 19 | PPQTestScheme( 20 | name = 'TRT FP8', 21 | quant_platform=TargetPlatform.TRT_FP8, 22 | export_platform=TargetPlatform.TRT_FP8, 23 | setting=QuantizationSettingFactory.default_setting()), 24 | 25 | PPQTestScheme( 26 | name = 'TRT INT8', 27 | quant_platform=TargetPlatform.TRT_INT8, 28 | export_platform=TargetPlatform.TRT_INT8, 29 | setting=QuantizationSettingFactory.default_setting()), 30 | 31 | PPQTestScheme( 32 | name = 'Sensetime Caffe[DSP INT8]', 33 | quant_platform=TargetPlatform.PPL_DSP_INT8, 34 | export_platform=TargetPlatform.PPL_DSP_INT8, 35 | setting=QuantizationSettingFactory.dsp_setting()), 36 | 37 | PPQTestScheme( 38 | name = 'Sensetime Caffe[DSP INT8]', 39 | quant_platform=TargetPlatform.SNPE_INT8, 40 | export_platform=TargetPlatform.SNPE_INT8, 41 | setting=QuantizationSettingFactory.dsp_setting()), 42 | 43 | PPQTestScheme( 44 | name = 'Sensetime PPL[GPU INT8]', 45 | quant_platform=TargetPlatform.PPL_CUDA_INT8, 46 | export_platform=TargetPlatform.PPL_CUDA_INT8, 47 | setting=QuantizationSettingFactory.pplcuda_setting()), 48 | 49 | PPQTestScheme( 50 | name = 'Sensetime PPL[GPU INT8 - ONNX RUNTIME EXPORT]', 51 | quant_platform=TargetPlatform.PPL_CUDA_INT8, 52 | export_platform=TargetPlatform.ONNXRUNTIME, 53 | setting=QuantizationSettingFactory.pplcuda_setting()), 54 | 55 | PPQTestScheme( 56 | name = 'ONNX RUNTIME[MetaX INT8]', 57 | quant_platform=TargetPlatform.METAX_INT8_T, 58 | export_platform=TargetPlatform.ONNXRUNTIME, 59 | setting=QuantizationSettingFactory.pplcuda_setting()), 60 | 61 | PPQTestScheme( 62 | name = 'ONNX RUNTIME[MetaX INT8 Channelwise]', 63 | quant_platform=TargetPlatform.METAX_INT8_C, 64 | export_platform=TargetPlatform.ONNXRUNTIME, 65 | setting=QuantizationSettingFactory.pplcuda_setting()), 66 | 67 | PPQTestScheme( 68 | name = 'ONNX RUNTIME OP ORITNETD[INT8]', 69 | quant_platform=TargetPlatform.RKNN_INT8, 70 | export_platform=TargetPlatform.RKNN_INT8, 71 | setting=QuantizationSettingFactory.pplcuda_setting()), 72 | 73 | PPQTestScheme( 74 | name = 'NXP [NXP INT8]', 75 | quant_platform=TargetPlatform.NXP_INT8, 76 | export_platform=TargetPlatform.NXP_INT8, 77 | setting=QuantizationSettingFactory.nxp_setting()), 78 | 79 | PPQTestScheme( 80 | name = 'Native', 81 | quant_platform=TargetPlatform.PPL_CUDA_INT8, 82 | export_platform=TargetPlatform.NATIVE, 83 | setting=QuantizationSettingFactory.pplcuda_setting()), 84 | 85 | ] -------------------------------------------------------------------------------- /tests/tworkingspace/placeholder.py: -------------------------------------------------------------------------------- 1 | # Oh 这个文件并没有什么用,只是为了上传 git 的时候保留一下文件夹结构 2 | --------------------------------------------------------------------------------