├── LICENSE ├── README.md ├── draw-engine.py ├── export_qat.py ├── install_dependencies.sh ├── models ├── experimental_trt.py ├── quantize.py └── quantize_rules.py ├── patch_yolov9.sh ├── qat.py ├── scripts ├── generate_trt_engine.sh └── val_trt.sh ├── segment └── qat_seg.py └── val_trt.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv9 QAT for TensorRT 10.9 Detection / Segmentation 2 | 3 | This repository contains an implementation of YOLOv9 with Quantization-Aware Training (QAT), specifically designed for deployment on platforms utilizing TensorRT for hardware-accelerated inference.
4 | This implementation aims to provide an efficient, low-latency version of YOLOv9 for real-time detection applications.
5 | If you do not intend to deploy your model using TensorRT, it is recommended not to proceed with this implementation. 6 | 7 | - The files in this repository represent a patch that adds QAT functionality to the original [YOLOv9 repository](https://github.com/WongKinYiu/yolov9/). 8 | - This patch is intended to be applied to the main YOLOv9 repository to incorporate the ability to train with QAT. 9 | - The implementation is optimized to work efficiently with TensorRT, an inference library that leverages hardware acceleration to enhance inference performance. 10 | - Users interested in implementing object detection using YOLOv9 with QAT on TensorRT platforms can benefit from this repository as it provides a ready-to-use solution. 11 | 12 | 13 | We use [TensorRT's pytorch quantization tool](https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization) to finetune training QAT yolov9 from the pre-trained weight, then export the model to onnx and deploy it with TensorRT. The accuray and performance can be found in below table. 14 | 15 | For those who are not familiar with QAT, I highly recommend watching this video:
[Quantization explained with PyTorch - Post-Training Quantization, Quantization-Aware Training](https://www.youtube.com/watch?v=0VdNflU08yA) 16 | 17 | **Important**
18 | Evaluation of the segmentation model using TensorRT is currently under development. Once I have more available time, I will complete and release this work. 19 | 20 | 🌟 We still have plenty of nodes to improve Q/DQ, and we rely on the community's contribution to enhance this project, benefiting us all. Let's collaborate and make it even better! 🚀 21 | 22 | ## Release Highlights 23 | - This release includes an upgrade from TensorRT 8 to TensorRT 10, ensuring compatibility with the CUDA version supported - by the latest NVIDIA Ada Lovelace GPUs. 24 | - The inference has been upgraded utilizing `enqueueV3` instead `enqueueV2`.
25 | - To maintain legacy support for TensorRT 8, a [dedicated branch](https://github.com/levipereira/yolov9-qat/tree/TensorRT-8) has been created. **Outdated**
26 | - We've added a new option `val_trt.sh --generate-graph` which enables [Graph Rendering](#generate-tensort-profiling-and-svg-image) functionality. This feature facilitates the creation of graphical representations of the engine plan in SVG image format. 27 | 28 | 29 | # Perfomance / Accuracy 30 | [Full Report](#benchmark) 31 | 32 | 33 | ## Accuracy Report 34 | 35 | **YOLOv9-C** 36 | 37 | ### Evaluation Results 38 | 39 | ## Detection 40 | #### Activation SiLU 41 | 42 | | Eval Model | AP | AP50 | Precision | Recall | 43 | |------------|--------|--------|-----------|--------| 44 | | **Origin (Pytorch)** | 0.529 | 0.699 | 0.743 | 0.634 | 45 | | **INT8 (Pytorch)** | 0.529 | 0.702 | 0.742 | 0.63 | 46 | | **INT8 (TensorRT)** | 0.529 | 0.696 | 0.739 | 0.635 | 47 | 48 | 49 | #### Activation ReLU 50 | 51 | | Eval Model | AP | AP50 | Precision | Recall | 52 | |------------|--------|--------|-----------|--------| 53 | | **Origin (Pytorch)** | 0.519 | 0.69 | 0.719 | 0.629 | 54 | | **INT8 (Pytorch)** | 0.518 | 0.69 | 0.726 | 0.625 | 55 | | **INT8 (TensorRT)** | 0.517 | 0.685 | 0.723 | 0.626 | 56 | 57 | ### Evaluation Comparison 58 | 59 | #### Activation SiLU 60 | | Eval Model | AP | AP50 | Precision | Recall | 61 | |----------------------|------|------|-----------|--------| 62 | | **INT8 (TensorRT)** vs **Origin (Pytorch)** | | | | | 63 | | | 0.000 | -0.003 | -0.004 | +0.001 | 64 | 65 | #### Activation ReLU 66 | | Eval Model | AP | AP50 | Precision | Recall | 67 | |----------------------|------|------|-----------|--------| 68 | | **INT8 (TensorRT)** vs **Origin (Pytorch)** | | | | | 69 | | | -0.002 | -0.005 | +0.004 | -0.003 | 70 | 71 | ## Segmentation 72 | | Model | Box | | | | Mask | | | | 73 | |--------|-----|--|--|--|------|--|--|--| 74 | | | P | R | mAP50 | mAP50-95 | P | R | mAP50 | mAP50-95 | 75 | | Origin | 0.729 | 0.632 | 0.691 | 0.521 | 0.717 | 0.611 | 0.657 | 0.423 | 76 | | PTQ | 0.729 | 0.626 | 0.688 | 0.520 | 0.717 | 0.604 | 0.654 | 0.421 | 77 | | QAT | 0.725 | 0.631 | 0.689 | 0.521 | 0.714 | 0.609 | 0.655 | 0.421 | 78 | 79 | 80 | ## Latency/Throughput Report - TensorRT 81 | 82 | ![image](https://github.com/levipereira/yolov9-qat/assets/22964932/61a46206-9784-4c75-bcd4-6534eba51223) 83 | 84 | ## Device 85 | | **GPU** | | 86 | |---------------------------|------------------------------| 87 | | Device | **NVIDIA GeForce RTX 4090** | 88 | | Compute Capability | 8.9 | 89 | | SMs | 128 | 90 | | Device Global Memory | 24207 MiB | 91 | | Application Compute Clock Rate | 2.58 GHz | 92 | | Application Memory Clock Rate | 10.501 GHz | 93 | 94 | 95 | ### Latency/Throughput 96 | 97 | | Model Name | Batch Size | Latency (99%) | Throughput (qps) | Total Inferences (IPS) | 98 | |-----------------|------------|----------------|------------------|------------------------| 99 | | **(FP16) SiLU** | 1 | 1.25 ms | 803 | 803 | 100 | | | 4 | 3.37 ms | 300 | 1200 | 101 | | | 8 | 6.6 ms | 153 | 1224 | 102 | | | 12 | 10 ms | 99 | 1188 | 103 | | | | | | | 104 | | **INT8 (SiLU)** | 1 | 0.97 ms | 1030 | 1030 | 105 | | | 4 | 2,06 ms | 486 | 1944 | 106 | | | 8 | 3.69 ms | 271 | 2168 | 107 | | | 12 | 5.36 ms | 189 | 2268 | 108 | | | | | | | 109 | | **INT8 (ReLU)** | 1 | 0.87 ms | 1150 | 1150 | 110 | | | 4 | 1.78 ms | 562 | 2248 | 111 | | | 8 | 3.06 ms | 327 | 2616 | 112 | | | 12 | 4.63 ms | 217 | 2604 | 113 | 114 | ## Latency/Throughput Comparison (INT8 vs FP16) 115 | 116 | | Model Name | Batch Size | Latency (99%) Change | Throughput (qps) Change | Total Inferences (IPS) Change | 117 | |---|---|---|---|---| 118 | | **INT8(SiLU)** vs **FP16** | 1 | -20.8% | +28.4% | +28.4% | 119 | | | 4 | -37.1% | +62.0% | +62.0% | 120 | | | 8 | -41.1% | +77.0% | +77.0% | 121 | | | 12 | -46.9% | +90.9% | +90.9% | 122 | 123 | 124 | ## QAT Training (Finetune) 125 | 126 | In this section, we'll outline the steps to perform Quantization-Aware Training (QAT) using fine-tuning.
**Please note that the supported quantization mode is fine-tuning only.**
The model should be trained using the original implementation train.py, and after training and reparameterization of the model, the user should proceed with quantization. 127 | 128 | ### Steps: 129 | 130 | 1. **Train the Model Using [Training Session](https://github.com/WongKinYiu/yolov9/tree/main?tab=readme-ov-file#training):** 131 | - Utilize the original implementation train.py to train your YOLOv9 model with your dataset and desired configurations. 132 | - Follow the training instructions provided in the original YOLOv9 repository to ensure proper training. 133 | 134 | 2. **Reparameterize the Model [reparameterization.py](https://github.com/sunmooncode/yolov9/blob/main/tools/reparameterization.py):** 135 | - After completing the training, reparameterize the trained model to prepare it for quantization. This step is crucial for ensuring that the model's weights are in a suitable format for quantization. 136 | 137 | 3. **[Proceed with Quantization](#quantize-model):** 138 | - Once the model is reparameterized, proceed with the quantization process. This involves applying the Quantization-Aware Training technique to fine-tune the model's weights, taking into account the quantization effects. 139 | 140 | 4. **[Eval Pytorch](#evaluate-using-pytorch) / [Eval TensorRT](#evaluate-using-tensorrt):** 141 | - After quantization, it's crucial to validate the performance of the quantized model to ensure that it meets your requirements in terms of accuracy and efficiency. 142 | - Test the quantized model thoroughly at both stages: during the quantization phase using PyTorch and after training using TensorRT. 143 | - Please note that different versions of TensorRT may yield varying results and perfomance 144 | 145 | 5. **Export to ONNX:** 146 | - [Export ONNX](#export-onnx) 147 | - Once you are satisfied with the quantized model's performance, you can proceed to export it to ONNX format. 148 | 149 | 6. **Deploy with TensorRT:** 150 | - [Deployment with TensorRT](#deployment-with-tensorrt) 151 | - After exporting to ONNX, you can deploy the model using TensorRT for hardware-accelerated inference on platforms supporting TensorRT. 152 | 153 | 154 | 155 | 156 | By following these steps, you can successfully perform Quantization-Aware Training (QAT) using fine-tuning with your YOLOv9 model. 157 | 158 | ## How to Install and Training 159 | Suggest to use docker environment. 160 | NVIDIA PyTorch image (`nvcr.io/nvidia/pytorch:24.10-py3`) 161 | 162 | Release 24.10 is based Ubuntu 22.04 including Python 3.10 CUDA 12.6.2, which requires NVIDIA Driver release 560 or later, if you are running on a data center GPU check docs. 163 | https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-10.html 164 | 165 | ## Installation 166 | ```bash 167 | 168 | docker pull nvcr.io/nvidia/pytorch:24.10-py3 169 | 170 | ## clone original yolov9 171 | git clone https://github.com/WongKinYiu/yolov9.git 172 | 173 | docker run --gpus all \ 174 | -it \ 175 | --net host \ 176 | --ipc=host \ 177 | -v $(pwd)/yolov9:/yolov9 \ 178 | -v $(pwd)/coco/:/yolov9/coco \ 179 | -v $(pwd)/runs:/yolov9/runs \ 180 | nvcr.io/nvidia/pytorch:24.10-py3 181 | 182 | ``` 183 | 184 | 1. Clone and apply patch (Inside Docker) 185 | ```bash 186 | cd / 187 | git clone https://github.com/levipereira/yolov9-qat.git 188 | cd /yolov9-qat 189 | ./patch_yolov9.sh /yolov9 190 | ``` 191 | 192 | 2. Install dependencies 193 | 194 | - **This release upgrade TensorRT to 10.9** 195 | - `./install_dependencies.sh` 196 | 197 | ```bash 198 | cd /yolov9-qat 199 | ./install_dependencies.sh 200 | ``` 201 | 202 | 203 | 3. Download dataset and pretrained model 204 | ```bash 205 | $ cd /yolov9 206 | $ bash scripts/get_coco.sh 207 | $ wget https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c-converted.pt 208 | ``` 209 | 210 | 211 | ## Usage 212 | 213 | ## Quantize Model 214 | 215 | To quantize a YOLOv9 model, run: 216 | 217 | ```bash 218 | python3 qat.py quantize --weights yolov9-c-converted.pt --name yolov9_qat --exist-ok 219 | 220 | python qat.py quantize --weights --data --hyp ... 221 | ``` 222 | ## Quantize Command Arguments 223 | 224 | ### Description 225 | This command is used to perform PTQ/QAT finetuning. 226 | 227 | ### Arguments 228 | 229 | - `--weights`: Path to the model weights (.pt). Default: ROOT/runs/models_original/yolov9-c.pt. 230 | - `--data`: Path to the dataset configuration file (data.yaml). Default: ROOT/data/coco.yaml. 231 | - `--hyp`: Path to the hyperparameters file (hyp.yaml). Default: ROOT/data/hyps/hyp.scratch-high.yaml. 232 | - `--device`: Device to use for training/evaluation (e.g., "cuda:0"). Default: "cuda:0". 233 | - `--batch-size`: Total batch size for training/evaluation. Default: 10. 234 | - `--imgsz`, `--img`, `--img-size`: Train/val image size (pixels). Default: 640. 235 | - `--project`: Directory to save the training/evaluation outputs. Default: ROOT/runs/qat. 236 | - `--name`: Name of the training/evaluation experiment. Default: 'exp'. 237 | - `--exist-ok`: Flag to indicate if existing project/name should be overwritten. 238 | - `--iters`: Iterations per epoch. Default: 200. 239 | - `--seed`: Global training seed. Default: 57. 240 | - `--supervision-stride`: Supervision stride. Default: 1. 241 | - `--no-eval-origin`: Disable eval for origin model. 242 | - `--no-eval-ptq`: Disable eval for ptq model. 243 | 244 | 245 | ## Sensitive Layer Analysis 246 | ```bash 247 | python qat.py sensitive --weights yolov9-c.pt --data data/coco.yaml --hyp hyp.scratch-high.yaml ... 248 | ``` 249 | 250 | ## Sensitive Command Arguments 251 | 252 | ### Description 253 | This command is used for sensitive layer analysis. 254 | 255 | ### Arguments 256 | 257 | - `--weights`: Path to the model weights (.pt). Default: ROOT/runs/models_original/yolov9-c.pt. 258 | - `--device`: Device to use for training/evaluation (e.g., "cuda:0"). Default: "cuda:0". 259 | - `--data`: Path to the dataset configuration file (data.yaml). Default: data/coco.yaml. 260 | - `--batch-size`: Total batch size for training/evaluation. Default: 10. 261 | - `--imgsz`, `--img`, `--img-size`: Train/val image size (pixels). Default: 640. 262 | - `--hyp`: Path to the hyperparameters file (hyp.yaml). Default: data/hyps/hyp.scratch-high.yaml. 263 | - `--project`: Directory to save the training/evaluation outputs. Default: ROOT/runs/qat_sentive. 264 | - `--name`: Name of the training/evaluation experiment. Default: 'exp'. 265 | - `--exist-ok`: Flag to indicate if existing project/name should be overwritten. 266 | - `--num-image`: Number of images to evaluate. Default: None. 267 | 268 | 269 | ## Evaluate QAT Model 270 | 271 | ### Evaluate using Pytorch 272 | ```bash 273 | python3 qat.py eval --weights runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.pt --name eval_qat_yolov9 274 | ``` 275 | ## Evaluation Command Arguments 276 | 277 | ### Description 278 | This command is used to perform evaluation on QAT Models. 279 | 280 | ### Arguments 281 | 282 | - `--weights`: Path to the model weights (.pt). Default: ROOT/runs/models_original/yolov9-c.pt. 283 | - `--data`: Path to the dataset configuration file (data.yaml). Default: data/coco.yaml. 284 | - `--batch-size`: Total batch size for evaluation. Default: 10. 285 | - `--imgsz`, `--img`, `--img-size`: Validation image size (pixels). Default: 640. 286 | - `--device`: Device to use for evaluation (e.g., "cuda:0"). Default: "cuda:0". 287 | - `--conf-thres`: Confidence threshold for evaluation. Default: 0.001. 288 | - `--iou-thres`: NMS threshold for evaluation. Default: 0.7. 289 | - `--project`: Directory to save the evaluation outputs. Default: ROOT/runs/qat_eval. 290 | - `--name`: Name of the evaluation experiment. Default: 'exp'. 291 | - `--exist-ok`: Flag to indicate if existing project/name should be overwritten. 292 | 293 | ### Evaluate using TensorRT 294 | 295 | ```bash 296 | ./scripts/val_trt.sh 297 | 298 | ./scripts/val_trt.sh runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.pt data/coco.yaml 640 299 | ``` 300 | 301 | ## Generate TensoRT Profiling and SVG image 302 | 303 | 304 | TensorRT Explorer can be installed by executing `./install_dependencies.sh --trex`.
This installation is necessary to enable the generation of Graph SV, allowing visualization of the profiling data for a TensorRT engine. 305 | 306 | ```bash 307 | ./scripts/val_trt.sh runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.pt data/coco.yaml 640 --generate-graph 308 | ``` 309 | 310 | # Export ONNX 311 | The goal of exporting to ONNX is to deploy to TensorRT, not to ONNX runtime. So we only export fake quantized model into a form TensorRT will take. Fake quantization will be broken into a pair of QuantizeLinear/DequantizeLinear ONNX ops. TensorRT will take the generated ONNX graph, and execute it in int8 in the most optimized way to its capability. 312 | 313 | ## Export ONNX Model without End2End 314 | ```bash 315 | python3 export_qat.py --weights runs/qat/yolov9_qat/weights/qat_best_yolov9-c.pt --include onnx --dynamic --simplify --inplace 316 | ``` 317 | 318 | ## Export ONNX Model End2End 319 | ```bash 320 | python3 export_qat.py --weights runs/qat/yolov9_qat/weights/qat_best_yolov9-c.pt --include onnx_end2end 321 | ``` 322 | 323 | 324 | ## Deployment with Tensorrt 325 | ```bash 326 | /usr/src/tensorrt/bin/trtexec \ 327 | --onnx=runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.onnx \ 328 | --int8 --fp16 \ 329 | --useCudaGraph \ 330 | --minShapes=images:1x3x640x640 \ 331 | --optShapes=images:4x3x640x640 \ 332 | --maxShapes=images:8x3x640x640 \ 333 | --saveEngine=runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.engine 334 | ``` 335 | 336 | # Benchmark 337 | Note: To test FP16 Models (such as Origin) remove flag `--int8` 338 | ```bash 339 | # Set variable batch_size and model_path_no_ext 340 | export batch_size=4 341 | export filepath_no_ext=runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted 342 | trtexec \ 343 | --onnx=${filepath_no_ext}.onnx \ 344 | --fp16 \ 345 | --int8 \ 346 | --saveEngine=${filepath_no_ext}.engine \ 347 | --timingCacheFile=${filepath_no_ext}.engine.timing.cache \ 348 | --warmUp=500 \ 349 | --duration=10 \ 350 | --useCudaGraph \ 351 | --useSpinWait \ 352 | --noDataTransfers \ 353 | --minShapes=images:1x3x640x640 \ 354 | --optShapes=images:${batch_size}x3x640x640 \ 355 | --maxShapes=images:${batch_size}x3x640x640 356 | ``` 357 | 358 | ### Device 359 | ```bash 360 | === Device Information === 361 | Available Devices: 362 | Device 0: "NVIDIA GeForce RTX 4090" 363 | Selected Device: NVIDIA GeForce RTX 4090 364 | Selected Device ID: 0 365 | Compute Capability: 8.9 366 | SMs: 128 367 | Device Global Memory: 24207 MiB 368 | Shared Memory per SM: 100 KiB 369 | Memory Bus Width: 384 bits (ECC disabled) 370 | Application Compute Clock Rate: 2.58 GHz 371 | Application Memory Clock Rate: 10.501 GHz 372 | ``` 373 | 374 | ## Output Details 375 | - `Latency`: refers to the [min, max, mean, median, 99% percentile] of the engine latency measurements, when timing the engine w/o profiling layers. 376 | - `Throughput`: is measured in query (inference) per second (QPS). 377 | 378 | ## YOLOv9-C QAT (SiLU) 379 | ## Batch Size 1 380 | ```bash 381 | Throughput: 1026.71 qps 382 | Latency: min = 0.969727 ms, max = 0.975098 ms, mean = 0.972263 ms, median = 0.972656 ms, percentile(90%) = 0.973145 ms, percentile(95%) = 0.973633 ms, percentile(99%) = 0.974121 ms 383 | Enqueue Time: min = 0.00195312 ms, max = 0.0195312 ms, mean = 0.00228119 ms, median = 0.00219727 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms 384 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 385 | GPU Compute Time: min = 0.969727 ms, max = 0.975098 ms, mean = 0.972263 ms, median = 0.972656 ms, percentile(90%) = 0.973145 ms, percentile(95%) = 0.973633 ms, percentile(99%) = 0.974121 ms 386 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 387 | Total Host Walltime: 10.0019 s 388 | Total GPU Compute Time: 9.98417 s 389 | ``` 390 | 391 | ## BatchSize 4 392 | ```bash 393 | === Performance summary === 394 | Throughput: 485.73 qps 395 | Latency: min = 2.05176 ms, max = 2.06152 ms, mean = 2.05712 ms, median = 2.05713 ms, percentile(90%) = 2.05908 ms, percentile(95%) = 2.05957 ms, percentile(99%) = 2.06055 ms 396 | Enqueue Time: min = 0.00195312 ms, max = 0.00708008 ms, mean = 0.00230195 ms, median = 0.00219727 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00415039 ms 397 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 398 | GPU Compute Time: min = 2.05176 ms, max = 2.06152 ms, mean = 2.05712 ms, median = 2.05713 ms, percentile(90%) = 2.05908 ms, percentile(95%) = 2.05957 ms, percentile(99%) = 2.06055 ms 399 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 400 | Total Host Walltime: 10.0035 s 401 | Total GPU Compute Time: 9.99553 s 402 | ``` 403 | 404 | 405 | ## BatchSize 8 406 | ```bash 407 | === Performance summary === 408 | Throughput: 271.107 qps 409 | Latency: min = 3.6792 ms, max = 3.69775 ms, mean = 3.68694 ms, median = 3.68652 ms, percentile(90%) = 3.69043 ms, percentile(95%) = 3.69141 ms, percentile(99%) = 3.69336 ms 410 | Enqueue Time: min = 0.00195312 ms, max = 0.0090332 ms, mean = 0.0023588 ms, median = 0.00231934 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00476074 ms 411 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 412 | GPU Compute Time: min = 3.6792 ms, max = 3.69775 ms, mean = 3.68694 ms, median = 3.68652 ms, percentile(90%) = 3.69043 ms, percentile(95%) = 3.69141 ms, percentile(99%) = 3.69336 ms 413 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 414 | Total Host Walltime: 10.0071 s 415 | Total GPU Compute Time: 10.0027 s 416 | ``` 417 | ## BatchSize 12 418 | ```bash 419 | === Performance summary === 420 | Throughput: 188.812 qps 421 | Latency: min = 5.25 ms, max = 5.37097 ms, mean = 5.2946 ms, median = 5.28906 ms, percentile(90%) = 5.32129 ms, percentile(95%) = 5.32593 ms, percentile(99%) = 5.36475 ms 422 | Enqueue Time: min = 0.00195312 ms, max = 0.0898438 ms, mean = 0.00248513 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00463867 ms 423 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 424 | GPU Compute Time: min = 5.25 ms, max = 5.37097 ms, mean = 5.2946 ms, median = 5.28906 ms, percentile(90%) = 5.32129 ms, percentile(95%) = 5.32593 ms, percentile(99%) = 5.36475 ms 425 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 426 | Total Host Walltime: 10.01 s 427 | Total GPU Compute Time: 10.0068 s 428 | ``` 429 | 430 | 431 | 432 | ## YOLOv9-C QAT (ReLU) 433 | ## Batch Size 1 434 | ```bash 435 | === Performance summary === 436 | Throughput: 1149.49 qps 437 | Latency: min = 0.866211 ms, max = 0.871094 ms, mean = 0.868257 ms, median = 0.868164 ms, percentile(90%) = 0.869385 ms, percentile(95%) = 0.869629 ms, percentile(99%) = 0.870117 ms 438 | Enqueue Time: min = 0.00195312 ms, max = 0.0180664 ms, mean = 0.00224214 ms, median = 0.00219727 ms, percentile(90%) = 0.00268555 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms 439 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 440 | GPU Compute Time: min = 0.866211 ms, max = 0.871094 ms, mean = 0.868257 ms, median = 0.868164 ms, percentile(90%) = 0.869385 ms, percentile(95%) = 0.869629 ms, percentile(99%) = 0.870117 ms 441 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 442 | Total Host Walltime: 10.0018 s 443 | Total GPU Compute Time: 9.98235 s 444 | ``` 445 | 446 | ## BatchSize 4 447 | ```bash 448 | === Performance summary === 449 | Throughput: 561.857 qps 450 | Latency: min = 1.77344 ms, max = 1.78418 ms, mean = 1.77814 ms, median = 1.77832 ms, percentile(90%) = 1.77979 ms, percentile(95%) = 1.78076 ms, percentile(99%) = 1.78174 ms 451 | Enqueue Time: min = 0.00195312 ms, max = 0.0205078 ms, mean = 0.00233018 ms, median = 0.0022583 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00439453 ms 452 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 453 | GPU Compute Time: min = 1.77344 ms, max = 1.78418 ms, mean = 1.77814 ms, median = 1.77832 ms, percentile(90%) = 1.77979 ms, percentile(95%) = 1.78076 ms, percentile(99%) = 1.78174 ms 454 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 455 | Total Host Walltime: 10.0043 s 456 | Total GPU Compute Time: 9.99494 s 457 | ``` 458 | 459 | 460 | ## BatchSize 8 461 | ```bash 462 | === Performance summary === 463 | Throughput: 326.86 qps 464 | Latency: min = 3.04126 ms, max = 3.06934 ms, mean = 3.05773 ms, median = 3.05859 ms, percentile(90%) = 3.06152 ms, percentile(95%) = 3.0625 ms, percentile(99%) = 3.06396 ms 465 | Enqueue Time: min = 0.00195312 ms, max = 0.0209961 ms, mean = 0.00235826 ms, median = 0.00231934 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00463867 ms 466 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 467 | GPU Compute Time: min = 3.04126 ms, max = 3.06934 ms, mean = 3.05773 ms, median = 3.05859 ms, percentile(90%) = 3.06152 ms, percentile(95%) = 3.0625 ms, percentile(99%) = 3.06396 ms 468 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 469 | Total Host Walltime: 10.0043 s 470 | Total GPU Compute Time: 9.99877 s 471 | ``` 472 | ## BatchSize 12 473 | ```bash 474 | === Performance summary === 475 | Throughput: 216.441 qps 476 | Latency: min = 4.60742 ms, max = 4.63184 ms, mean = 4.61852 ms, median = 4.61816 ms, percentile(90%) = 4.62305 ms, percentile(95%) = 4.62439 ms, percentile(99%) = 4.62744 ms 477 | Enqueue Time: min = 0.00195312 ms, max = 0.0131836 ms, mean = 0.00250633 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00341797 ms, percentile(99%) = 0.00531006 ms 478 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 479 | GPU Compute Time: min = 4.60742 ms, max = 4.63184 ms, mean = 4.61852 ms, median = 4.61816 ms, percentile(90%) = 4.62305 ms, percentile(95%) = 4.62439 ms, percentile(99%) = 4.62744 ms 480 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 481 | Total Host Walltime: 10.0074 s 482 | Total GPU Compute Time: 10.0037 s 483 | ``` 484 | 485 | ## YOLOv9-C FP16 486 | ## Batch Size 1 487 | ```bash 488 | === Performance summary === 489 | Throughput: 802.984 qps 490 | Latency: min = 1.23901 ms, max = 1.25439 ms, mean = 1.24376 ms, median = 1.24316 ms, percentile(90%) = 1.24805 ms, percentile(95%) = 1.24902 ms, percentile(99%) = 1.24951 ms 491 | Enqueue Time: min = 0.00195312 ms, max = 0.00756836 ms, mean = 0.00240711 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms 492 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 493 | GPU Compute Time: min = 1.23901 ms, max = 1.25439 ms, mean = 1.24376 ms, median = 1.24316 ms, percentile(90%) = 1.24805 ms, percentile(95%) = 1.24902 ms, percentile(99%) = 1.24951 ms 494 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 495 | Total Host Walltime: 10.0027 s 496 | Total GPU Compute Time: 9.98985 s 497 | ``` 498 | 499 | ## BatchSize 4 500 | ```bash 501 | === Performance summary === 502 | Throughput: 300.281 qps 503 | Latency: min = 3.30341 ms, max = 3.38025 ms, mean = 3.32861 ms, median = 3.3291 ms, percentile(90%) = 3.33594 ms, percentile(95%) = 3.34229 ms, percentile(99%) = 3.37 ms 504 | Enqueue Time: min = 0.00195312 ms, max = 0.00830078 ms, mean = 0.00244718 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms 505 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 506 | GPU Compute Time: min = 3.30341 ms, max = 3.38025 ms, mean = 3.32861 ms, median = 3.3291 ms, percentile(90%) = 3.33594 ms, percentile(95%) = 3.34229 ms, percentile(99%) = 3.37 ms 507 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 508 | Total Host Walltime: 10.0073 s 509 | Total GPU Compute Time: 10.0025 s 510 | ``` 511 | 512 | 513 | ## BatchSize 8 514 | ```bash 515 | === Performance summary === 516 | Throughput: 153.031 qps 517 | Latency: min = 6.47882 ms, max = 6.64679 ms, mean = 6.53299 ms, median = 6.5332 ms, percentile(90%) = 6.55029 ms, percentile(95%) = 6.55762 ms, percentile(99%) = 6.59766 ms 518 | Enqueue Time: min = 0.00195312 ms, max = 0.0117188 ms, mean = 0.00248772 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms 519 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 520 | GPU Compute Time: min = 6.47882 ms, max = 6.64679 ms, mean = 6.53299 ms, median = 6.5332 ms, percentile(90%) = 6.55029 ms, percentile(95%) = 6.55762 ms, percentile(99%) = 6.59766 ms 521 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 522 | Total Host Walltime: 10.011 s 523 | Total GPU Compute Time: 10.0085 s 524 | ``` 525 | 526 | ## BatchSize 8 527 | ```bash 528 | === Performance summary === 529 | Throughput: 99.3162 qps 530 | Latency: min = 10.0372 ms, max = 10.0947 ms, mean = 10.0672 ms, median = 10.0674 ms, percentile(90%) = 10.0781 ms, percentile(95%) = 10.0811 ms, percentile(99%) = 10.0859 ms 531 | Enqueue Time: min = 0.00195312 ms, max = 0.0078125 ms, mean = 0.00248219 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms 532 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 533 | GPU Compute Time: min = 10.0372 ms, max = 10.0947 ms, mean = 10.0672 ms, median = 10.0674 ms, percentile(90%) = 10.0781 ms, percentile(95%) = 10.0811 ms, percentile(99%) = 10.0859 ms 534 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 535 | Total Host Walltime: 10.0286 s 536 | Total GPU Compute Time: 10.0269 s 537 | ``` 538 | 539 | 540 | # Segmentation 541 | 542 | ## FP16 543 | ### Batch Size 8 544 | 545 | ```bash 546 | === Performance summary === 547 | Throughput: 124.055 qps 548 | Latency: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms 549 | Enqueue Time: min = 0.00219727 ms, max = 0.0200653 ms, mean = 0.00271174 ms, median = 0.00256348 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00317383 ms, percentile(99%) = 0.00466919 ms 550 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 551 | GPU Compute Time: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms 552 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 553 | Total Host Walltime: 3.01478 s 554 | Total GPU Compute Time: 3.01415 s 555 | ``` 556 | 557 | ## INT8 / FP16 558 | ### Batch Size 8 559 | ```bash 560 | === Performance summary === 561 | Throughput: 223.63 qps 562 | Latency: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms 563 | Enqueue Time: min = 0.00219727 ms, max = 0.00854492 ms, mean = 0.00258152 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00305176 ms, percentile(99%) = 0.00439453 ms 564 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 565 | GPU Compute Time: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms 566 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms 567 | Total Host Walltime: 3.00944 s 568 | Total GPU Compute Time: 3.00836 s 569 | ``` -------------------------------------------------------------------------------- /draw-engine.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: MIT 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a 6 | # copy of this software and associated documentation files (the "Software"), 7 | # to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the 10 | # Software is furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 21 | # DEALINGS IN THE SOFTWARE. 22 | ################################################################################ 23 | 24 | 25 | """ 26 | This script generates an SVG diagram of the input engine graph SVG file. 27 | Note: 28 | THIS SCRIPT DEPENDS ON LIB: https://github.com/NVIDIA/TensorRT/tree/main/tools/experimental/trt-engine-explorer 29 | this script requires graphviz which can be installed: 30 | $ /yolov9-qat/install_dependencies.sh --trex 31 | """ 32 | 33 | import graphviz 34 | from trex import * 35 | from trex import graphing 36 | import argparse 37 | import shutil 38 | 39 | 40 | def draw_engine(engine_json_fname: str, engine_profile_fname: str): 41 | engine_name=engine_json_fname.replace('.layer.json', '') 42 | graphviz_is_installed = shutil.which("dot") is not None 43 | if not graphviz_is_installed: 44 | print("graphviz is required but it is not installed.\n") 45 | print("To install on Ubuntu:") 46 | print("$ /yolov9-qat/install_dependencies.sh --trex") 47 | exit() 48 | 49 | plan = EnginePlan(engine_json_fname, engine_profile_fname) 50 | formatter = graphing.layer_type_formatter 51 | display_regions = True 52 | expand_layer_details = False 53 | 54 | graph = graphing.to_dot(plan, formatter, 55 | display_regions=display_regions, 56 | expand_layer_details=expand_layer_details) 57 | graphing.render_dot(graph, engine_name, 'svg') 58 | 59 | graphing.render_dot(graph, engine_name, 'png') 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument('--layer', help="name of engine JSON file to draw" , required=True) 65 | parser.add_argument('--profile', help="name of profile JSON file to draw", required=True) 66 | args = parser.parse_args() 67 | draw_engine(engine_json_fname=args.layer,engine_profile_fname=args.profile) 68 | -------------------------------------------------------------------------------- /export_qat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import json 4 | import os 5 | import platform 6 | import re 7 | import subprocess 8 | import sys 9 | import time 10 | import warnings 11 | from pathlib import Path 12 | 13 | import pandas as pd 14 | import torch 15 | from torch.utils.mobile_optimizer import optimize_for_mobile 16 | import onnx_graphsurgeon as gs 17 | 18 | import models.quantize as quantize 19 | 20 | from pytorch_quantization import nn as quant_nn 21 | 22 | FILE = Path(__file__).resolve() 23 | ROOT = FILE.parents[0] # YOLO root directory 24 | if str(ROOT) not in sys.path: 25 | sys.path.append(str(ROOT)) # add ROOT to PATH 26 | if platform.system() != 'Windows': 27 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 28 | 29 | from models.experimental_trt import End2End_TRT 30 | from models.experimental import attempt_load 31 | from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel 32 | from utils.dataloaders import LoadImages 33 | from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version, 34 | check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save) 35 | from utils.torch_utils import select_device, smart_inference_mode 36 | from models.quantize import remove_redundant_qdq_model 37 | 38 | MACOS = platform.system() == 'Darwin' # macOS environment 39 | 40 | 41 | 42 | def export_formats(): 43 | # YOLO export formats 44 | x = [ 45 | ['PyTorch', '-', '.pt', True, True], 46 | ['TorchScript', 'torchscript', '.torchscript', True, True], 47 | ['ONNX', 'onnx', '.onnx', True, True], 48 | ['ONNX END2END', 'onnx_end2end', '_end2end.onnx', True, True], 49 | ['OpenVINO', 'openvino', '_openvino_model', True, False], 50 | ['TensorRT', 'engine', '.engine', False, True], 51 | ['CoreML', 'coreml', '.mlmodel', True, False], 52 | ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True], 53 | ['TensorFlow GraphDef', 'pb', '.pb', True, True], 54 | ['TensorFlow Lite', 'tflite', '.tflite', True, False], 55 | ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False], 56 | ['TensorFlow.js', 'tfjs', '_web_model', False, False], 57 | ['PaddlePaddle', 'paddle', '_paddle_model', True, True],] 58 | return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU']) 59 | 60 | 61 | def try_export(inner_func): 62 | # YOLO export decorator, i..e @try_export 63 | inner_args = get_default_args(inner_func) 64 | 65 | def outer_func(*args, **kwargs): 66 | prefix = inner_args['prefix'] 67 | try: 68 | with Profile() as dt: 69 | f, model = inner_func(*args, **kwargs) 70 | LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)') 71 | return f, model 72 | except Exception as e: 73 | LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}') 74 | return None, None 75 | 76 | return outer_func 77 | 78 | 79 | @try_export 80 | def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')): 81 | # YOLO TorchScript model export 82 | LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...') 83 | f = file.with_suffix('.torchscript') 84 | 85 | ts = torch.jit.trace(model, im, strict=False) 86 | d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names} 87 | extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap() 88 | if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html 89 | optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files) 90 | else: 91 | ts.save(str(f), _extra_files=extra_files) 92 | return f, None 93 | 94 | 95 | @try_export 96 | def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')): 97 | # YOLO ONNX export 98 | check_requirements('onnx') 99 | import onnx 100 | 101 | is_model_qat=False 102 | for i in range(0, len(model.model)): 103 | layer = model.model[i] 104 | if quantize.have_quantizer(layer): 105 | is_model_qat=True 106 | break 107 | 108 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') 109 | f = file.with_suffix('.onnx') 110 | 111 | output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'] 112 | if dynamic: 113 | dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) 114 | if isinstance(model, SegmentationModel): 115 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) 116 | dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) 117 | elif isinstance(model, DetectionModel): 118 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) 119 | 120 | if is_model_qat: 121 | warnings.filterwarnings("ignore") 122 | LOGGER.info(f'{prefix} Model QAT Detected ...') 123 | quant_nn.TensorQuantizer.use_fb_fake_quant = True 124 | model.eval() 125 | quantize.initialize() 126 | quantize.replace_custom_module_forward(model) 127 | with torch.no_grad(): 128 | torch.onnx.export( 129 | model, 130 | im, 131 | f, 132 | verbose=False, 133 | opset_version=13, 134 | do_constant_folding=True, 135 | input_names=['images'], 136 | output_names=output_names, 137 | dynamic_axes=dynamic) 138 | 139 | else: 140 | torch.onnx.export( 141 | model.cpu() if dynamic else model, # --dynamic only compatible with cpu 142 | im.cpu() if dynamic else im, 143 | f, 144 | verbose=False, 145 | opset_version=opset, 146 | do_constant_folding=True, 147 | input_names=['images'], 148 | output_names=output_names, 149 | dynamic_axes=dynamic or None) 150 | 151 | # Checks 152 | model_onnx = onnx.load(f) # load onnx model 153 | onnx.checker.check_model(model_onnx) # check onnx model 154 | 155 | # Metadata 156 | d = {'stride': int(max(model.stride)), 'names': model.names} 157 | for k, v in d.items(): 158 | meta = model_onnx.metadata_props.add() 159 | meta.key, meta.value = k, str(v) 160 | 161 | # Simplify 162 | if simplify: 163 | try: 164 | cuda = torch.cuda.is_available() 165 | check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1')) 166 | import onnxsim 167 | 168 | LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...') 169 | model_onnx, check = onnxsim.simplify(model_onnx) 170 | assert check, 'assert check failed' 171 | onnx.save(model_onnx, f) 172 | except Exception as e: 173 | LOGGER.info(f'{prefix} simplifier failure: {e}') 174 | 175 | LOGGER.info(f'{prefix} Removing redundant Q/DQ layer with onnx_graphsurgeon {gs.__version__}...') 176 | remove_redundant_qdq_model(model_onnx, f) 177 | model_onnx = onnx.load(f) 178 | return f, model_onnx 179 | 180 | 181 | 182 | @try_export 183 | def export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, labels, mask_resolution, pooler_scale, sampling_ratio, prefix=colorstr('ONNX END2END:')): 184 | LOGGER.info(f'{prefix} Model type: {type(model)}') 185 | LOGGER.info(f'{prefix} Is DetectionModel: {isinstance(model, DetectionModel)}') 186 | LOGGER.info(f'{prefix} Is SegmentationModel: {isinstance(model, SegmentationModel)}') 187 | 188 | has_detection_capabilities = hasattr(model, 'model') and hasattr(model, 'names') and hasattr(model, 'stride') 189 | 190 | if not has_detection_capabilities: 191 | raise RuntimeError("Model not supported. Only Detection Models can be exported with End2End functionality.") 192 | 193 | LOGGER.info(f'{prefix} Model accepted for export.') 194 | 195 | is_det_model=True 196 | if isinstance(model, SegmentationModel): 197 | is_det_model=False 198 | 199 | env_is_det_model = os.getenv("MODEL_DET") 200 | if env_is_det_model == "0": 201 | is_det_model = False 202 | 203 | # YOLO ONNX export 204 | check_requirements('onnx') 205 | import onnx 206 | 207 | is_model_qat=False 208 | for i in range(0, len(model.model)): 209 | layer = model.model[i] 210 | if quantize.have_quantizer(layer): 211 | is_model_qat=True 212 | break 213 | 214 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') 215 | f = os.path.splitext(file)[0] + "-end2end.onnx" 216 | batch_size = 'batch' 217 | d = { 218 | 'stride': int(max(model.stride)), 219 | 'names': model.names, 220 | 'model type' : 'Detection' if is_det_model else 'Segmentation', 221 | 'TRT Compatibility': '8.6 or above' if class_agnostic else '8.5 or above', 222 | 'TRT Plugins': 'EfficientNMS_TRT' if is_det_model else 'EfficientNMSX_TRT, ROIAlign' 223 | } 224 | 225 | 226 | dynamic_axes = {'images': {0 : 'batch', 2: 'height', 3:'width'}, } # variable length axes 227 | 228 | output_axes = { 229 | 'num_dets': {0: 'batch'}, 230 | 'det_boxes': {0: 'batch'}, 231 | 'det_scores': {0: 'batch'}, 232 | 'det_classes': {0: 'batch'}, 233 | } 234 | if is_det_model: 235 | output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes'] 236 | shapes = [ batch_size, 1, 237 | batch_size, topk_all, 4, 238 | batch_size, topk_all, 239 | batch_size, topk_all] 240 | 241 | else: 242 | output_axes['det_masks'] = {0: 'batch'} 243 | output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes', 'det_masks'] 244 | shapes = [ batch_size, 1, 245 | batch_size, topk_all, 4, 246 | batch_size, topk_all, 247 | batch_size, topk_all, 248 | batch_size, topk_all, mask_resolution * mask_resolution] 249 | 250 | dynamic_axes.update(output_axes) 251 | model = End2End_TRT(model, class_agnostic, topk_all, iou_thres, conf_thres, mask_resolution, pooler_scale, sampling_ratio, None ,device, labels, is_det_model ) 252 | 253 | 254 | if is_model_qat: 255 | warnings.filterwarnings("ignore") 256 | LOGGER.info(f'{prefix} Model QAT Detected ...') 257 | quant_nn.TensorQuantizer.use_fb_fake_quant = True 258 | model.eval() 259 | quantize.initialize() 260 | 261 | with torch.no_grad(): 262 | torch.onnx.export(model, 263 | im, 264 | f, 265 | verbose=False, 266 | export_params=True, # store the trained parameter weights inside the model file 267 | opset_version=16, 268 | do_constant_folding=True, # whether to execute constant folding for optimization 269 | input_names=['images'], 270 | output_names=output_names, 271 | dynamic_axes=dynamic_axes) 272 | quant_nn.TensorQuantizer.use_fb_fake_quant = False 273 | else: 274 | torch.onnx.export(model, 275 | im, 276 | f, 277 | verbose=False, 278 | export_params=True, # store the trained parameter weights inside the model file 279 | opset_version=16, 280 | do_constant_folding=True, # whether to execute constant folding for optimization 281 | input_names=['images'], 282 | output_names=output_names, 283 | dynamic_axes=dynamic_axes) 284 | 285 | # Checks 286 | model_onnx = onnx.load(f) # load onnx model 287 | onnx.checker.check_model(model_onnx) # check onnx model 288 | for k, v in d.items(): 289 | meta = model_onnx.metadata_props.add() 290 | meta.key, meta.value = k, str(v) 291 | 292 | for i in model_onnx.graph.output: 293 | for j in i.type.tensor_type.shape.dim: 294 | j.dim_param = str(shapes.pop(0)) 295 | 296 | if simplify: 297 | try: 298 | import onnxsim 299 | 300 | print('\nStarting to simplify ONNX...') 301 | model_onnx, check = onnxsim.simplify(model_onnx) 302 | assert check, 'assert check failed' 303 | except Exception as e: 304 | print(f'Simplifier failure: {e}') 305 | 306 | # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model 307 | onnx.save(model_onnx,f) 308 | print('ONNX export success, saved as %s' % f) 309 | 310 | LOGGER.info(f'{prefix} Removing redundant Q/DQ layer with onnx_graphsurgeon {gs.__version__}...') 311 | remove_redundant_qdq_model(model_onnx, f) 312 | model_onnx = onnx.load(f) 313 | 314 | return f, model_onnx 315 | 316 | 317 | @try_export 318 | def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')): 319 | # YOLO OpenVINO export 320 | check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/ 321 | import openvino.inference_engine as ie 322 | 323 | LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...') 324 | f = str(file).replace('.pt', f'_openvino_model{os.sep}') 325 | 326 | #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}" 327 | #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {"--compress_to_fp16" if half else ""}" 328 | half_arg = "--compress_to_fp16" if half else "" 329 | cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {half_arg}" 330 | subprocess.run(cmd.split(), check=True, env=os.environ) # export 331 | yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml 332 | return f, None 333 | 334 | 335 | @try_export 336 | def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')): 337 | # YOLO Paddle export 338 | check_requirements(('paddlepaddle', 'x2paddle')) 339 | import x2paddle 340 | from x2paddle.convert import pytorch2paddle 341 | 342 | LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...') 343 | f = str(file).replace('.pt', f'_paddle_model{os.sep}') 344 | 345 | pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export 346 | yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml 347 | return f, None 348 | 349 | 350 | @try_export 351 | def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')): 352 | # YOLO CoreML export 353 | check_requirements('coremltools') 354 | import coremltools as ct 355 | 356 | LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...') 357 | f = file.with_suffix('.mlmodel') 358 | 359 | ts = torch.jit.trace(model, im, strict=False) # TorchScript model 360 | ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])]) 361 | bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None) 362 | if bits < 32: 363 | if MACOS: # quantization only supported on macOS 364 | with warnings.catch_warnings(): 365 | warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning 366 | ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode) 367 | else: 368 | print(f'{prefix} quantization only supported on macOS, skipping...') 369 | ct_model.save(f) 370 | return f, ct_model 371 | 372 | 373 | @try_export 374 | def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')): 375 | # YOLO TensorRT export https://developer.nvidia.com/tensorrt 376 | assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`' 377 | try: 378 | import tensorrt as trt 379 | except Exception: 380 | if platform.system() == 'Linux': 381 | check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com') 382 | import tensorrt as trt 383 | 384 | if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012 385 | grid = model.model[-1].anchor_grid 386 | model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid] 387 | export_onnx(model, im, file, 12, dynamic, simplify) # opset 12 388 | model.model[-1].anchor_grid = grid 389 | else: # TensorRT >= 8 390 | check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0 391 | export_onnx(model, im, file, 12, dynamic, simplify) # opset 12 392 | onnx = file.with_suffix('.onnx') 393 | 394 | LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...') 395 | assert onnx.exists(), f'failed to export ONNX file: {onnx}' 396 | f = file.with_suffix('.engine') # TensorRT engine file 397 | logger = trt.Logger(trt.Logger.INFO) 398 | if verbose: 399 | logger.min_severity = trt.Logger.Severity.VERBOSE 400 | 401 | builder = trt.Builder(logger) 402 | config = builder.create_builder_config() 403 | config.max_workspace_size = workspace * 1 << 30 404 | # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice 405 | 406 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 407 | network = builder.create_network(flag) 408 | parser = trt.OnnxParser(network, logger) 409 | if not parser.parse_from_file(str(onnx)): 410 | raise RuntimeError(f'failed to load ONNX file: {onnx}') 411 | 412 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 413 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 414 | for inp in inputs: 415 | LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}') 416 | for out in outputs: 417 | LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}') 418 | 419 | if dynamic: 420 | if im.shape[0] <= 1: 421 | LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument") 422 | profile = builder.create_optimization_profile() 423 | for inp in inputs: 424 | profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape) 425 | config.add_optimization_profile(profile) 426 | 427 | LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') 428 | if builder.platform_has_fast_fp16 and half: 429 | config.set_flag(trt.BuilderFlag.FP16) 430 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 431 | t.write(engine.serialize()) 432 | return f, None 433 | 434 | 435 | @try_export 436 | def export_saved_model(model, 437 | im, 438 | file, 439 | dynamic, 440 | tf_nms=False, 441 | agnostic_nms=False, 442 | topk_per_class=100, 443 | topk_all=100, 444 | iou_thres=0.45, 445 | conf_thres=0.25, 446 | keras=False, 447 | prefix=colorstr('TensorFlow SavedModel:')): 448 | # YOLO TensorFlow SavedModel export 449 | try: 450 | import tensorflow as tf 451 | except Exception: 452 | check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}") 453 | import tensorflow as tf 454 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 455 | 456 | from models.tf import TFModel 457 | 458 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 459 | f = str(file).replace('.pt', '_saved_model') 460 | batch_size, ch, *imgsz = list(im.shape) # BCHW 461 | 462 | tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz) 463 | im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow 464 | _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) 465 | inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size) 466 | outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres) 467 | keras_model = tf.keras.Model(inputs=inputs, outputs=outputs) 468 | keras_model.trainable = False 469 | keras_model.summary() 470 | if keras: 471 | keras_model.save(f, save_format='tf') 472 | else: 473 | spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype) 474 | m = tf.function(lambda x: keras_model(x)) # full model 475 | m = m.get_concrete_function(spec) 476 | frozen_func = convert_variables_to_constants_v2(m) 477 | tfm = tf.Module() 478 | tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec]) 479 | tfm.__call__(im) 480 | tf.saved_model.save(tfm, 481 | f, 482 | options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version( 483 | tf.__version__, '2.6') else tf.saved_model.SaveOptions()) 484 | return f, keras_model 485 | 486 | 487 | @try_export 488 | def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')): 489 | # YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow 490 | import tensorflow as tf 491 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 492 | 493 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 494 | f = file.with_suffix('.pb') 495 | 496 | m = tf.function(lambda x: keras_model(x)) # full model 497 | m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)) 498 | frozen_func = convert_variables_to_constants_v2(m) 499 | frozen_func.graph.as_graph_def() 500 | tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False) 501 | return f, None 502 | 503 | 504 | @try_export 505 | def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')): 506 | # YOLOv5 TensorFlow Lite export 507 | import tensorflow as tf 508 | 509 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...') 510 | batch_size, ch, *imgsz = list(im.shape) # BCHW 511 | f = str(file).replace('.pt', '-fp16.tflite') 512 | 513 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model) 514 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS] 515 | converter.target_spec.supported_types = [tf.float16] 516 | converter.optimizations = [tf.lite.Optimize.DEFAULT] 517 | if int8: 518 | from models.tf import representative_dataset_gen 519 | dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False) 520 | converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100) 521 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] 522 | converter.target_spec.supported_types = [] 523 | converter.inference_input_type = tf.uint8 # or tf.int8 524 | converter.inference_output_type = tf.uint8 # or tf.int8 525 | converter.experimental_new_quantizer = True 526 | f = str(file).replace('.pt', '-int8.tflite') 527 | if nms or agnostic_nms: 528 | converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS) 529 | 530 | tflite_model = converter.convert() 531 | open(f, "wb").write(tflite_model) 532 | return f, None 533 | 534 | 535 | @try_export 536 | def export_edgetpu(file, prefix=colorstr('Edge TPU:')): 537 | # YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/ 538 | cmd = 'edgetpu_compiler --version' 539 | help_url = 'https://coral.ai/docs/edgetpu/compiler/' 540 | assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}' 541 | if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0: 542 | LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}') 543 | sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system 544 | for c in ( 545 | 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -', 546 | 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list', 547 | 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'): 548 | subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True) 549 | ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1] 550 | 551 | LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...') 552 | f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model 553 | f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model 554 | 555 | cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}" 556 | subprocess.run(cmd.split(), check=True) 557 | return f, None 558 | 559 | 560 | @try_export 561 | def export_tfjs(file, prefix=colorstr('TensorFlow.js:')): 562 | # YOLO TensorFlow.js export 563 | check_requirements('tensorflowjs') 564 | import tensorflowjs as tfjs 565 | 566 | LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...') 567 | f = str(file).replace('.pt', '_web_model') # js dir 568 | f_pb = file.with_suffix('.pb') # *.pb path 569 | f_json = f'{f}/model.json' # *.json path 570 | 571 | cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \ 572 | f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}' 573 | subprocess.run(cmd.split()) 574 | 575 | json = Path(f_json).read_text() 576 | with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order 577 | subst = re.sub( 578 | r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, ' 579 | r'"Identity.?.?": {"name": "Identity.?.?"}, ' 580 | r'"Identity.?.?": {"name": "Identity.?.?"}, ' 581 | r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, ' 582 | r'"Identity_1": {"name": "Identity_1"}, ' 583 | r'"Identity_2": {"name": "Identity_2"}, ' 584 | r'"Identity_3": {"name": "Identity_3"}}}', json) 585 | j.write(subst) 586 | return f, None 587 | 588 | 589 | def add_tflite_metadata(file, metadata, num_outputs): 590 | # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata 591 | with contextlib.suppress(ImportError): 592 | # check_requirements('tflite_support') 593 | from tflite_support import flatbuffers 594 | from tflite_support import metadata as _metadata 595 | from tflite_support import metadata_schema_py_generated as _metadata_fb 596 | 597 | tmp_file = Path('/tmp/meta.txt') 598 | with open(tmp_file, 'w') as meta_f: 599 | meta_f.write(str(metadata)) 600 | 601 | model_meta = _metadata_fb.ModelMetadataT() 602 | label_file = _metadata_fb.AssociatedFileT() 603 | label_file.name = tmp_file.name 604 | model_meta.associatedFiles = [label_file] 605 | 606 | subgraph = _metadata_fb.SubGraphMetadataT() 607 | subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()] 608 | subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs 609 | model_meta.subgraphMetadata = [subgraph] 610 | 611 | b = flatbuffers.Builder(0) 612 | b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) 613 | metadata_buf = b.Output() 614 | 615 | populator = _metadata.MetadataPopulator.with_model_file(file) 616 | populator.load_metadata_buffer(metadata_buf) 617 | populator.load_associated_files([str(tmp_file)]) 618 | populator.populate() 619 | tmp_file.unlink() 620 | 621 | 622 | @smart_inference_mode() 623 | def run( 624 | data=ROOT / 'data/coco.yaml', # 'dataset.yaml path' 625 | weights=ROOT / 'yolo.pt', # weights path 626 | imgsz=(640, 640), # image (height, width) 627 | batch_size=1, # batch size 628 | device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu 629 | include=('torchscript', 'onnx'), # include formats 630 | class_agnostic=False, 631 | half=False, # FP16 half-precision export 632 | inplace=False, # set YOLO Detect() inplace=True 633 | keras=False, # use Keras 634 | optimize=False, # TorchScript: optimize for mobile 635 | int8=False, # CoreML/TF INT8 quantization 636 | dynamic=False, # ONNX/TF/TensorRT: dynamic axes 637 | simplify=False, # ONNX: simplify model 638 | opset=12, # ONNX: opset version 639 | verbose=False, # TensorRT: verbose log 640 | workspace=4, # TensorRT: workspace size (GB) 641 | nms=False, # TF: add NMS to model 642 | agnostic_nms=False, # TF: add agnostic NMS to model 643 | topk_per_class=100, # TF.js NMS: topk per class to keep 644 | topk_all=100, # TF.js NMS: topk for all classes to keep 645 | iou_thres=0.45, # TF.js NMS: IoU threshold 646 | conf_thres=0.25, # TF.js NMS: confidence threshold 647 | mask_resolution=56, 648 | pooler_scale=0.25, 649 | sampling_ratio=0, 650 | ): 651 | t = time.time() 652 | include = [x.lower() for x in include] # to lowercase 653 | fmts = tuple(export_formats()['Argument'][1:]) # --include arguments 654 | flags = [x in include for x in fmts] 655 | assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}' 656 | jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans 657 | file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights 658 | 659 | # Load PyTorch model 660 | device = select_device(device) 661 | if half: 662 | assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0' 663 | assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both' 664 | model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model 665 | 666 | # Checks 667 | imgsz *= 2 if len(imgsz) == 1 else 1 # expand 668 | if optimize: 669 | assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu' 670 | 671 | # Input 672 | gs = int(max(model.stride)) # grid size (max stride) 673 | imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples 674 | im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection 675 | 676 | # Update model 677 | model.eval() 678 | for k, m in model.named_modules(): 679 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)): 680 | m.inplace = inplace 681 | m.dynamic = dynamic 682 | m.export = True 683 | 684 | for _ in range(2): 685 | y = model(im) # dry runs 686 | if half and not coreml: 687 | im, model = im.half(), model.half() # to FP16 688 | shape = tuple((y[0] if isinstance(y, (tuple, list)) else y).shape) # model output shape 689 | metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata 690 | LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)") 691 | 692 | # Exports 693 | f = [''] * len(fmts) # exported filenames 694 | warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning 695 | if jit: # TorchScript 696 | f[0], _ = export_torchscript(model, im, file, optimize) 697 | if engine: # TensorRT required before ONNX 698 | f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose) 699 | if onnx or xml: # OpenVINO requires ONNX 700 | f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify) 701 | if onnx_end2end: 702 | labels = model.names 703 | f[2], _ = export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, len(labels), mask_resolution, pooler_scale, sampling_ratio ) 704 | if xml: # OpenVINO 705 | f[3], _ = export_openvino(file, metadata, half) 706 | if coreml: # CoreML 707 | f[4], _ = export_coreml(model, im, file, int8, half) 708 | if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats 709 | assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.' 710 | assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.' 711 | f[5], s_model = export_saved_model(model.cpu(), 712 | im, 713 | file, 714 | dynamic, 715 | tf_nms=nms or agnostic_nms or tfjs, 716 | agnostic_nms=agnostic_nms or tfjs, 717 | topk_per_class=topk_per_class, 718 | topk_all=topk_all, 719 | iou_thres=iou_thres, 720 | conf_thres=conf_thres, 721 | keras=keras) 722 | if pb or tfjs: # pb prerequisite to tfjs 723 | f[6], _ = export_pb(s_model, file) 724 | if tflite or edgetpu: 725 | f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms) 726 | if edgetpu: 727 | f[8], _ = export_edgetpu(file) 728 | add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs)) 729 | if tfjs: 730 | f[9], _ = export_tfjs(file) 731 | if paddle: # PaddlePaddle 732 | f[10], _ = export_paddle(model, im, file, metadata) 733 | 734 | # Finish 735 | f = [str(x) for x in f if x] # filter out '' and None 736 | if any(f): 737 | cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type 738 | dir = Path('segment' if seg else 'classify' if cls else '') 739 | h = '--half' if half else '' # --half FP16 inference arg 740 | s = "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \ 741 | "# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else '' 742 | if onnx_end2end: 743 | LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' 744 | f"\nResults saved to {colorstr('bold', file.parent.resolve())}" 745 | f"\nVisualize: https://netron.app") 746 | else: 747 | LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)' 748 | f"\nResults saved to {colorstr('bold', file.parent.resolve())}" 749 | f"\nDetect: python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}" 750 | f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}" 751 | f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}" 752 | f"\nVisualize: https://netron.app") 753 | return f # return list of exported files/dirs 754 | 755 | 756 | def parse_opt(): 757 | parser = argparse.ArgumentParser() 758 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path') 759 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model.pt path(s)') 760 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)') 761 | parser.add_argument('--batch-size', type=int, default=1, help='batch size') 762 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') 763 | parser.add_argument('--half', action='store_true', help='FP16 half-precision export') 764 | parser.add_argument('--inplace', action='store_true', help='set YOLO Detect() inplace=True') 765 | parser.add_argument('--keras', action='store_true', help='TF: use Keras') 766 | parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile') 767 | parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization') 768 | parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes') 769 | parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model') 770 | parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version') 771 | parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log') 772 | parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)') 773 | parser.add_argument('--nms', action='store_true', help='TF: add NMS to model') 774 | parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model') 775 | parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep') 776 | parser.add_argument('--topk-all', type=int, default=100, help='ONNX END2END/TF.js NMS: topk for all classes to keep') 777 | parser.add_argument('--iou-thres', type=float, default=0.45, help='ONNX END2END/TF.js NMS: IoU threshold') 778 | parser.add_argument('--conf-thres', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold') 779 | parser.add_argument('--class-agnostic', action='store_true', help='Agnostic NMS (single class)') 780 | parser.add_argument('--mask-resolution', type=int, default=160, help='Mask pooled output.') 781 | parser.add_argument('--pooler-scale', type=float, default=0.25, help='Multiplicative factor used to translate the ROI coordinates. ') 782 | parser.add_argument('--sampling-ratio', type=int, default=0, help='Number of sampling points in the interpolation. Allowed values are non-negative integers.') 783 | parser.add_argument( 784 | '--include', 785 | nargs='+', 786 | default=['torchscript'], 787 | help='torchscript, onnx, onnx_end2end, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle') 788 | opt = parser.parse_args() 789 | 790 | if 'onnx_end2end' in opt.include: 791 | opt.simplify = True 792 | opt.dynamic = True 793 | opt.inplace = True 794 | opt.half = False 795 | 796 | print_args(vars(opt)) 797 | return opt 798 | 799 | 800 | def main(opt): 801 | for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]): 802 | run(**vars(opt)) 803 | 804 | 805 | if __name__ == "__main__": 806 | opt = parse_opt() 807 | main(opt) 808 | -------------------------------------------------------------------------------- /install_dependencies.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to display usage message 4 | usage() { 5 | echo "Usage: $0 [--no-trex]" 1>&2 6 | exit 1 7 | } 8 | 9 | # Set default flags 10 | install_trex=true # TREx installation enabled by default 11 | install_base=true # Base dependencies always installed by default 12 | 13 | # Shared paths 14 | TENSORRT_REPO_PATH="/opt/nvidia/TensorRT" 15 | DOWNLOADS_PATH="/yolov9-qat/downloads" 16 | 17 | # Check command line options 18 | while [[ $# -gt 0 ]]; do 19 | case "$1" in 20 | --no-trex) 21 | install_trex=false 22 | ;; 23 | *) 24 | usage 25 | ;; 26 | esac 27 | shift 28 | done 29 | 30 | # Function to install system dependencies 31 | install_system_dependencies() { 32 | echo "Installing system dependencies..." 33 | apt-get update || return 1 34 | apt-get install -y zip htop screen libgl1-mesa-glx libfreetype6-dev || return 1 35 | return 0 36 | } 37 | 38 | # Function to upgrade TensorRT 39 | upgrade_tensorrt() { 40 | echo "Upgrading TensorRT..." 41 | local os="ubuntu2204" 42 | local trt_version="10.9.0" 43 | local cuda="cuda-12.8" 44 | local tensorrt_package="nv-tensorrt-local-repo-${os}-${trt_version}-${cuda}_1.0-1_amd64.deb" 45 | local download_path="${DOWNLOADS_PATH}/${tensorrt_package}" 46 | 47 | # Create downloads directory if it doesn't exist 48 | mkdir -p "$DOWNLOADS_PATH" || return 1 49 | 50 | # Check if the package already exists 51 | if [ ! -f "$download_path" ]; then 52 | echo "Downloading TensorRT package..." 53 | wget "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${trt_version}/local_repo/${tensorrt_package}" -O "$download_path" || return 1 54 | else 55 | echo "TensorRT package already exists at $download_path. Reusing existing file." 56 | fi 57 | 58 | # Install the package 59 | dpkg -i "$download_path" || return 1 60 | cp /var/nv-tensorrt-local-repo-${os}-${trt_version}-${cuda}/*keyring.gpg /usr/share/keyrings/ || return 1 61 | apt-get update || return 1 62 | apt-get install -y tensorrt || return 1 63 | apt-get purge "nv-tensorrt-local-repo*" -y || return 1 64 | 65 | # Keep the downloaded file for potential reuse 66 | echo "TensorRT package kept at $download_path for future use" 67 | return 0 68 | } 69 | 70 | # Function to install Python packages 71 | install_python_packages() { 72 | echo "Installing Python packages..." 73 | pip install --upgrade pip || return 1 74 | pip install --upgrade tensorrt==10.9.0.34 || return 1 75 | 76 | pip install seaborn \ 77 | thop \ 78 | "markdown-it-py>=2.2.0" \ 79 | "onnx-simplifier>=0.4.35" \ 80 | "onnxsim>=0.4.35" \ 81 | "onnxruntime>=1.16.3" \ 82 | "ujson>=5.9.0" \ 83 | "pycocotools>=2.0.7" \ 84 | "pycuda>=2025.1" || return 1 85 | 86 | pip install --upgrade onnx_graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com || return 1 87 | pip install pillow==9.5.0 --no-cache-dir --force-reinstall || return 1 88 | return 0 89 | } 90 | 91 | # Function to clone TensorRT repository once 92 | clone_tensorrt_repo() { 93 | echo "Cloning NVIDIA TensorRT repository..." 94 | 95 | if [ ! -d "$TENSORRT_REPO_PATH" ]; then 96 | # Create directory and clone repository 97 | mkdir -p "$(dirname "$TENSORRT_REPO_PATH")" || return 1 98 | git clone https://github.com/NVIDIA/TensorRT.git "$TENSORRT_REPO_PATH" || return 1 99 | cd "$TENSORRT_REPO_PATH" || return 1 100 | git checkout release/10.9 || return 1 101 | echo "TensorRT repository cloned successfully to $TENSORRT_REPO_PATH" 102 | else 103 | echo "TensorRT repository already exists at $TENSORRT_REPO_PATH" 104 | fi 105 | 106 | return 0 107 | } 108 | 109 | # Function to install PyTorch Quantization 110 | install_pytorch_quantization() { 111 | echo "Installing PyTorch Quantization..." 112 | 113 | # Navigate to PyTorch Quantization directory in TensorRT repo 114 | cd "$TENSORRT_REPO_PATH/tools/pytorch-quantization" || return 1 115 | 116 | # Install requirements and setup 117 | pip install -r requirements.txt || return 1 118 | python setup.py install || return 1 119 | 120 | echo "PyTorch Quantization installed successfully" 121 | return 0 122 | } 123 | 124 | # Function to install TREx 125 | install_trex_environment() { 126 | echo "Installing NVIDIA TREx environment..." 127 | # Check if TREx is not already installed 128 | if [ ! -d "/opt/nvidia_trex/env_trex" ]; then 129 | apt-get install -y graphviz || return 1 130 | pip install virtualenv "widgetsnbextension>=4.0.9" || return 1 131 | 132 | mkdir -p /opt/nvidia_trex || return 1 133 | cd /opt/nvidia_trex/ || return 1 134 | python3 -m virtualenv env_trex || return 1 135 | source env_trex/bin/activate || return 1 136 | pip install "Werkzeug>=2.2.2" "graphviz>=0.20.1" || return 1 137 | 138 | # Navigate to TREx directory in TensorRT repo 139 | cd "$TENSORRT_REPO_PATH/tools/experimental/trt-engine-explorer" || return 1 140 | 141 | source /opt/nvidia_trex/env_trex/bin/activate || return 1 142 | pip install -e . || return 1 143 | pip install jupyter_nbextensions_configurator notebook==6.4.12 ipywidgets || return 1 144 | jupyter nbextension enable widgetsnbextension --user --py || return 1 145 | deactivate || return 1 146 | else 147 | echo "NVIDIA TREx virtual environment already exists. Skipping installation." 148 | fi 149 | return 0 150 | } 151 | 152 | # Function to cleanup 153 | cleanup() { 154 | echo "Cleaning up..." 155 | apt-get clean 156 | rm -rf /var/lib/apt/lists/* 157 | } 158 | 159 | # Main installation process 160 | main() { 161 | # Install base dependencies (always) 162 | if $install_base; then 163 | install_system_dependencies || { echo "Failed to install system dependencies"; exit 1; } 164 | upgrade_tensorrt || { echo "Failed to upgrade TensorRT"; exit 1; } 165 | install_python_packages || { echo "Failed to install Python packages"; exit 1; } 166 | fi 167 | 168 | # Clone TensorRT repository once 169 | clone_tensorrt_repo || { echo "Failed to clone TensorRT repository"; exit 1; } 170 | 171 | # Install TREx by default unless --no-trex flag is provided 172 | if $install_trex; then 173 | install_trex_environment || { echo "Failed to install TREx environment"; exit 1; } 174 | fi 175 | 176 | # Always install PyTorch Quantization 177 | install_pytorch_quantization || { echo "Failed to install PyTorch Quantization"; exit 1; } 178 | 179 | # Final cleanup 180 | cleanup 181 | 182 | echo "Installation completed successfully." 183 | return 0 184 | } 185 | 186 | # Execute main function 187 | main 188 | 189 | 190 | -------------------------------------------------------------------------------- /models/experimental_trt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class TRT_EfficientNMS_85(torch.autograd.Function): 5 | '''TensorRT NMS operation''' 6 | @staticmethod 7 | def forward( 8 | ctx, 9 | boxes, 10 | scores, 11 | background_class=-1, 12 | box_coding=1, 13 | iou_threshold=0.45, 14 | max_output_boxes=100, 15 | plugin_version="1", 16 | score_activation=0, 17 | score_threshold=0.25, 18 | ): 19 | 20 | batch_size, num_boxes, num_classes = scores.shape 21 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) 22 | det_boxes = torch.randn(batch_size, max_output_boxes, 4) 23 | det_scores = torch.randn(batch_size, max_output_boxes) 24 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) 25 | return num_det, det_boxes, det_scores, det_classes 26 | 27 | @staticmethod 28 | def symbolic(g, 29 | boxes, 30 | scores, 31 | background_class=-1, 32 | box_coding=1, 33 | iou_threshold=0.45, 34 | max_output_boxes=100, 35 | plugin_version="1", 36 | score_activation=0, 37 | score_threshold=0.25): 38 | out = g.op("TRT::EfficientNMS_TRT", 39 | boxes, 40 | scores, 41 | background_class_i=background_class, 42 | box_coding_i=box_coding, 43 | iou_threshold_f=iou_threshold, 44 | max_output_boxes_i=max_output_boxes, 45 | plugin_version_s=plugin_version, 46 | score_activation_i=score_activation, 47 | score_threshold_f=score_threshold, 48 | outputs=4) 49 | nums, boxes, scores, classes = out 50 | return nums, boxes, scores, classes 51 | 52 | class TRT_EfficientNMS(torch.autograd.Function): 53 | '''TensorRT NMS operation''' 54 | @staticmethod 55 | def forward( 56 | ctx, 57 | boxes, 58 | scores, 59 | background_class=-1, 60 | box_coding=1, 61 | iou_threshold=0.45, 62 | max_output_boxes=100, 63 | plugin_version="1", 64 | score_activation=0, 65 | score_threshold=0.25, 66 | class_agnostic=0, 67 | ): 68 | 69 | batch_size, num_boxes, num_classes = scores.shape 70 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) 71 | det_boxes = torch.randn(batch_size, max_output_boxes, 4) 72 | det_scores = torch.randn(batch_size, max_output_boxes) 73 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) 74 | return num_det, det_boxes, det_scores, det_classes 75 | 76 | @staticmethod 77 | def symbolic(g, 78 | boxes, 79 | scores, 80 | background_class=-1, 81 | box_coding=1, 82 | iou_threshold=0.45, 83 | max_output_boxes=100, 84 | plugin_version="1", 85 | score_activation=0, 86 | score_threshold=0.25, 87 | class_agnostic=0): 88 | out = g.op("TRT::EfficientNMS_TRT", 89 | boxes, 90 | scores, 91 | background_class_i=background_class, 92 | box_coding_i=box_coding, 93 | iou_threshold_f=iou_threshold, 94 | max_output_boxes_i=max_output_boxes, 95 | plugin_version_s=plugin_version, 96 | score_activation_i=score_activation, 97 | class_agnostic_i=class_agnostic, 98 | score_threshold_f=score_threshold, 99 | outputs=4) 100 | nums, boxes, scores, classes = out 101 | return nums, boxes, scores, classes 102 | 103 | class TRT_EfficientNMSX_85(torch.autograd.Function): 104 | '''TensorRT NMS operation''' 105 | @staticmethod 106 | def forward( 107 | ctx, 108 | boxes, 109 | scores, 110 | background_class=-1, 111 | box_coding=1, 112 | iou_threshold=0.45, 113 | max_output_boxes=100, 114 | plugin_version="1", 115 | score_activation=0, 116 | score_threshold=0.25 117 | ): 118 | 119 | batch_size, num_boxes, num_classes = scores.shape 120 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) 121 | det_boxes = torch.randn(batch_size, max_output_boxes, 4) 122 | det_scores = torch.randn(batch_size, max_output_boxes) 123 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) 124 | det_indices = torch.randint(0,num_boxes,(batch_size, max_output_boxes), dtype=torch.int32) 125 | return num_det, det_boxes, det_scores, det_classes, det_indices 126 | 127 | @staticmethod 128 | def symbolic(g, 129 | boxes, 130 | scores, 131 | background_class=-1, 132 | box_coding=1, 133 | iou_threshold=0.45, 134 | max_output_boxes=100, 135 | plugin_version="1", 136 | score_activation=0, 137 | score_threshold=0.25): 138 | out = g.op("TRT::EfficientNMSX_TRT", 139 | boxes, 140 | scores, 141 | background_class_i=background_class, 142 | box_coding_i=box_coding, 143 | iou_threshold_f=iou_threshold, 144 | max_output_boxes_i=max_output_boxes, 145 | plugin_version_s=plugin_version, 146 | score_activation_i=score_activation, 147 | score_threshold_f=score_threshold, 148 | outputs=5) 149 | nums, boxes, scores, classes, det_indices = out 150 | return nums, boxes, scores, classes, det_indices 151 | 152 | class TRT_EfficientNMSX(torch.autograd.Function): 153 | '''TensorRT NMS operation''' 154 | @staticmethod 155 | def forward( 156 | ctx, 157 | boxes, 158 | scores, 159 | background_class=-1, 160 | box_coding=1, 161 | iou_threshold=0.45, 162 | max_output_boxes=100, 163 | plugin_version="1", 164 | score_activation=0, 165 | score_threshold=0.25, 166 | class_agnostic=0, 167 | ): 168 | 169 | batch_size, num_boxes, num_classes = scores.shape 170 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32) 171 | det_boxes = torch.randn(batch_size, max_output_boxes, 4) 172 | det_scores = torch.randn(batch_size, max_output_boxes) 173 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32) 174 | det_indices = torch.randint(0,num_boxes,(batch_size, max_output_boxes), dtype=torch.int32) 175 | return num_det, det_boxes, det_scores, det_classes, det_indices 176 | 177 | @staticmethod 178 | def symbolic(g, 179 | boxes, 180 | scores, 181 | background_class=-1, 182 | box_coding=1, 183 | iou_threshold=0.45, 184 | max_output_boxes=100, 185 | plugin_version="1", 186 | score_activation=0, 187 | score_threshold=0.25, 188 | class_agnostic=0): 189 | out = g.op("TRT::EfficientNMSX_TRT", 190 | boxes, 191 | scores, 192 | background_class_i=background_class, 193 | box_coding_i=box_coding, 194 | iou_threshold_f=iou_threshold, 195 | max_output_boxes_i=max_output_boxes, 196 | plugin_version_s=plugin_version, 197 | score_activation_i=score_activation, 198 | class_agnostic_i=class_agnostic, 199 | score_threshold_f=score_threshold, 200 | outputs=5) 201 | nums, boxes, scores, classes, det_indices = out 202 | return nums, boxes, scores, classes, det_indices 203 | 204 | class TRT_ROIAlign(torch.autograd.Function): 205 | @staticmethod 206 | def forward( 207 | ctx, 208 | X, 209 | rois, 210 | batch_indices, 211 | coordinate_transformation_mode= 1, 212 | mode=1, # 1- avg pooling / 0 - max pooling 213 | output_height=160, 214 | output_width=160, 215 | sampling_ratio=0, 216 | spatial_scale=0.25, 217 | ): 218 | device = rois.device 219 | dtype = rois.dtype 220 | N, C, H, W = X.shape 221 | num_rois = rois.shape[0] 222 | return torch.randn((num_rois, C, output_height, output_width), device=device, dtype=dtype) 223 | 224 | @staticmethod 225 | def symbolic( 226 | g, 227 | X, 228 | rois, 229 | batch_indices, 230 | coordinate_transformation_mode=1, 231 | mode=1, 232 | output_height=160, 233 | output_width=160, 234 | sampling_ratio=0, 235 | spatial_scale=0.25, 236 | ): 237 | return g.op( 238 | "TRT::ROIAlign_TRT", 239 | X, 240 | rois, 241 | batch_indices, 242 | coordinate_transformation_mode_i=coordinate_transformation_mode, 243 | mode_i=mode, 244 | output_height_i=output_height, 245 | output_width_i=output_width, 246 | sampling_ratio_i=sampling_ratio, 247 | spatial_scale_f=spatial_scale, 248 | ) 249 | 250 | class ONNX_EfficientNMS_TRT(nn.Module): 251 | '''onnx module with TensorRT NMS operation.''' 252 | def __init__(self, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): 253 | super().__init__() 254 | assert max_wh is None 255 | self.device = device if device else torch.device('cpu') 256 | self.class_agnostic = 1 if class_agnostic else 0 257 | self.background_class = -1, 258 | self.box_coding = 1, 259 | self.iou_threshold = iou_thres 260 | self.max_obj = max_obj 261 | self.plugin_version = '1' 262 | self.score_activation = 0 263 | self.score_threshold = score_thres 264 | self.n_classes=n_classes 265 | 266 | 267 | def forward(self, x): 268 | if isinstance(x, list): 269 | x = x[1] 270 | x = x.permute(0, 2, 1) 271 | bboxes_x = x[..., 0:1] 272 | bboxes_y = x[..., 1:2] 273 | bboxes_w = x[..., 2:3] 274 | bboxes_h = x[..., 3:4] 275 | bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1) 276 | bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4] 277 | obj_conf = x[..., 4:] 278 | scores = obj_conf 279 | if self.class_agnostic == 1: 280 | num_det, det_boxes, det_scores, det_classes = TRT_EfficientNMS.apply(bboxes, scores, self.background_class, self.box_coding, 281 | self.iou_threshold, self.max_obj, 282 | self.plugin_version, self.score_activation, 283 | self.score_threshold, self.class_agnostic) 284 | else: 285 | num_det, det_boxes, det_scores, det_classes = TRT_EfficientNMS_85.apply(bboxes, scores, self.background_class, self.box_coding, 286 | self.iou_threshold, self.max_obj, 287 | self.plugin_version, self.score_activation, 288 | self.score_threshold) 289 | return num_det, det_boxes, det_scores, det_classes 290 | 291 | class ONNX_EfficientNMSX_TRT(nn.Module): 292 | '''onnx module with TensorRT NMS operation.''' 293 | def __init__(self, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80): 294 | super().__init__() 295 | assert max_wh is None 296 | self.device = device if device else torch.device('cpu') 297 | self.class_agnostic = 1 if class_agnostic else 0 298 | self.background_class = -1, 299 | self.box_coding = 1, 300 | self.iou_threshold = iou_thres 301 | self.max_obj = max_obj 302 | self.plugin_version = '1' 303 | self.score_activation = 0 304 | self.score_threshold = score_thres 305 | self.n_classes=n_classes 306 | 307 | 308 | def forward(self, x): 309 | if isinstance(x, list): 310 | x = x[1] 311 | x = x.permute(0, 2, 1) 312 | bboxes_x = x[..., 0:1] 313 | bboxes_y = x[..., 1:2] 314 | bboxes_w = x[..., 2:3] 315 | bboxes_h = x[..., 3:4] 316 | bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1) 317 | bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4] 318 | obj_conf = x[..., 4:] 319 | scores = obj_conf 320 | if self.class_agnostic == 1: 321 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX.apply(bboxes, scores, self.background_class, self.box_coding, 322 | self.iou_threshold, self.max_obj, 323 | self.plugin_version, self.score_activation, 324 | self.score_threshold, self.class_agnostic) 325 | else: 326 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX_85.apply(bboxes, scores, self.background_class, self.box_coding, 327 | self.iou_threshold, self.max_obj, 328 | self.plugin_version, self.score_activation, 329 | self.score_threshold) 330 | return num_det, det_boxes, det_scores, det_classes, det_indices 331 | 332 | 333 | 334 | class End2End_TRT(nn.Module): 335 | '''export onnx or tensorrt model with NMS operation.''' 336 | def __init__(self, model, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, mask_resolution=56, pooler_scale=0.25, sampling_ratio=0, max_wh=None, device=None, n_classes=80, is_det_model=True): 337 | super().__init__() 338 | device = device if device else torch.device('cpu') 339 | assert isinstance(max_wh,(int)) or max_wh is None 340 | self.model = model.to(device) 341 | self.model.model[-1].end2end = True 342 | if is_det_model: 343 | self.patch_model = ONNX_EfficientNMS_TRT 344 | self.end2end = self.patch_model(class_agnostic, max_obj, iou_thres, score_thres, max_wh, device, n_classes) 345 | else: 346 | self.patch_model = ONNX_End2End_MASK_TRT 347 | self.end2end = self.patch_model(class_agnostic, max_obj, iou_thres, score_thres, mask_resolution, pooler_scale, sampling_ratio, max_wh, device, n_classes) 348 | self.end2end.eval() 349 | 350 | def forward(self, x): 351 | x = self.model(x) 352 | x = self.end2end(x) 353 | return x 354 | 355 | 356 | class ONNX_End2End_MASK_TRT(nn.Module): 357 | """onnx module with ONNX-TensorRT NMS/ROIAlign operation.""" 358 | def __init__( 359 | self, 360 | class_agnostic=False, 361 | max_obj=100, 362 | iou_thres=0.45, 363 | score_thres=0.25, 364 | mask_resolution=160, 365 | pooler_scale=0.25, 366 | sampling_ratio=0, 367 | max_wh=None, 368 | device=None, 369 | n_classes=80 370 | ): 371 | super().__init__() 372 | assert isinstance(max_wh,(int)) or max_wh is None 373 | self.device = device if device else torch.device('cpu') 374 | self.class_agnostic = 1 if class_agnostic else 0 375 | self.max_obj = max_obj 376 | self.background_class = -1, 377 | self.box_coding = 1, 378 | self.iou_threshold = iou_thres 379 | self.max_obj = max_obj 380 | self.plugin_version = '1' 381 | self.score_activation = 0 382 | self.score_threshold = score_thres 383 | self.n_classes=n_classes 384 | self.mask_resolution = mask_resolution 385 | self.pooler_scale = pooler_scale 386 | self.sampling_ratio = sampling_ratio 387 | 388 | def forward(self, x): 389 | if isinstance(x, list): ## remove auxiliary branch 390 | x = x[1] 391 | det=x[0] 392 | proto=x[1] 393 | det = det.permute(0, 2, 1) 394 | 395 | bboxes_x = det[..., 0:1] 396 | bboxes_y = det[..., 1:2] 397 | bboxes_w = det[..., 2:3] 398 | bboxes_h = det[..., 3:4] 399 | bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1) 400 | bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4] 401 | scores = det[..., 4: 4 + self.n_classes] 402 | 403 | batch_size, nm, proto_h, proto_w = proto.shape 404 | total_object = batch_size * self.max_obj 405 | masks = det[..., 4 + self.n_classes : 4 + self.n_classes + nm] 406 | if self.class_agnostic == 1: 407 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX.apply(bboxes, scores, self.background_class, self.box_coding, 408 | self.iou_threshold, self.max_obj, 409 | self.plugin_version, self.score_activation, 410 | self.score_threshold,self.class_agnostic) 411 | else: 412 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX_85.apply(bboxes, scores, self.background_class, self.box_coding, 413 | self.iou_threshold, self.max_obj, 414 | self.plugin_version, self.score_activation, 415 | self.score_threshold) 416 | 417 | batch_indices = torch.ones_like(det_indices) * torch.arange(batch_size, device=self.device, dtype=torch.int32).unsqueeze(1) 418 | batch_indices = batch_indices.view(total_object).to(torch.long) 419 | det_indices = det_indices.view(total_object).to(torch.long) 420 | det_masks = masks[batch_indices, det_indices] 421 | 422 | 423 | pooled_proto = TRT_ROIAlign.apply( proto, 424 | det_boxes.view(total_object, 4), 425 | batch_indices, 426 | 1, 427 | 1, 428 | self.mask_resolution, 429 | self.mask_resolution, 430 | self.sampling_ratio, 431 | self.pooler_scale 432 | ) 433 | pooled_proto = pooled_proto.view( 434 | total_object, nm, self.mask_resolution * self.mask_resolution, 435 | ) 436 | 437 | det_masks = ( 438 | torch.matmul(det_masks.unsqueeze(dim=1), pooled_proto) 439 | .sigmoid() 440 | .view(batch_size, self.max_obj, self.mask_resolution * self.mask_resolution) 441 | ) 442 | 443 | return num_det, det_boxes, det_scores, det_classes, det_masks -------------------------------------------------------------------------------- /models/quantize.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: MIT 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a 6 | # copy of this software and associated documentation files (the "Software"), 7 | # to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the 10 | # Software is furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 21 | # DEALINGS IN THE SOFTWARE. 22 | ################################################################################ 23 | 24 | 25 | import os 26 | import re 27 | from typing import List, Callable, Union, Dict 28 | from tqdm import tqdm 29 | from copy import deepcopy 30 | 31 | # PyTorch 32 | import torch 33 | import torch.optim as optim 34 | from torch.cuda import amp 35 | import torch.nn.functional as F 36 | 37 | # Pytorch Quantization 38 | from pytorch_quantization import nn as quant_nn 39 | from pytorch_quantization.nn.modules import _utils as quant_nn_utils 40 | from pytorch_quantization import calib 41 | from pytorch_quantization.tensor_quant import QuantDescriptor 42 | from pytorch_quantization import quant_modules 43 | from absl import logging as quant_logging 44 | 45 | import onnx_graphsurgeon as gs 46 | from utils.general import (check_requirements, LOGGER,colorstr) 47 | from models.quantize_rules import find_quantizer_pairs 48 | 49 | 50 | class QuantAdd(torch.nn.Module, quant_nn_utils.QuantMixin): 51 | def __init__(self, quantization): 52 | super().__init__() 53 | 54 | if quantization: 55 | self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor()) 56 | self._input1_quantizer = quant_nn.TensorQuantizer(QuantDescriptor()) 57 | self.quantization = quantization 58 | 59 | def forward(self, x, y): 60 | if self.quantization: 61 | return self._input0_quantizer(x) + self._input1_quantizer(y) 62 | return x + y 63 | 64 | 65 | class QuantADownAvgChunk(torch.nn.Module): 66 | def __init__(self): 67 | super().__init__() 68 | self._chunk_quantizer = quant_nn.TensorQuantizer(QuantDescriptor()) 69 | self._chunk_quantizer._calibrator._torch_hist = True 70 | self.avg_pool2d = torch.nn.AvgPool2d(2, 1, 0, False, True) 71 | 72 | def forward(self, x): 73 | x = self.avg_pool2d(x) 74 | x = self._chunk_quantizer(x) 75 | return x.chunk(2, 1) 76 | 77 | class QuantAConvAvgChunk(torch.nn.Module): 78 | def __init__(self): 79 | super().__init__() 80 | self._chunk_quantizer = quant_nn.TensorQuantizer(QuantDescriptor(num_bits=8, calib_method="histogram")) 81 | self._chunk_quantizer._calibrator._torch_hist = True 82 | self.avg_pool2d = torch.nn.AvgPool2d(2, 1, 0, False, True) 83 | 84 | def forward(self, x): 85 | x = self.avg_pool2d(x) 86 | x = self._chunk_quantizer(x) 87 | return x 88 | 89 | class QuantRepNCSPELAN4Chunk(torch.nn.Module): 90 | def __init__(self, c): 91 | super().__init__() 92 | self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor()) 93 | self.c = c 94 | def forward(self, x, chunks, dims): 95 | return torch.split(self._input0_quantizer(x), (self.c, self.c), dims) 96 | 97 | class QuantUpsample(torch.nn.Module): 98 | def __init__(self, size, scale_factor, mode): 99 | super().__init__() 100 | self.size = size 101 | self.scale_factor = scale_factor 102 | self.mode = mode 103 | self._input_quantizer = quant_nn.TensorQuantizer(QuantDescriptor()) 104 | 105 | def forward(self, x): 106 | return F.interpolate(self._input_quantizer(x), self.size, self.scale_factor, self.mode) 107 | 108 | 109 | class QuantConcat(torch.nn.Module): 110 | def __init__(self, dim): 111 | super().__init__() 112 | self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor()) 113 | self._input1_quantizer = quant_nn.TensorQuantizer(QuantDescriptor()) 114 | self.dim = dim 115 | 116 | def forward(self, x, dim): 117 | x_0 = self._input0_quantizer(x[0]) 118 | x_1 = self._input1_quantizer(x[1]) 119 | return torch.cat((x_0, x_1), self.dim) 120 | 121 | 122 | class disable_quantization: 123 | def __init__(self, model): 124 | self.model = model 125 | 126 | def apply(self, disabled=True): 127 | for name, module in self.model.named_modules(): 128 | if isinstance(module, quant_nn.TensorQuantizer): 129 | module._disabled = disabled 130 | 131 | def __enter__(self): 132 | self.apply(True) 133 | 134 | def __exit__(self, *args, **kwargs): 135 | self.apply(False) 136 | 137 | 138 | class enable_quantization: 139 | def __init__(self, model): 140 | self.model = model 141 | 142 | def apply(self, enabled=True): 143 | for name, module in self.model.named_modules(): 144 | if isinstance(module, quant_nn.TensorQuantizer): 145 | module._disabled = not enabled 146 | 147 | def __enter__(self): 148 | self.apply(True) 149 | return self 150 | 151 | def __exit__(self, *args, **kwargs): 152 | self.apply(False) 153 | 154 | 155 | def have_quantizer(module): 156 | for name, module in module.named_modules(): 157 | if isinstance(module, quant_nn.TensorQuantizer): 158 | return True 159 | 160 | 161 | # Initialize PyTorch Quantization 162 | def initialize(): 163 | quant_modules.initialize( ) 164 | quant_desc_input = QuantDescriptor(calib_method="histogram") 165 | quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input) 166 | quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input) 167 | quant_nn.QuantAvgPool2d.set_default_quant_desc_input(quant_desc_input) 168 | quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input) 169 | 170 | quant_logging.set_verbosity(quant_logging.ERROR) 171 | 172 | 173 | def remove_redundant_qdq_model(onnx_model, f): 174 | check_requirements('onnx') 175 | import onnx 176 | 177 | graph = gs.import_onnx(onnx_model) 178 | nodes = graph.nodes 179 | 180 | mul_nodes = [node for node in nodes if node.op == "Mul" and node.i(0).op == "Conv" and node.i(1).op == "Sigmoid"] 181 | many_outputs_mul_nodes = [] 182 | 183 | for node in mul_nodes: 184 | try: 185 | for i in range(99): 186 | node.o(i) 187 | except: 188 | if i > 1: 189 | mul_nodename_outnum = {"node": node, "out_num": i} 190 | many_outputs_mul_nodes.append(mul_nodename_outnum) 191 | 192 | for node_dict in many_outputs_mul_nodes: 193 | if node_dict["out_num"] == 2: 194 | if node_dict["node"].o(0).op == "QuantizeLinear" and node_dict["node"].o(1).op == "QuantizeLinear": 195 | if node_dict["node"].o(1).o(0).o(0).op == "Concat": 196 | concat_dq_out_name = node_dict["node"].o(1).o(0).outputs[0].name 197 | for i, concat_input in enumerate(node_dict["node"].o(1).o(0).o(0).inputs): 198 | if concat_input.name == concat_dq_out_name: 199 | node_dict["node"].o(1).o(0).o(0).inputs[i] = node_dict["node"].o(0).o(0).outputs[0] 200 | else: 201 | node_dict["node"].o(1).o(0).o(0).inputs[0] = node_dict["node"].o(0).o(0).outputs[0] 202 | 203 | 204 | # elif node_dict["node"].o(0).op == "QuantizeLinear" and node_dict["node"].o(1).op == "Concat": 205 | # concat_dq_out_name = node_dict["node"].outputs[0].outputs[0].inputs[0].name 206 | # for i, concat_input in enumerate(node_dict["node"].outputs[0].outputs[1].inputs): 207 | # if concat_input.name == concat_dq_out_name: 208 | # #print("elif", concat_input.name, concat_dq_out_name ) 209 | # #print("will-be", node_dict["node"].outputs[0].outputs[1].inputs[i], node_dict["node"].outputs[0].outputs[0].o().outputs[0] ) 210 | # node_dict["node"].outputs[0].outputs[1].inputs[i] = node_dict["node"].outputs[0].outputs[0].o().outputs[0] 211 | 212 | 213 | # add_nodes = [node for node in nodes if node.op == "Add"] 214 | # many_outputs_add_nodes = [] 215 | # for node in add_nodes: 216 | # try: 217 | # for i in range(99): 218 | # node.o(i) 219 | # except: 220 | # if i > 1 and node.o().op == "QuantizeLinear": 221 | # add_nodename_outnum = {"node": node, "out_num": i} 222 | # many_outputs_add_nodes.append(add_nodename_outnum) 223 | 224 | 225 | # for node_dict in many_outputs_add_nodes: 226 | # if node_dict["node"].outputs[0].outputs[0].op == "QuantizeLinear" and node_dict["node"].outputs[0].outputs[1].op == "Concat": 227 | # concat_dq_out_name = node_dict["node"].outputs[0].outputs[0].inputs[0].name 228 | # for i, concat_input in enumerate(node_dict["node"].outputs[0].outputs[1].inputs): 229 | # if concat_input.name == concat_dq_out_name: 230 | # node_dict["node"].outputs[0].outputs[1].inputs[i] = node_dict["node"].outputs[0].outputs[0].o().outputs[0] 231 | 232 | onnx.save(gs.export_onnx(graph), f) 233 | 234 | 235 | def transfer_torch_to_quantization(nninstance : torch.nn.Module, quantmodule): 236 | quant_instance = quantmodule.__new__(quantmodule) 237 | for k, val in vars(nninstance).items(): 238 | setattr(quant_instance, k, val) 239 | 240 | def __init__(self): 241 | if self.__class__.__name__ == 'QuantAvgPool2d': 242 | self.__init__(nninstance.kernel_size, nninstance.stride, nninstance.padding, nninstance.ceil_mode, nninstance.count_include_pad) 243 | elif isinstance(self, quant_nn_utils.QuantInputMixin): 244 | quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__) 245 | self.init_quantizer(quant_desc_input) 246 | 247 | # Turn on torch_hist to enable higher calibration speeds 248 | if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator): 249 | self._input_quantizer._calibrator._torch_hist = True 250 | else: 251 | quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__) 252 | self.init_quantizer(quant_desc_input, quant_desc_weight) 253 | # Turn on torch_hist to enable higher calibration speeds 254 | if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator): 255 | self._input_quantizer._calibrator._torch_hist = True 256 | self._weight_quantizer._calibrator._torch_hist = True 257 | 258 | __init__(quant_instance) 259 | return quant_instance 260 | 261 | 262 | def quantization_ignore_match(ignore_policy : Union[str, List[str], Callable], path : str) -> bool: 263 | 264 | if ignore_policy is None: return False 265 | if isinstance(ignore_policy, Callable): 266 | return ignore_policy(path) 267 | 268 | if isinstance(ignore_policy, str) or isinstance(ignore_policy, List): 269 | 270 | if isinstance(ignore_policy, str): 271 | ignore_policy = [ignore_policy] 272 | 273 | if path in ignore_policy: return True 274 | for item in ignore_policy: 275 | if re.match(item, path): 276 | return True 277 | return False 278 | 279 | 280 | def set_module(model, submodule_key, module): 281 | tokens = submodule_key.split('.') 282 | sub_tokens = tokens[:-1] 283 | cur_mod = model 284 | for s in sub_tokens: 285 | cur_mod = getattr(cur_mod, s) 286 | setattr(cur_mod, tokens[-1], module) 287 | 288 | 289 | def replace_to_quantization_module(model : torch.nn.Module, ignore_policy : Union[str, List[str], Callable] = None, prefixx=colorstr('QAT:')): 290 | 291 | module_dict = {} 292 | for entry in quant_modules._DEFAULT_QUANT_MAP: 293 | module = getattr(entry.orig_mod, entry.mod_name) 294 | module_dict[id(module)] = entry.replace_mod 295 | 296 | def recursive_and_replace_module(module, prefix=""): 297 | for name in module._modules: 298 | submodule = module._modules[name] 299 | path = name if prefix == "" else prefix + "." + name 300 | recursive_and_replace_module(submodule, path) 301 | 302 | submodule_id = id(type(submodule)) 303 | if submodule_id in module_dict: 304 | ignored = quantization_ignore_match(ignore_policy, path) 305 | if ignored: 306 | LOGGER.info(f'{prefixx} Quantization: {path} has ignored.') 307 | continue 308 | 309 | module._modules[name] = transfer_torch_to_quantization(submodule, module_dict[submodule_id]) 310 | 311 | recursive_and_replace_module(model) 312 | 313 | 314 | def get_attr_with_path(m, path): 315 | def sub_attr(m, names): 316 | name = names[0] 317 | value = getattr(m, name) 318 | if len(names) == 1: 319 | return value 320 | return sub_attr(value, names[1:]) 321 | return sub_attr(m, path.split(".")) 322 | 323 | def repncspelan4_qaunt_forward(self, x): 324 | if hasattr(self, "repncspelan4chunkop"): 325 | y = list(self.repncspelan4chunkop(self.cv1(x), 2, 1)) 326 | y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) 327 | return self.cv4(torch.cat(y, 1)) 328 | else: 329 | y = list(self.cv1(x).split((self.c, self.c), 1)) 330 | y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) 331 | return self.cv4(torch.cat(y, 1)) 332 | 333 | def repbottleneck_quant_forward(self, x): 334 | if hasattr(self, "addop"): 335 | return self.addop(x, self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x)) 336 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) 337 | 338 | def upsample_quant_forward(self, x): 339 | if hasattr(self, "upsampleop"): 340 | return self.upsampleop(x) 341 | return F.interpolate(x) 342 | 343 | def concat_quant_forward(self, x): 344 | if hasattr(self, "concatop"): 345 | return self.concatop(x, self.d) 346 | return torch.cat(x, self.d) 347 | 348 | def adown_quant_forward(self, x): 349 | if hasattr(self, "adownchunkop"): 350 | x1, x2 = self.adownchunkop(x) 351 | x1 = self.cv1(x1) 352 | x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1) 353 | x2 = self.cv2(x2) 354 | return torch.cat((x1, x2), 1) 355 | 356 | def aconv_quant_forward(self, x): 357 | if hasattr(self, "aconvchunkop"): 358 | x = self.aconvchunkop(x) 359 | return self.cv1(x) 360 | 361 | def apply_custom_rules_to_quantizer(model : torch.nn.Module, export_onnx : Callable): 362 | export_onnx(model, "quantization-custom-rules-temp.onnx") 363 | pairs = find_quantizer_pairs("quantization-custom-rules-temp.onnx") 364 | for major, sub in pairs: 365 | print(f"Rules: {sub} match to {major}") 366 | get_attr_with_path(model, sub)._input_quantizer = get_attr_with_path(model, major)._input_quantizer 367 | os.remove("quantization-custom-rules-temp.onnx") 368 | 369 | for name, module in model.named_modules(): 370 | if module.__class__.__name__ == "RepNBottleneck": 371 | if module.add: 372 | print(f"Rules: {name}.add match to {name}.cv1") 373 | major = module.cv1.conv._input_quantizer 374 | module.addop._input0_quantizer = major 375 | module.addop._input1_quantizer = major 376 | 377 | if isinstance(module, torch.nn.MaxPool2d): 378 | quant_conv_desc_input = QuantDescriptor(num_bits=8, calib_method='histogram') 379 | quant_maxpool2d = quant_nn.QuantMaxPool2d(module.kernel_size, 380 | module.stride, 381 | module.padding, 382 | module.dilation, 383 | module.ceil_mode, 384 | quant_desc_input = quant_conv_desc_input) 385 | set_module(model, name, quant_maxpool2d) 386 | 387 | if module.__class__.__name__ == 'ADown': 388 | module.cv1.conv._input_quantizer = module.adownchunkop._chunk_quantizer 389 | if module.__class__.__name__ == 'AConv': 390 | module.cv1.conv._input_quantizer = module.aconvchunkop._chunk_quantizer 391 | 392 | def replace_custom_module_forward(model): 393 | for name, module in model.named_modules(): 394 | # if module.__class__.__name__ == "RepNCSPELAN4": 395 | # if not hasattr(module, "repncspelan4chunkop"): 396 | # print(f"Add RepNCSPELAN4QuantChunk to {name}") 397 | # module.repncspelan4chunkop = QuantRepNCSPELAN4Chunk(module.c) 398 | # module.__class__.forward = repncspelan4_qaunt_forward 399 | 400 | if module.__class__.__name__ == "ADown": 401 | if not hasattr(module, "adownchunkop"): 402 | print(f"Add ADownQuantChunk to {name}") 403 | module.adownchunkop = QuantADownAvgChunk() 404 | module.__class__.forward = adown_quant_forward 405 | 406 | if module.__class__.__name__ == "AConv": 407 | if not hasattr(module, "aconvchunkop"): 408 | print(f"Add AConvQuantChunk to {name}") 409 | module.aconvchunkop = QuantAConvAvgChunk() 410 | module.__class__.forward = aconv_quant_forward 411 | 412 | if module.__class__.__name__ == "RepNBottleneck": 413 | if module.add: 414 | if not hasattr(module, "addop"): 415 | print(f"Add QuantAdd to {name}") 416 | module.addop = QuantAdd(module.add) 417 | module.__class__.forward = repbottleneck_quant_forward 418 | 419 | if module.__class__.__name__ == "Concat": 420 | if not hasattr(module, "concatop"): 421 | print(f"Add QuantConcat to {name}") 422 | module.concatop = QuantConcat(module.d) 423 | module.__class__.forward = concat_quant_forward 424 | 425 | if module.__class__.__name__ == "Upsample": 426 | if not hasattr(module, "upsampleop"): 427 | print(f"Add QuantUpsample to {name}") 428 | module.upsampleop = QuantUpsample(module.size, module.scale_factor, module.mode) 429 | module.__class__.forward = upsample_quant_forward 430 | 431 | def calibrate_model(model : torch.nn.Module, dataloader, device, num_batch=25): 432 | 433 | def compute_amax(model, **kwargs): 434 | for name, module in model.named_modules(): 435 | if isinstance(module, quant_nn.TensorQuantizer): 436 | if module._calibrator is not None: 437 | if isinstance(module._calibrator, calib.MaxCalibrator): 438 | module.load_calib_amax() 439 | else: 440 | module.load_calib_amax(**kwargs) 441 | 442 | module._amax = module._amax.to(device) 443 | 444 | def collect_stats(model, data_loader, device, num_batch=200): 445 | """Feed data to the network and collect statistics""" 446 | # Enable calibrators 447 | model.eval() 448 | for name, module in model.named_modules(): 449 | 450 | if isinstance(module, quant_nn.TensorQuantizer): 451 | if module._calibrator is not None: 452 | module.disable_quant() 453 | module.enable_calib() 454 | else: 455 | module.disable() 456 | 457 | # Feed data to the network for collecting stats 458 | with torch.no_grad(): 459 | for i, datas in tqdm(enumerate(data_loader), total=num_batch, desc="Collect stats for calibrating"): 460 | imgs = datas[0].to(device, non_blocking=True).float() / 255.0 461 | model(imgs) 462 | 463 | if i >= num_batch: 464 | break 465 | 466 | # Disable calibrators 467 | for name, module in model.named_modules(): 468 | if isinstance(module, quant_nn.TensorQuantizer): 469 | if module._calibrator is not None: 470 | module.enable_quant() 471 | module.disable_calib() 472 | else: 473 | module.enable() 474 | 475 | with torch.no_grad(): 476 | collect_stats(model, dataloader, device, num_batch=num_batch) 477 | #compute_amax(model, method="percentile", percentile=99.99, strict=True) # strict=False avoid Exception when some quantizer are never used 478 | compute_amax(model, method="mse") 479 | 480 | 481 | 482 | def finetune( 483 | model : torch.nn.Module, train_dataloader, per_epoch_callback : Callable = None, preprocess : Callable = None, 484 | nepochs=10, early_exit_batchs_per_epoch=1000, lrschedule : Dict = None, fp16=True, learningrate=1e-5, 485 | supervision_policy : Callable = None, prefix=colorstr('QAT:') 486 | ): 487 | origin_model = deepcopy(model).eval() 488 | disable_quantization(origin_model).apply() 489 | 490 | model.train() 491 | model.requires_grad_(True) 492 | 493 | scaler = amp.GradScaler(enabled=fp16) 494 | optimizer = optim.Adam(model.parameters(), learningrate) 495 | quant_lossfn = torch.nn.MSELoss() 496 | device = next(model.parameters()).device 497 | 498 | 499 | if lrschedule is None: 500 | lrschedule = { 501 | 0: 1e-6, 502 | 6: 1e-5, 503 | 7: 1e-6 504 | } 505 | 506 | 507 | def make_layer_forward_hook(l): 508 | def forward_hook(m, input, output): 509 | l.append(output) 510 | return forward_hook 511 | 512 | supervision_module_pairs = [] 513 | for ((mname, ml), (oriname, ori)) in zip(model.named_modules(), origin_model.named_modules()): 514 | if isinstance(ml, quant_nn.TensorQuantizer): continue 515 | 516 | if supervision_policy: 517 | if not supervision_policy(mname, ml): 518 | continue 519 | 520 | supervision_module_pairs.append([ml, ori]) 521 | 522 | 523 | for iepoch in range(nepochs): 524 | 525 | if iepoch in lrschedule: 526 | learningrate = lrschedule[iepoch] 527 | for g in optimizer.param_groups: 528 | g["lr"] = learningrate 529 | 530 | model_outputs = [] 531 | origin_outputs = [] 532 | remove_handle = [] 533 | 534 | 535 | 536 | for ml, ori in supervision_module_pairs: 537 | remove_handle.append(ml.register_forward_hook(make_layer_forward_hook(model_outputs))) 538 | remove_handle.append(ori.register_forward_hook(make_layer_forward_hook(origin_outputs))) 539 | 540 | model.train() 541 | pbar = tqdm(train_dataloader, desc="QAT", total=early_exit_batchs_per_epoch) 542 | for ibatch, imgs in enumerate(pbar): 543 | 544 | if ibatch >= early_exit_batchs_per_epoch: 545 | break 546 | 547 | if preprocess: 548 | imgs = preprocess(imgs) 549 | 550 | 551 | imgs = imgs.to(device) 552 | with amp.autocast(enabled=fp16): 553 | model(imgs) 554 | 555 | with torch.no_grad(): 556 | origin_model(imgs) 557 | 558 | quant_loss = 0 559 | for mo, fo in zip(model_outputs, origin_outputs): 560 | for m, f in zip(mo, fo): 561 | quant_loss += quant_lossfn(m, f) 562 | 563 | model_outputs.clear() 564 | origin_outputs.clear() 565 | 566 | if fp16: 567 | scaler.scale(quant_loss).backward() 568 | scaler.step(optimizer) 569 | scaler.update() 570 | else: 571 | quant_loss.backward() 572 | optimizer.step() 573 | optimizer.zero_grad() 574 | pbar.set_description(f"QAT Finetuning {iepoch + 1} / {nepochs}, Loss: {quant_loss.detach().item():.5f}, LR: {learningrate:g}") 575 | 576 | # You must remove hooks during onnx export or torch.save 577 | for rm in remove_handle: 578 | rm.remove() 579 | 580 | if per_epoch_callback: 581 | if per_epoch_callback(model, iepoch, learningrate): 582 | break 583 | 584 | 585 | def export_onnx(model, input, file, *args, **kwargs): 586 | quant_nn.TensorQuantizer.use_fb_fake_quant = True 587 | 588 | model.eval() 589 | with torch.no_grad(): 590 | torch.onnx.export(model, input, file, *args, **kwargs) 591 | 592 | quant_nn.TensorQuantizer.use_fb_fake_quant = False 593 | -------------------------------------------------------------------------------- /models/quantize_rules.py: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: MIT 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a 6 | # copy of this software and associated documentation files (the "Software"), 7 | # to deal in the Software without restriction, including without limitation 8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense, 9 | # and/or sell copies of the Software, and to permit persons to whom the 10 | # Software is furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in 13 | # all copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 21 | # DEALINGS IN THE SOFTWARE. 22 | ################################################################################ 23 | 24 | import onnx 25 | 26 | def find_with_input_node(model, name): 27 | for node in model.graph.node: 28 | if len(node.input) > 0 and name in node.input: 29 | return node 30 | 31 | def find_all_with_input_node(model, name): 32 | all = [] 33 | for node in model.graph.node: 34 | if len(node.input) > 0 and name in node.input: 35 | all.append(node) 36 | return all 37 | 38 | def find_with_output_node(model, name): 39 | for node in model.graph.node: 40 | if len(node.output) > 0 and name in node.output: 41 | return node 42 | 43 | """ 44 | def find_with_no_change_parent_node(model, node): 45 | parent = find_with_output_node(model, node.input[0]) 46 | if parent is not None: 47 | print("Parent:", parent.op_type) 48 | if parent.op_type in ["Concat", "MaxPool", "AveragePool", "Slice"]: 49 | return find_with_no_change_parent_node(model, parent) 50 | return parent 51 | """ 52 | 53 | def find_quantizelinear_conv(model, qnode): 54 | dq = find_with_input_node(model, qnode.output[0]) 55 | conv = find_with_input_node(model, dq.output[0]) 56 | return conv 57 | 58 | 59 | def find_quantize_conv_name(model, weight_qname): 60 | dq = find_with_output_node(model, weight_qname) 61 | q = find_with_output_node(model, dq.input[0]) 62 | return ".".join(q.input[0].split(".")[:-1]) 63 | 64 | def find_quantizer_pairs(onnx_file): 65 | 66 | model = onnx.load(onnx_file) 67 | match_pairs = [] 68 | for node in model.graph.node: 69 | if node.op_type == "Concat": 70 | qnodes = find_all_with_input_node(model, node.output[0]) 71 | major = None 72 | for qnode in qnodes: 73 | if qnode.op_type != "QuantizeLinear": 74 | continue 75 | conv = find_quantizelinear_conv(model, qnode) 76 | if major is None: 77 | major = find_quantize_conv_name(model, conv.input[1]) 78 | else: 79 | match_pairs.append([major, find_quantize_conv_name(model, conv.input[1])]) 80 | 81 | for subnode in model.graph.node: 82 | if len(subnode.input) > 0 and subnode.op_type == "QuantizeLinear" and subnode.input[0] in node.input: 83 | subconv = find_quantizelinear_conv(model, subnode) 84 | match_pairs.append([major, find_quantize_conv_name(model, subconv.input[1])]) 85 | 86 | 87 | elif node.op_type == "MaxPool": 88 | qnode = find_with_input_node(model, node.output[0]) 89 | if not (qnode and qnode.op_type == "QuantizeLinear"): 90 | continue 91 | major = find_quantizelinear_conv(model, qnode) 92 | major = find_quantize_conv_name(model, major.input[1]) 93 | same_input_nodes = find_all_with_input_node(model, node.input[0]) 94 | 95 | for same_input_node in same_input_nodes: 96 | if same_input_node.op_type == "QuantizeLinear": 97 | subconv = find_quantizelinear_conv(model, same_input_node) 98 | match_pairs.append([major, find_quantize_conv_name(model, subconv.input[1])]) 99 | 100 | return match_pairs 101 | -------------------------------------------------------------------------------- /patch_yolov9.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if the correct number of arguments were provided 4 | if [ "$#" -ne 1 ]; then 5 | echo "Usage: $0 " 6 | exit 1 7 | fi 8 | 9 | # YOLOv9 directory 10 | yolov9_dir="$1" 11 | 12 | # Check if the YOLOv9 directory exists 13 | if [ ! -d "$yolov9_dir" ]; then 14 | echo "Error: Directory '$yolov9_dir' not found." 15 | exit 1 16 | fi 17 | 18 | if [ ! -f "$yolov9_dir/models/experimental.py" ]; then 19 | echo "Error: '$yolov9_dir' does not appear to contain a valid YOLOv9 repository." 20 | echo "Please make sure '$yolov9_dir' is the root directory of a YOLOv9 repository." 21 | exit 1 22 | fi 23 | 24 | # Copy files to the YOLOv9 directory 25 | cp val_trt.py "$yolov9_dir/val_trt.py" && echo "qat.py patched successfully." 26 | cp qat.py "$yolov9_dir/qat.py" && echo "qat.py patched successfully." 27 | cp export_qat.py "$yolov9_dir/export_qat.py" && echo "export_qat.py patched successfully." 28 | cp models/quantize_rules.py "$yolov9_dir/models/quantize_rules.py" && echo "quantize_rules.py patched successfully." 29 | cp models/quantize.py "$yolov9_dir/models/quantize.py" && echo "quantize.py patched successfully." 30 | cp scripts/generate_trt_engine.sh "$yolov9_dir/scripts/generate_trt_engine.sh" && echo "generate_trt_engine.sh patched successfully." 31 | cp scripts/val_trt.sh "$yolov9_dir/scripts/val_trt.sh" && echo "val_trt.sh patched successfully." 32 | cp draw-engine.py "$yolov9_dir/draw-engine.py" && echo "draw-engine.py patched successfully." 33 | 34 | cp models/experimental_trt.py "$yolov9_dir/models/experimental_trt.py" && echo "experimental_trt.py patched successfully." 35 | cp segment/qat_seg.py "$yolov9_dir/segment/qat_seg.py" && echo "qat_seg.py patched successfully." 36 | 37 | echo "Patch applied successfully to YOLOv9 directory: $yolov9_dir" 38 | -------------------------------------------------------------------------------- /qat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import yaml 5 | import argparse 6 | import json 7 | from copy import deepcopy 8 | from pathlib import Path 9 | import warnings 10 | 11 | # PyTorch 12 | import torch 13 | import torch.nn as nn 14 | 15 | import val as validate 16 | from models.yolo import Model 17 | from models.common import Conv 18 | from utils.dataloaders import create_dataloader 19 | from utils.downloads import attempt_download 20 | 21 | from models.yolo import Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel 22 | import models.quantize as quantize 23 | 24 | from utils.general import (LOGGER, check_dataset, check_requirements, check_img_size, colorstr, init_seeds,increment_path,file_size) 25 | from utils.torch_utils import (torch_distributed_zero_first) 26 | 27 | warnings.filterwarnings("ignore") 28 | 29 | FILE = Path(__file__).resolve() 30 | ROOT = FILE.parents[0] # YOLO root directory 31 | if str(ROOT) not in sys.path: 32 | sys.path.append(str(ROOT)) # add ROOT to PATH 33 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 34 | 35 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html 36 | RANK = int(os.getenv('RANK', -1)) 37 | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) 38 | GIT_INFO = None 39 | 40 | 41 | class ReportTool: 42 | def __init__(self, file): 43 | self.file = file 44 | if os.path.exists(self.file): 45 | open(self.file, 'w').close() 46 | self.data = [] 47 | 48 | def load_data(self): 49 | try: 50 | return json.load(open(self.file, "r")) 51 | except FileNotFoundError: 52 | return [] 53 | 54 | def append(self, item): 55 | self.data.append(item) 56 | self.save_data() 57 | 58 | def update(self, item): 59 | for i, data_item in enumerate(self.data): 60 | if data_item[0] == item[0]: 61 | self.data[i] = item 62 | break 63 | else: 64 | # Se não encontrar, adiciona como um novo item 65 | self.append(item) 66 | self.save_data() 67 | 68 | def save_data(self): 69 | json.dump(self.data, open(self.file, "w"), indent=4) 70 | 71 | 72 | def load_model(weights, device) -> Model: 73 | with torch_distributed_zero_first(LOCAL_RANK): 74 | attempt_download(weights) 75 | model = torch.load(weights, map_location=device)["model"] 76 | for m in model.modules(): 77 | if type(m) is nn.Upsample: 78 | m.recompute_scale_factor = None # torch 1.11.0 compatibility 79 | elif type(m) is Conv: 80 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 81 | model.float() 82 | model.eval() 83 | with torch.no_grad(): 84 | model.fuse() 85 | return model 86 | 87 | 88 | 89 | def create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp_path): 90 | with open(hyp_path) as f: 91 | hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps 92 | loader = create_dataloader( 93 | train_path, 94 | imgsz=imgsz, 95 | batch_size=batch_size, 96 | single_cls=single_cls, 97 | augment=True, hyp=hyp, rect=False, cache=False, stride=stride, pad=0.0, image_weights=False)[0] 98 | return loader 99 | 100 | 101 | 102 | def create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride, keep_images=None): 103 | loader = create_dataloader( 104 | test_path, 105 | imgsz=imgsz, 106 | batch_size=batch_size, 107 | single_cls=single_cls, 108 | augment=False, hyp=None, rect=True, cache=False,stride=stride,pad=0.5, image_weights=False)[0] 109 | 110 | def subclass_len(self): 111 | if keep_images is not None: 112 | return keep_images 113 | return len(self.img_files) 114 | 115 | loader.dataset.__len__ = subclass_len 116 | return loader 117 | 118 | def evaluate_dataset(model_eval, val_loader, imgsz, data_dict, single_cls, save_dir, is_coco, conf_thres=0.001 , iou_thres=0.7 ): 119 | return validate.run(data_dict, 120 | model=model_eval, 121 | imgsz=imgsz, 122 | single_cls=single_cls, 123 | half=True, 124 | task='val', 125 | verbose=True, 126 | conf_thres=conf_thres, 127 | iou_thres=iou_thres, 128 | save_dir=save_dir, 129 | save_json=is_coco, 130 | dataloader=val_loader, 131 | )[0][:4] 132 | 133 | 134 | def export_onnx(model, file, im, opset=12, dynamic=False, prefix=colorstr('QAT ONNX:')): 135 | check_requirements('onnx') 136 | import onnx 137 | 138 | file = Path(file) 139 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') 140 | 141 | f = file.with_suffix('.onnx') 142 | output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'] 143 | model.eval() 144 | for k, m in model.named_modules(): 145 | # print(m) 146 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)): 147 | m.inplace = True 148 | m.dynamic = dynamic 149 | m.export = True 150 | dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) 151 | if isinstance(model, SegmentationModel): 152 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) 153 | dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) 154 | elif isinstance(model, DetectionModel): 155 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) 156 | 157 | quantize.export_onnx(model, im, file, opset_version=13, 158 | input_names=["images"], output_names=output_names, 159 | dynamic_axes=dynamic or None 160 | ) 161 | 162 | for k, m in model.named_modules(): 163 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)): 164 | m.inplace = True 165 | m.dynamic = False 166 | m.export = False 167 | 168 | 169 | def run_quantize(weights, data, imgsz, batch_size, hyp, device, save_dir, supervision_stride, iters, no_eval_origin, no_eval_ptq, prefix=colorstr('QAT:')): 170 | 171 | if not Path(weights).exists(): 172 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌') 173 | exit(1) 174 | 175 | quantize.initialize() 176 | 177 | with torch_distributed_zero_first(LOCAL_RANK): 178 | data_dict = check_dataset(data) 179 | 180 | w = save_dir / 'weights' # weights dir 181 | w.mkdir(parents=True, exist_ok=True) # make dir 182 | 183 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset 184 | 185 | nc = int(data_dict['nc']) # number of classes 186 | single_cls = False if nc > 1 else True 187 | names = data_dict['names'] # class names 188 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check 189 | train_path = data_dict['train'] 190 | test_path = data_dict['val'] 191 | 192 | result_eval_origin=None 193 | result_eval_ptq=None 194 | result_eval_qat_best=None 195 | 196 | device = torch.device(device) 197 | model = load_model(weights, device) 198 | 199 | if not isinstance(model, DetectionModel): 200 | model_name=model.__class__.__name__ 201 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌') 202 | exit(1) 203 | 204 | stride = max(int(model.stride.max()), 32) # grid size (max stride) 205 | imgsz = check_img_size(imgsz, s=stride) # check image size 206 | 207 | # conf onnx export 208 | exp_imgsz=[imgsz,imgsz] 209 | gs = int(max(model.stride)) # grid size (max stride) 210 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples 211 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection 212 | 213 | 214 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp) 215 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride) 216 | 217 | ### This rule is disabled - This allow user disable qat per Layers ### 218 | # This rule has been disabled, but it remains in the code to maintain compatibility or future implementation. 219 | """ 220 | ignore_layer=-1 221 | if ignore_layer > -1: 222 | ignore_policy=f"model\.{ignore_layer}\.cv\d+\.\d+\.\d+(\.conv)?" 223 | else: 224 | ignore_policy=f"model\.9999999999\.cv\d+\.\d+\.\d+(\.conv)?" 225 | """ 226 | ### End ####### 227 | 228 | quantize.replace_custom_module_forward(model) 229 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented 230 | quantize.apply_custom_rules_to_quantizer(model, lambda model, file: export_onnx(model, file, im)) 231 | quantize.calibrate_model(model, train_dataloader, device) 232 | 233 | report_file = os.path.join(save_dir, "report.json") 234 | report = ReportTool(report_file) 235 | 236 | if no_eval_origin: 237 | LOGGER.info(f'\n{prefix} Evaluating Origin...') 238 | model_eval = deepcopy(model).eval() 239 | with quantize.disable_quantization(model_eval): 240 | result_eval_origin = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco ) 241 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_origin) 242 | LOGGER.info(f'\n{prefix} Eval Origin - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 243 | report.append(["Origin", str(weights), eval_map, eval_map50,eval_mp, eval_mr ]) 244 | 245 | if no_eval_ptq: 246 | 247 | LOGGER.info(f'\n{prefix} Evaluating PTQ...') 248 | model_eval = deepcopy(model).eval() 249 | 250 | result_eval_ptq = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco ) 251 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_ptq) 252 | LOGGER.info(f'\n{prefix} Eval PTQ - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 253 | ptq_weights = w / f'ptq_ap_{eval_map}_{os.path.basename(weights)}' 254 | torch.save({"model": model_eval},f'{ptq_weights}') 255 | LOGGER.info(f'\n{prefix} PTQ, weights saved as {ptq_weights} ({file_size(ptq_weights):.1f} MB)') 256 | report.append(["PTQ", str(ptq_weights), eval_map, eval_map50,eval_mp, eval_mr ]) 257 | 258 | best_map = 0 259 | 260 | def per_epoch(model, epoch, lr): 261 | nonlocal best_map , result_eval_qat_best 262 | 263 | epoch +=1 264 | model_eval = deepcopy(model).eval() 265 | with torch.no_grad(): 266 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco ) 267 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 268 | qat_weights = w / f'qat_ep_{epoch}_ap_{eval_map}_{os.path.basename(weights)}' 269 | torch.save({"model": model_eval},f'{qat_weights}') 270 | LOGGER.info(f'\n{prefix} Epoch-{epoch}, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)') 271 | report.append([f"QAT-{epoch}", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ]) 272 | 273 | if eval_map > best_map: 274 | best_map = eval_map 275 | result_eval_qat_best=eval_result 276 | qat_weights = w / f'qat_best_{os.path.basename(weights)}' 277 | torch.save({"model": model_eval}, f'{qat_weights}') 278 | LOGGER.info(f'{prefix} QAT Best, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)') 279 | report.update(["QAT-Best", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ]) 280 | 281 | eval_results = [result_eval_origin, result_eval_ptq, result_eval_qat_best] 282 | 283 | LOGGER.info(f'\n\nEval Model | {"AP":<8} | {"AP50":<8} | {"Precision":<10} | {"Recall":<8}') 284 | LOGGER.info('-' * 55) 285 | for idx, eval_r in enumerate(eval_results): 286 | if eval_r is not None: 287 | eval_mp, eval_mr, eval_map50, eval_map = tuple(round(x, 3) for x in eval_r) 288 | if idx == 0: 289 | LOGGER.info(f'Origin | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}') 290 | if idx == 1: 291 | LOGGER.info(f'PTQ | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}') 292 | if idx == 2: 293 | LOGGER.info(f'QAT - Best | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}\n') 294 | 295 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 296 | LOGGER.info(f'\n{prefix} Eval - Epoch {epoch} | AP: {eval_map} | AP50: {eval_map50} | Precision: {eval_mp} | Recall: {eval_mr}\n') 297 | 298 | def preprocess(datas): 299 | return datas[0].to(device).float() / 255.0 300 | 301 | def supervision_policy(): 302 | supervision_list = [] 303 | for item in model.model: 304 | supervision_list.append(id(item)) 305 | 306 | keep_idx = list(range(0, len(model.model) - 1, supervision_stride)) 307 | keep_idx.append(len(model.model) - 2) 308 | def impl(name, module): 309 | if id(module) not in supervision_list: return False 310 | idx = supervision_list.index(id(module)) 311 | if idx in keep_idx: 312 | print(f"Supervision: {name} will compute loss with origin model during QAT training") 313 | else: 314 | print(f"Supervision: {name} no compute loss during QAT training, that is unsupervised only and doesn't mean don't learn") 315 | return idx in keep_idx 316 | return impl 317 | 318 | quantize.finetune( 319 | model, train_dataloader, per_epoch, early_exit_batchs_per_epoch=iters, 320 | preprocess=preprocess, supervision_policy=supervision_policy()) 321 | 322 | def run_sensitive_analysis(weights, device, data, imgsz, batch_size, hyp, save_dir, num_image, prefix=colorstr('QAT ANALYSIS:')): 323 | 324 | if not Path(weights).exists(): 325 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌') 326 | exit(1) 327 | 328 | save_dir = Path(save_dir) 329 | # Create the directory if it doesn't exist 330 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok) 331 | 332 | with torch_distributed_zero_first(LOCAL_RANK): 333 | data_dict = check_dataset(data) 334 | 335 | is_coco=False 336 | 337 | nc = int(data_dict['nc']) # number of classes 338 | single_cls = False if nc > 1 else True 339 | names = data_dict['names'] # class names 340 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check 341 | train_path = data_dict['train'] 342 | test_path = data_dict['val'] 343 | 344 | device = torch.device(device) 345 | model = load_model(weights, device) 346 | 347 | if not isinstance(model, DetectionModel) or isinstance(model, SegmentationModel): 348 | LOGGER.info(f'{prefix} " Model not supported. Only Detection Models is supported. ❌') 349 | exit(1) 350 | 351 | is_model_qat=False 352 | for i in range(0, len(model.model)): 353 | layer = model.model[i] 354 | if quantize.have_quantizer(layer): 355 | is_model_qat=True 356 | break 357 | 358 | if is_model_qat: 359 | LOGGER.info(f'{prefix} This model already quantized. Only not quantized models is allowed. ❌') 360 | exit(1) 361 | 362 | stride = max(int(model.stride.max()), 32) # grid size (max stride) 363 | imgsz = check_img_size(imgsz, s=stride) # check image size 364 | 365 | exp_imgsz=[imgsz,imgsz] 366 | gs = int(max(model.stride)) # grid size (max stride) 367 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples 368 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection 369 | 370 | 371 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp) 372 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride) 373 | quantize.initialize() 374 | quantize.replace_custom_module_forward(model) 375 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented 376 | quantize.calibrate_model(model, train_dataloader, device) 377 | 378 | report_file=os.path.join(save_dir , "summary-sensitive-analysis.json") 379 | report = ReportTool(report_file) 380 | 381 | model_eval = deepcopy(model).eval() 382 | LOGGER.info(f'\n{prefix} Evaluating PTQ...') 383 | 384 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco ) 385 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 386 | 387 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT enabled on All Layers - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 388 | report.append([eval_map, "PTQ"]) 389 | LOGGER.info(f'{prefix} Sensitive analysis by each layer. Layers Detected: {len(model.model)}') 390 | 391 | for i in range(0, len(model.model)): 392 | layer = model.model[i] 393 | if quantize.have_quantizer(layer): 394 | LOGGER.info(f'{prefix} QAT disabled on Layer model.{i}') 395 | quantize.disable_quantization(layer).apply() 396 | model_eval = deepcopy(model).eval() 397 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco ) 398 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 399 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT disabled on Layer model.{i} - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}\n') 400 | report.append([eval_map, f"model.{i}"]) 401 | quantize.enable_quantization(layer).apply() 402 | else: 403 | LOGGER.info(f'{prefix} Ignored Layer model.{i} because it is {type(layer)}') 404 | 405 | report = sorted(report.data, key=lambda x:x[0], reverse=True) 406 | print("Sensitive summary:") 407 | for n, (ap, name) in enumerate(report[:10]): 408 | print(f"Top{n}: Using fp16 {name}, ap = {ap:.5f}") 409 | 410 | 411 | def run_eval(weights, device, data, imgsz, batch_size, save_dir, conf_thres, iou_thres, prefix=colorstr('QAT TEST:')): 412 | 413 | if not Path(weights).exists(): 414 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌') 415 | exit(1) 416 | 417 | quantize.initialize() 418 | 419 | save_dir = Path(save_dir) 420 | # Create the directory if it doesn't exist 421 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok) 422 | 423 | with torch_distributed_zero_first(LOCAL_RANK): 424 | data_dict = check_dataset(data) 425 | 426 | 427 | device = torch.device(device) 428 | model = load_model(weights, device) 429 | 430 | if not isinstance(model, DetectionModel): 431 | model_name=model.__class__.__name__ 432 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌') 433 | exit(1) 434 | 435 | is_model_qat=False 436 | for i in range(0, len(model.model)): 437 | layer = model.model[i] 438 | if quantize.have_quantizer(layer): 439 | is_model_qat=True 440 | break 441 | 442 | if not is_model_qat: 443 | LOGGER.info(f'{prefix} This model was not Quantized. ❌') 444 | exit(1) 445 | 446 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset 447 | 448 | stride = max(int(model.stride.max()), 32) # grid size (max stride) 449 | imgsz = check_img_size(imgsz, s=stride) # check image size 450 | 451 | nc = int(data_dict['nc']) # number of classes 452 | single_cls = False if nc > 1 else True 453 | names = data_dict['names'] # class names 454 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check 455 | 456 | test_path = data_dict['val'] 457 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride) 458 | 459 | 460 | LOGGER.info(f'\n{prefix} Evaluating ...') 461 | model_eval = deepcopy(model).eval() 462 | 463 | result_eval = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, conf_thres=conf_thres, iou_thres=iou_thres ) 464 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval) 465 | LOGGER.info(f'\n{prefix} Eval Result - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 466 | LOGGER.info(f'\n{prefix} Eval Result, saved on {save_dir}') 467 | 468 | 469 | 470 | if __name__ == "__main__": 471 | parser = argparse.ArgumentParser(prog='qat.py') 472 | subps = parser.add_subparsers(dest="cmd") 473 | qat = subps.add_parser("quantize", help="PTQ/QAT finetune ...") 474 | 475 | qat.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='weights path') 476 | qat.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path') 477 | qat.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-high.yaml', help='hyperparameters path') 478 | qat.add_argument("--device", type=str, default="cuda:0", help="device") 479 | qat.add_argument('--batch-size', type=int, default=10, help='total batch size') 480 | qat.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') 481 | qat.add_argument('--project', default=ROOT / 'runs/qat', help='save to project/name') 482 | qat.add_argument('--name', default='exp', help='save to project/name') 483 | qat.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 484 | qat.add_argument("--iters", type=int, default=200, help="iters per epoch") 485 | qat.add_argument('--seed', type=int, default=57, help='Global training seed') 486 | qat.add_argument("--supervision-stride", type=int, default=1, help="supervision stride") 487 | qat.add_argument("--no-eval-origin", action="store_false", help="Disable eval for origin model") 488 | qat.add_argument("--no-eval-ptq", action="store_false", help="Disable eval for ptq model") 489 | 490 | sensitive = subps.add_parser("sensitive", help="Sensitive layer analysis") 491 | sensitive.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)') 492 | sensitive.add_argument("--device", type=str, default="cuda:0", help="device") 493 | sensitive.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path') 494 | sensitive.add_argument('--batch-size', type=int, default=10, help='total batch size') 495 | sensitive.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') 496 | sensitive.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch-high.yaml', help='hyperparameters path') 497 | sensitive.add_argument('--project', default=ROOT / 'runs/qat_sentive', help='save to project/name') 498 | sensitive.add_argument('--name', default='exp', help='save to project/name') 499 | sensitive.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 500 | sensitive.add_argument("--num-image", type=int, default=None, help="number of image to evaluate") 501 | 502 | testcmd = subps.add_parser("eval", help="Do evaluate") 503 | testcmd.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)') 504 | testcmd.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path') 505 | testcmd.add_argument('--batch-size', type=int, default=10, help='total batch size') 506 | testcmd.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='val image size (pixels)') 507 | testcmd.add_argument("--device", type=str, default="cuda:0", help="device") 508 | testcmd.add_argument("--conf-thres", type=float, default=0.001, help="confidence threshold") 509 | testcmd.add_argument("--iou-thres", type=float, default=0.7, help="nms threshold") 510 | testcmd.add_argument('--project', default=ROOT / 'runs/qat_eval', help='save to project/name') 511 | testcmd.add_argument('--name', default='exp', help='save to project/name') 512 | testcmd.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 513 | 514 | opt = parser.parse_args() 515 | if opt.cmd == "quantize": 516 | print(opt) 517 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) 518 | init_seeds(opt.seed + 1 + RANK, deterministic=False) 519 | 520 | run_quantize( 521 | opt.weights, opt.data, opt.imgsz, opt.batch_size, 522 | opt.hyp, opt.device, Path(opt.save_dir), 523 | opt.supervision_stride, opt.iters, 524 | opt.no_eval_origin, opt.no_eval_ptq 525 | ) 526 | 527 | elif opt.cmd == "sensitive": 528 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) 529 | print(opt) 530 | run_sensitive_analysis(opt.weights, opt.device, opt.data, 531 | opt.imgsz, opt.batch_size, opt.hyp, 532 | opt.save_dir, opt.num_image 533 | ) 534 | elif opt.cmd == "eval": 535 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) 536 | print(opt) 537 | run_eval(opt.weights, opt.device, opt.data, 538 | opt.imgsz, opt.batch_size, opt.save_dir, 539 | opt.conf_thres, opt.iou_thres 540 | ) 541 | else: 542 | parser.print_help() -------------------------------------------------------------------------------- /scripts/generate_trt_engine.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Function to display usage message 4 | usage() { 5 | echo "Usage: $0 [--generate-graph]" 6 | exit 1 7 | } 8 | 9 | # Check if the correct number of arguments are provided 10 | if [ "$#" -lt 2 ] || [ "$#" -gt 3 ]; then 11 | usage 12 | fi 13 | 14 | function get_free_gpu_memory() { 15 | # Get the total memory and used memory from nvidia-smi for GPU 0 16 | local total_memory=$(nvidia-smi --id=0 --query-gpu=memory.total --format=csv,noheader,nounits | awk '{print $1}') 17 | local used_memory=$(nvidia-smi --id=0 --query-gpu=memory.used --format=csv,noheader,nounits | awk '{print $1}') 18 | 19 | # Calculate free memory 20 | local free_memory=$((total_memory - used_memory)) 21 | echo "$free_memory" 22 | } 23 | 24 | workspace=$(get_free_gpu_memory) 25 | 26 | # Set default values 27 | generate_graph=false 28 | 29 | # Parse command line arguments 30 | onnx="$1" 31 | image_size="$2" 32 | stride=32 33 | network_size=$((image_size + stride)) 34 | shape=1x3x${network_size}x${network_size} 35 | 36 | file_no_ext="${onnx%.*}" 37 | 38 | # Generate engine and graph file paths 39 | trt_engine="$file_no_ext.engine" 40 | graph="$trt_engine.layer.json" 41 | profile="$trt_engine.profile.json" 42 | timing="$trt_engine.timing.json" 43 | timing_cache="$trt_engine.timing.cache" 44 | 45 | 46 | # Check if optional flag --generate-graph is provided 47 | if [ "$3" == "--generate-graph" ]; then 48 | generate_graph=true 49 | fi 50 | 51 | # Run trtexec command to generate engine and graph files 52 | if [ "$generate_graph" = true ]; then 53 | trtexec --onnx="${onnx}" \ 54 | --saveEngine="${trt_engine}" \ 55 | --fp16 --int8 \ 56 | --useCudaGraph \ 57 | --separateProfileRun \ 58 | --useSpinWait \ 59 | --profilingVerbosity=detailed \ 60 | --minShapes=images:$shape \ 61 | --optShapes=images:$shape \ 62 | --maxShapes=images:$shape \ 63 | --memPoolSize=workspace:${workspace}MiB \ 64 | --dumpLayerInfo \ 65 | --exportTimes="${timing}" \ 66 | --exportLayerInfo="${graph}" \ 67 | --exportProfile="${profile}" \ 68 | --timingCacheFile="${timing_cache}" 69 | 70 | # Profiling affects the performance of your kernel! 71 | # Always run and time without profiling. 72 | 73 | else 74 | trtexec --onnx="${onnx}" \ 75 | --saveEngine="${trt_engine}" \ 76 | --fp16 --int8 \ 77 | --useCudaGraph \ 78 | --minShapes=images:$shape \ 79 | --optShapes=images:$shape \ 80 | --maxShapes=images:$shape \ 81 | --memPoolSize=workspace:${workspace}MiB \ 82 | --timingCacheFile="${timing_cache}" 83 | fi 84 | 85 | # Check if trtexec command was successful 86 | if [ $? -eq 0 ]; then 87 | echo "Engine file generated successfully: ${trt_engine}" 88 | if [ "$generate_graph" = true ]; then 89 | echo "Graph file generated successfully: ${graph}" 90 | echo "Profile file generated successfully: ${profile}" 91 | fi 92 | else 93 | echo "Failed to generate engine file." 94 | exit 1 95 | fi 96 | 97 | -------------------------------------------------------------------------------- /scripts/val_trt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | usage() { 5 | echo "Usage: $0 --generate-graph" 6 | exit 1 7 | } 8 | 9 | # Check if all required arguments are provided 10 | if [ "$#" -lt 3 -o "$#" -gt 4 ]; then 11 | usage 12 | exit 1 13 | fi 14 | 15 | weight=$1 16 | data=$2 17 | img_size=$3 18 | generate_graph=false 19 | 20 | file_no_ext="${weight%.*}" 21 | 22 | # Generate engine and graph file paths 23 | onnx_file="$file_no_ext.onnx" 24 | trt_engine="$file_no_ext.engine" 25 | graph="$trt_engine.layer.json" 26 | profile="$trt_engine.profile.json" 27 | 28 | 29 | # Check if optional flag --generate-graph is provided 30 | if [ "$4" == "--generate-graph" ]; then 31 | generate_graph=true 32 | fi 33 | 34 | # Check if weight file exists 35 | if [ ! -f "$weight" ]; then 36 | echo "Error: Weight file '$weight' not found." 37 | exit 1 38 | fi 39 | 40 | # Run the script 41 | python3 export_qat.py --weights "$weight" --include onnx --inplace --dynamic --simplify 42 | 43 | echo -e "\n" 44 | 45 | # Check if ONNX file was successfully generated 46 | if [ ! -f $onnx_file ]; then 47 | echo "Error: ONNX file $onnx_file not generated." 48 | exit 1 49 | fi 50 | 51 | # Run the script 52 | if [ "$generate_graph" = true ]; then 53 | bash scripts/generate_trt_engine.sh $onnx_file $img_size --generate-graph 54 | else 55 | bash scripts/generate_trt_engine.sh $onnx_file $img_size 56 | fi 57 | 58 | if [ $? -ne 0 ]; then 59 | exit 1 60 | fi 61 | 62 | if [ "$generate_graph" = true ]; then 63 | # Check if Graph file exists 64 | if [ ! -f "$graph" ]; then 65 | echo "Error: Graph file $graph not found." 66 | exit 1 67 | fi 68 | 69 | # Check if Graph file exists 70 | if [ ! -f "$profile" ]; then 71 | echo "Error: Graph file $profile not found." 72 | exit 1 73 | fi 74 | 75 | # Run the script 76 | /bin/bash -c "source /opt/nvidia_trex/env_trex/bin/activate && python3 draw-engine.py --layer $graph --profile $profile" 77 | fi 78 | 79 | # Run the script 80 | python3 val_trt.py --engine-file $trt_engine --data $data 81 | 82 | 83 | -------------------------------------------------------------------------------- /segment/qat_seg.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | import yaml 5 | import argparse 6 | import json 7 | from copy import deepcopy 8 | from pathlib import Path 9 | import warnings 10 | 11 | # PyTorch 12 | import torch 13 | import torch.nn as nn 14 | 15 | import val as validate 16 | from models.yolo import Model 17 | from models.common import Conv 18 | from utils.segment.dataloaders import create_dataloader 19 | from utils.downloads import attempt_download 20 | 21 | from models.yolo import Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel 22 | import models.quantize as quantize 23 | 24 | from utils.general import (LOGGER, check_dataset, check_requirements, check_img_size, colorstr, init_seeds,increment_path,file_size) 25 | from utils.torch_utils import (torch_distributed_zero_first) 26 | 27 | warnings.filterwarnings("ignore") 28 | 29 | FILE = Path(__file__).resolve() 30 | ROOT = FILE.parents[0] # YOLO root directory 31 | if str(ROOT) not in sys.path: 32 | sys.path.append(str(ROOT)) # add ROOT to PATH 33 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 34 | 35 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html 36 | RANK = int(os.getenv('RANK', -1)) 37 | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) 38 | GIT_INFO = None 39 | 40 | 41 | class ReportTool: 42 | def __init__(self, file): 43 | self.file = file 44 | if os.path.exists(self.file): 45 | open(self.file, 'w').close() 46 | self.data = [] 47 | 48 | def load_data(self): 49 | try: 50 | return json.load(open(self.file, "r")) 51 | except FileNotFoundError: 52 | return [] 53 | 54 | def append(self, item): 55 | self.data.append(item) 56 | self.save_data() 57 | 58 | def update(self, item): 59 | for i, data_item in enumerate(self.data): 60 | if data_item[0] == item[0]: 61 | self.data[i] = item 62 | break 63 | else: 64 | # Se não encontrar, adiciona como um novo item 65 | self.append(item) 66 | self.save_data() 67 | 68 | def save_data(self): 69 | json.dump(self.data, open(self.file, "w"), indent=4) 70 | 71 | 72 | def load_model(weights, device) -> Model: 73 | with torch_distributed_zero_first(LOCAL_RANK): 74 | attempt_download(weights) 75 | model = torch.load(weights, map_location=device)["model"] 76 | for m in model.modules(): 77 | if type(m) is nn.Upsample: 78 | m.recompute_scale_factor = None # torch 1.11.0 compatibility 79 | elif type(m) is Conv: 80 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 81 | model.float() 82 | model.eval() 83 | with torch.no_grad(): 84 | model.fuse() 85 | return model 86 | 87 | 88 | 89 | def create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp_path, mask_ratio, overlap): 90 | with open(hyp_path) as f: 91 | hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps 92 | loader = create_dataloader( 93 | train_path, 94 | imgsz=imgsz, 95 | batch_size=batch_size, 96 | single_cls=single_cls, 97 | augment=True, 98 | hyp=hyp, 99 | rect=False, 100 | cache=False, 101 | stride=stride, 102 | pad=0.0, 103 | image_weights=False, 104 | shuffle=True, 105 | mask_downsample_ratio=mask_ratio, 106 | overlap_mask=overlap)[0] 107 | return loader 108 | 109 | 110 | 111 | 112 | def create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride, keep_images, mask_ratio,overlap ): 113 | loader = create_dataloader( 114 | test_path, 115 | imgsz=imgsz, 116 | batch_size=batch_size, 117 | single_cls=single_cls, 118 | augment=False, 119 | hyp=None, 120 | rect=True, 121 | cache=False, 122 | stride=stride,pad=0.5, 123 | image_weights=False, 124 | mask_downsample_ratio=mask_ratio, 125 | overlap_mask=overlap)[0] 126 | 127 | def subclass_len(self): 128 | if keep_images is not None: 129 | return keep_images 130 | return len(self.img_files) 131 | 132 | loader.dataset.__len__ = subclass_len 133 | return loader 134 | 135 | def evaluate_dataset(model_eval, val_loader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap, conf_thres=0.001 , iou_thres=0.65 ): 136 | return validate.run(data_dict, 137 | model=model_eval, 138 | imgsz=imgsz, 139 | single_cls=single_cls, 140 | half=True, 141 | task='val', 142 | verbose=True, 143 | conf_thres=conf_thres, 144 | iou_thres=iou_thres, 145 | save_dir=save_dir, 146 | save_json=is_coco, 147 | dataloader=val_loader, 148 | mask_downsample_ratio=mask_ratio, 149 | overlap=overlap 150 | )[0][:4] 151 | 152 | 153 | def export_onnx(model, file, im, opset=12, dynamic=False, prefix=colorstr('QAT ONNX:')): 154 | check_requirements('onnx') 155 | import onnx 156 | 157 | file = Path(file) 158 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...') 159 | 160 | f = file.with_suffix('.onnx') 161 | output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0'] 162 | model.eval() 163 | for k, m in model.named_modules(): 164 | # print(m) 165 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)): 166 | m.inplace = True 167 | m.dynamic = dynamic 168 | m.export = True 169 | dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640) 170 | if isinstance(model, SegmentationModel): 171 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) 172 | dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160) 173 | elif isinstance(model, DetectionModel): 174 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85) 175 | 176 | quantize.export_onnx(model, im, file, opset_version=13, 177 | input_names=["images"], output_names=output_names, 178 | dynamic_axes=dynamic or None 179 | ) 180 | 181 | for k, m in model.named_modules(): 182 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)): 183 | m.inplace = True 184 | m.dynamic = False 185 | m.export = False 186 | 187 | 188 | def run_quantize(weights, data, imgsz, batch_size, hyp, device, save_dir, supervision_stride, iters, no_eval_origin, no_eval_ptq, mask_ratio, overlap, prefix=colorstr('QAT:')): 189 | 190 | if not Path(weights).exists(): 191 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌') 192 | exit(1) 193 | 194 | quantize.initialize() 195 | 196 | with torch_distributed_zero_first(LOCAL_RANK): 197 | data_dict = check_dataset(data) 198 | 199 | w = save_dir / 'weights' # weights dir 200 | w.mkdir(parents=True, exist_ok=True) # make dir 201 | 202 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset 203 | 204 | 205 | nc = int(data_dict['nc']) # number of classes 206 | single_cls = False if nc > 1 else True 207 | names = data_dict['names'] # class names 208 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check 209 | train_path = data_dict['train'] 210 | test_path = data_dict['val'] 211 | 212 | result_eval_origin=None 213 | result_eval_ptq=None 214 | result_eval_qat_best=None 215 | 216 | device = torch.device(device) 217 | model = load_model(weights, device) 218 | 219 | if not isinstance(model, SegmentationModel): 220 | model_name=model.__class__.__name__ 221 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌') 222 | exit(1) 223 | 224 | stride = max(int(model.stride.max()), 32) # grid size (max stride) 225 | imgsz = check_img_size(imgsz, s=stride) # check image size 226 | 227 | # conf onnx export 228 | exp_imgsz=[imgsz,imgsz] 229 | gs = int(max(model.stride)) # grid size (max stride) 230 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples 231 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection 232 | 233 | 234 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp, mask_ratio, overlap) 235 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride, None, mask_ratio, overlap) 236 | 237 | ### This rule is disabled - This allow user disable qat per Layers ### 238 | # This rule has been disabled, but it remains in the code to maintain compatibility or future implementation. 239 | """ 240 | ignore_layer=-1 241 | if ignore_layer > -1: 242 | ignore_policy=f"model\.{ignore_layer}\.cv\d+\.\d+\.\d+(\.conv)?" 243 | else: 244 | ignore_policy=f"model\.9999999999\.cv\d+\.\d+\.\d+(\.conv)?" 245 | """ 246 | ### End ####### 247 | 248 | quantize.replace_custom_module_forward(model) 249 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented 250 | quantize.apply_custom_rules_to_quantizer(model, lambda model, file: export_onnx(model, file, im)) 251 | quantize.calibrate_model(model, train_dataloader, device) 252 | 253 | report_file = os.path.join(save_dir, "report.json") 254 | report = ReportTool(report_file) 255 | 256 | if no_eval_origin: 257 | LOGGER.info(f'\n{prefix} Evaluating Origin...') 258 | model_eval = deepcopy(model).eval() 259 | with quantize.disable_quantization(model_eval): 260 | result_eval_origin = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap ) 261 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_origin) 262 | LOGGER.info(f'\n{prefix} Eval Origin - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 263 | report.append(["Origin", str(weights), eval_map, eval_map50,eval_mp, eval_mr ]) 264 | 265 | if no_eval_ptq: 266 | 267 | LOGGER.info(f'\n{prefix} Evaluating PTQ...') 268 | model_eval = deepcopy(model).eval() 269 | 270 | result_eval_ptq = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap ) 271 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_ptq) 272 | LOGGER.info(f'\n{prefix} Eval PTQ - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 273 | ptq_weights = w / f'ptq_ap_{eval_map}_{os.path.basename(weights)}' 274 | torch.save({"model": model_eval},f'{ptq_weights}') 275 | LOGGER.info(f'\n{prefix} PTQ, weights saved as {ptq_weights} ({file_size(ptq_weights):.1f} MB)') 276 | report.append(["PTQ", str(ptq_weights), eval_map, eval_map50,eval_mp, eval_mr ]) 277 | 278 | best_map = 0 279 | 280 | def per_epoch(model, epoch, lr): 281 | nonlocal best_map , result_eval_qat_best 282 | 283 | epoch +=1 284 | model_eval = deepcopy(model).eval() 285 | with torch.no_grad(): 286 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap ) 287 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 288 | qat_weights = w / f'qat_ep_{epoch}_ap_{eval_map}_{os.path.basename(weights)}' 289 | torch.save({"model": model_eval},f'{qat_weights}') 290 | LOGGER.info(f'\n{prefix} Epoch-{epoch}, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)') 291 | report.append([f"QAT-{epoch}", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ]) 292 | 293 | if eval_map > best_map: 294 | best_map = eval_map 295 | result_eval_qat_best=eval_result 296 | qat_weights = w / f'qat_best_{os.path.basename(weights)}' 297 | torch.save({"model": model_eval}, f'{qat_weights}') 298 | LOGGER.info(f'{prefix} QAT Best, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)') 299 | report.update(["QAT-Best", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ]) 300 | 301 | eval_results = [result_eval_origin, result_eval_ptq, result_eval_qat_best] 302 | 303 | LOGGER.info(f'\n\nEval Model | {"AP":<8} | {"AP50":<8} | {"Precision":<10} | {"Recall":<8}') 304 | LOGGER.info('-' * 55) 305 | for idx, eval_r in enumerate(eval_results): 306 | if eval_r is not None: 307 | eval_mp, eval_mr, eval_map50, eval_map = tuple(round(x, 3) for x in eval_r) 308 | if idx == 0: 309 | LOGGER.info(f'Origin | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}') 310 | if idx == 1: 311 | LOGGER.info(f'PTQ | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}') 312 | if idx == 2: 313 | LOGGER.info(f'QAT - Best | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}\n') 314 | 315 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 316 | LOGGER.info(f'\n{prefix} Eval - Epoch {epoch} | AP: {eval_map} | AP50: {eval_map50} | Precision: {eval_mp} | Recall: {eval_mr}\n') 317 | 318 | def preprocess(datas): 319 | return datas[0].to(device).float() / 255.0 320 | 321 | def supervision_policy(): 322 | supervision_list = [] 323 | for item in model.model: 324 | supervision_list.append(id(item)) 325 | 326 | keep_idx = list(range(0, len(model.model) - 1, supervision_stride)) 327 | keep_idx.append(len(model.model) - 2) 328 | def impl(name, module): 329 | if id(module) not in supervision_list: return False 330 | idx = supervision_list.index(id(module)) 331 | if idx in keep_idx: 332 | print(f"Supervision: {name} will compute loss with origin model during QAT training") 333 | else: 334 | print(f"Supervision: {name} no compute loss during QAT training, that is unsupervised only and doesn't mean don't learn") 335 | return idx in keep_idx 336 | return impl 337 | 338 | quantize.finetune( 339 | model, train_dataloader, per_epoch, early_exit_batchs_per_epoch=iters, 340 | preprocess=preprocess, supervision_policy=supervision_policy()) 341 | 342 | def run_sensitive_analysis(weights, device, data, imgsz, batch_size, hyp, save_dir, num_image, prefix=colorstr('QAT ANALYSIS:')): 343 | 344 | if not Path(weights).exists(): 345 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌') 346 | exit(1) 347 | 348 | save_dir = Path(save_dir) 349 | # Create the directory if it doesn't exist 350 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok) 351 | 352 | with torch_distributed_zero_first(LOCAL_RANK): 353 | data_dict = check_dataset(data) 354 | 355 | is_coco=False 356 | 357 | nc = int(data_dict['nc']) # number of classes 358 | single_cls = False if nc > 1 else True 359 | names = data_dict['names'] # class names 360 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check 361 | train_path = data_dict['train'] 362 | test_path = data_dict['val'] 363 | 364 | device = torch.device(device) 365 | model = load_model(weights, device) 366 | 367 | if not isinstance(model, DetectionModel) or isinstance(model, SegmentationModel): 368 | LOGGER.info(f'{prefix} " Model not supported. Only Detection Models is supported. ❌') 369 | exit(1) 370 | 371 | is_model_qat=False 372 | for i in range(0, len(model.model)): 373 | layer = model.model[i] 374 | if quantize.have_quantizer(layer): 375 | is_model_qat=True 376 | break 377 | 378 | if is_model_qat: 379 | LOGGER.info(f'{prefix} This model already quantized. Only not quantized models is allowed. ❌') 380 | exit(1) 381 | 382 | stride = max(int(model.stride.max()), 32) # grid size (max stride) 383 | imgsz = check_img_size(imgsz, s=stride) # check image size 384 | 385 | exp_imgsz=[imgsz,imgsz] 386 | gs = int(max(model.stride)) # grid size (max stride) 387 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples 388 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection 389 | 390 | 391 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp) 392 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride) 393 | quantize.initialize() 394 | quantize.replace_custom_module_forward(model) 395 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented 396 | quantize.calibrate_model(model, train_dataloader, device) 397 | 398 | report_file=os.path.join(save_dir , "summary-sensitive-analysis.json") 399 | report = ReportTool(report_file) 400 | 401 | model_eval = deepcopy(model).eval() 402 | LOGGER.info(f'\n{prefix} Evaluating PTQ...') 403 | 404 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap ) 405 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 406 | 407 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT enabled on All Layers - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 408 | report.append([eval_map, "PTQ"]) 409 | LOGGER.info(f'{prefix} Sensitive analysis by each layer. Layers Detected: {len(model.model)}') 410 | 411 | for i in range(0, len(model.model)): 412 | layer = model.model[i] 413 | if quantize.have_quantizer(layer): 414 | LOGGER.info(f'{prefix} QAT disabled on Layer model.{i}') 415 | quantize.disable_quantization(layer).apply() 416 | model_eval = deepcopy(model).eval() 417 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap ) 418 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result) 419 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT disabled on Layer model.{i} - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}\n') 420 | report.append([eval_map, f"model.{i}"]) 421 | quantize.enable_quantization(layer).apply() 422 | else: 423 | LOGGER.info(f'{prefix} Ignored Layer model.{i} because it is {type(layer)}') 424 | 425 | report = sorted(report.data, key=lambda x:x[0], reverse=True) 426 | print("Sensitive summary:") 427 | for n, (ap, name) in enumerate(report[:10]): 428 | print(f"Top{n}: Using fp16 {name}, ap = {ap:.5f}") 429 | 430 | 431 | def run_eval(weights, device, data, imgsz, batch_size, save_dir, conf_thres, iou_thres, prefix=colorstr('QAT TEST:')): 432 | 433 | if not Path(weights).exists(): 434 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌') 435 | exit(1) 436 | 437 | quantize.initialize() 438 | 439 | save_dir = Path(save_dir) 440 | # Create the directory if it doesn't exist 441 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok) 442 | 443 | with torch_distributed_zero_first(LOCAL_RANK): 444 | data_dict = check_dataset(data) 445 | 446 | 447 | device = torch.device(device) 448 | model = load_model(weights, device) 449 | 450 | if not isinstance(model, DetectionModel): 451 | model_name=model.__class__.__name__ 452 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌') 453 | exit(1) 454 | 455 | is_model_qat=False 456 | for i in range(0, len(model.model)): 457 | layer = model.model[i] 458 | if quantize.have_quantizer(layer): 459 | is_model_qat=True 460 | break 461 | 462 | if not is_model_qat: 463 | LOGGER.info(f'{prefix} This model was not Quantized. ❌') 464 | exit(1) 465 | 466 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset 467 | 468 | stride = max(int(model.stride.max()), 32) # grid size (max stride) 469 | imgsz = check_img_size(imgsz, s=stride) # check image size 470 | 471 | nc = int(data_dict['nc']) # number of classes 472 | single_cls = False if nc > 1 else True 473 | names = data_dict['names'] # class names 474 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check 475 | 476 | test_path = data_dict['val'] 477 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride) 478 | 479 | 480 | LOGGER.info(f'\n{prefix} Evaluating ...') 481 | model_eval = deepcopy(model).eval() 482 | 483 | result_eval = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap ,conf_thres=conf_thres, iou_thres=iou_thres ) 484 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval) 485 | LOGGER.info(f'\n{prefix} Eval Result - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 486 | LOGGER.info(f'\n{prefix} Eval Result, saved on {save_dir}') 487 | 488 | 489 | 490 | if __name__ == "__main__": 491 | parser = argparse.ArgumentParser(prog='qat.py') 492 | subps = parser.add_subparsers(dest="cmd") 493 | qat = subps.add_parser("quantize", help="PTQ/QAT finetune ...") 494 | 495 | qat.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='weights path') 496 | qat.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path') 497 | qat.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-high.yaml', help='hyperparameters path') 498 | qat.add_argument("--device", type=str, default="cuda:0", help="device") 499 | qat.add_argument('--batch-size', type=int, default=10, help='total batch size') 500 | qat.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') 501 | qat.add_argument('--project', default=ROOT / 'runs/qat', help='save to project/name') 502 | qat.add_argument('--name', default='exp', help='save to project/name') 503 | qat.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 504 | qat.add_argument("--iters", type=int, default=200, help="iters per epoch") 505 | qat.add_argument('--seed', type=int, default=57, help='Global training seed') 506 | qat.add_argument("--supervision-stride", type=int, default=1, help="supervision stride") 507 | qat.add_argument("--no-eval-origin", action="store_false", help="Disable eval for origin model") 508 | qat.add_argument("--no-eval-ptq", action="store_false", help="Disable eval for ptq model") 509 | qat.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory') 510 | qat.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP') 511 | 512 | sensitive = subps.add_parser("sensitive", help="Sensitive layer analysis") 513 | sensitive.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)') 514 | sensitive.add_argument("--device", type=str, default="cuda:0", help="device") 515 | sensitive.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path') 516 | sensitive.add_argument('--batch-size', type=int, default=10, help='total batch size') 517 | sensitive.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)') 518 | sensitive.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch-high.yaml', help='hyperparameters path') 519 | sensitive.add_argument('--project', default=ROOT / 'runs/qat_sentive', help='save to project/name') 520 | sensitive.add_argument('--name', default='exp', help='save to project/name') 521 | sensitive.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 522 | sensitive.add_argument("--num-image", type=int, default=None, help="number of image to evaluate") 523 | 524 | testcmd = subps.add_parser("eval", help="Do evaluate") 525 | testcmd.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)') 526 | testcmd.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path') 527 | testcmd.add_argument('--batch-size', type=int, default=10, help='total batch size') 528 | testcmd.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='val image size (pixels)') 529 | testcmd.add_argument("--device", type=str, default="cuda:0", help="device") 530 | testcmd.add_argument("--conf-thres", type=float, default=0.001, help="confidence threshold") 531 | testcmd.add_argument("--iou-thres", type=float, default=0.7, help="nms threshold") 532 | testcmd.add_argument('--project', default=ROOT / 'runs/qat_eval', help='save to project/name') 533 | testcmd.add_argument('--name', default='exp', help='save to project/name') 534 | testcmd.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 535 | 536 | opt = parser.parse_args() 537 | if opt.cmd == "quantize": 538 | print(opt) 539 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) 540 | init_seeds(opt.seed + 1 + RANK, deterministic=False) 541 | 542 | run_quantize( 543 | opt.weights, opt.data, opt.imgsz, opt.batch_size, 544 | opt.hyp, opt.device, Path(opt.save_dir), 545 | opt.supervision_stride, opt.iters, 546 | opt.no_eval_origin, opt.no_eval_ptq,opt.mask_ratio, opt.no_overlap 547 | ) 548 | 549 | elif opt.cmd == "sensitive": 550 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) 551 | print(opt) 552 | run_sensitive_analysis(opt.weights, opt.device, opt.data, 553 | opt.imgsz, opt.batch_size, opt.hyp, 554 | opt.save_dir, opt.num_image 555 | ) 556 | elif opt.cmd == "eval": 557 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok)) 558 | print(opt) 559 | run_eval(opt.weights, opt.device, opt.data, 560 | opt.imgsz, opt.batch_size, opt.save_dir, 561 | opt.conf_thres, opt.iou_thres 562 | ) 563 | else: 564 | parser.print_help() -------------------------------------------------------------------------------- /val_trt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch 9 | from tqdm import tqdm 10 | 11 | FILE = Path(__file__).resolve() 12 | ROOT = FILE.parents[0] # YOLO root directory 13 | if str(ROOT) not in sys.path: 14 | sys.path.append(str(ROOT)) # add ROOT to PATH 15 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative 16 | 17 | from utils.callbacks import Callbacks 18 | from utils.dataloaders import create_dataloader 19 | from utils.general import (LOGGER, TQDM_BAR_FORMAT, Profile, check_dataset, check_requirements, 20 | check_yaml, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, 21 | print_args, scale_boxes, xywh2xyxy, xyxy2xywh) 22 | from utils.metrics import ConfusionMatrix, ap_per_class, box_iou 23 | from utils.plots import output_to_target, plot_images, plot_val_study 24 | from utils.torch_utils import smart_inference_mode 25 | 26 | 27 | import pycuda.autoinit 28 | import pycuda.driver as cuda 29 | import tensorrt as trt 30 | 31 | class HostDeviceMem(object): 32 | def __init__(self, host_mem, device_mem): 33 | self.host = host_mem 34 | self.device = device_mem 35 | 36 | def __str__(self): 37 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) 38 | 39 | def __repr__(self): 40 | return self.__str__() 41 | 42 | ## https://docs.nvidia.com/deeplearning/tensorrt/migration-guide/index.html 43 | 44 | def do_inference(context, inputs, outputs, stream): 45 | # Transfer input data to the GPU. 46 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] 47 | 48 | # Run inference. 49 | context.execute_async_v3(stream_handle=stream.handle) 50 | 51 | # Transfer predictions back from the GPU. 52 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] 53 | 54 | # Synchronize the stream 55 | stream.synchronize() 56 | 57 | # Return only the host outputs. 58 | return [out.host for out in outputs] 59 | 60 | 61 | def allocate_buffers(engine): 62 | ''' 63 | Allocates all buffers required for an engine, i.e. host/device inputs/outputs. 64 | ''' 65 | inputs = [] 66 | outputs = [] 67 | bindings = [] 68 | stream = cuda.Stream() 69 | 70 | for i in range(engine.num_io_tensors): 71 | tensor_name = engine.get_tensor_name(i) 72 | size = trt.volume(engine.get_tensor_shape(tensor_name)) 73 | dtype = trt.nptype(engine.get_tensor_dtype(tensor_name)) 74 | 75 | # Allocate host and device buffers 76 | host_mem = cuda.pagelocked_empty(size, dtype) # page-locked memory buffer (won't swapped to disk) 77 | device_mem = cuda.mem_alloc(host_mem.nbytes) 78 | 79 | # Append the device buffer address to device bindings. 80 | # When cast to int, it's a linear index into the context's memory (like memory address). 81 | bindings.append(int(device_mem)) 82 | 83 | # Append to the appropriate input/output list. 84 | if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT: 85 | inputs.append(HostDeviceMem(host_mem, device_mem)) 86 | else: 87 | outputs.append(HostDeviceMem(host_mem, device_mem)) 88 | 89 | return inputs, outputs, bindings, stream 90 | 91 | def save_one_txt(predn, save_conf, shape, file): 92 | # Save one txt result 93 | gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh 94 | for *xyxy, conf, cls in predn.tolist(): 95 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh 96 | line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format 97 | with open(file, 'a') as f: 98 | f.write(('%g ' * len(line)).rstrip() % line + '\n') 99 | 100 | 101 | def save_one_json(predn, jdict, path, class_map): 102 | # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236} 103 | image_id = int(path.stem) if path.stem.isnumeric() else path.stem 104 | box = xyxy2xywh(predn[:, :4]) # xywh 105 | box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner 106 | for p, b in zip(predn.tolist(), box.tolist()): 107 | jdict.append({ 108 | 'image_id': image_id, 109 | 'category_id': class_map[int(p[5])], 110 | 'bbox': [round(x, 3) for x in b], 111 | 'score': round(p[4], 5)}) 112 | 113 | 114 | def process_batch(detections, labels, iouv): 115 | """ 116 | Return correct prediction matrix 117 | Arguments: 118 | detections (array[N, 6]), x1, y1, x2, y2, conf, class 119 | labels (array[M, 5]), class, x1, y1, x2, y2 120 | Returns: 121 | correct (array[N, 10]), for 10 IoU levels 122 | """ 123 | correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) 124 | iou = box_iou(labels[:, 1:], detections[:, :4]) 125 | correct_class = labels[:, 0:1] == detections[:, 5] 126 | for i in range(len(iouv)): 127 | x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match 128 | if x[0].shape[0]: 129 | matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] 130 | if x[0].shape[0] > 1: 131 | matches = matches[matches[:, 2].argsort()[::-1]] 132 | matches = matches[np.unique(matches[:, 1], return_index=True)[1]] 133 | # matches = matches[matches[:, 2].argsort()[::-1]] 134 | matches = matches[np.unique(matches[:, 0], return_index=True)[1]] 135 | correct[matches[:, 1].astype(int), i] = True 136 | return torch.tensor(correct, dtype=torch.bool, device=iouv.device) 137 | 138 | 139 | @smart_inference_mode() 140 | def run( 141 | data, 142 | engine_file=None, # model.pt path(s) 143 | batch_size=1, # batch size 144 | imgsz=640, # inference size (pixels) 145 | conf_thres=0.001, # confidence threshold 146 | iou_thres=0.7, # NMS IoU threshold 147 | max_det=300, # maximum detections per image 148 | task='val', # train, val, test, speed or study 149 | device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu 150 | workers=8, # max dataloader workers (per RANK in DDP mode) 151 | single_cls=False, # treat as single-class dataset 152 | augment=False, # augmented inference 153 | verbose=False, # verbose output 154 | save_txt=False, # save results to *.txt 155 | save_hybrid=False, # save label+prediction hybrid results to *.txt 156 | save_conf=False, # save confidences in --save-txt labels 157 | save_json=False, # save a COCO-JSON results file 158 | project=ROOT / 'runs/val', # save to project/name 159 | name='exp', # save to project/name 160 | exist_ok=False, # existing project/name ok, do not increment 161 | half=True, # use FP16 half-precision inference 162 | dnn=False, # use OpenCV DNN for ONNX inference 163 | min_items=0, # Experimental 164 | model=None, 165 | dataloader=None, 166 | save_dir=Path(''), 167 | plots=True, 168 | callbacks=Callbacks(), 169 | compute_loss=None, 170 | prefix=colorstr('VAL-TRT:') 171 | ): 172 | stride=32 173 | logger = trt.Logger(trt.Logger.INFO) 174 | runtime = trt.Runtime(logger) 175 | engine = runtime.deserialize_cuda_engine(open(engine_file, 'rb') .read()) 176 | inputs, outputs, bindings, stream = allocate_buffers(engine) 177 | 178 | inputshape = [engine.get_tensor_shape(binding) for binding in engine][0] 179 | imgsz=inputshape[3] - stride ## set imgz from engine 180 | outputshape = [engine.get_tensor_shape(binding) for binding in engine][1] 181 | 182 | context = engine.create_execution_context() 183 | 184 | for i in range(engine.num_io_tensors): 185 | context.set_tensor_address(engine.get_tensor_name(i), bindings[i]) 186 | 187 | # Initialize/load model and set device 188 | training = False 189 | device = torch.device("cpu") 190 | gs = 32 191 | 192 | # Directories 193 | save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run 194 | (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir 195 | data = check_dataset(data) # check 196 | # Configure 197 | is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'val2017.txt') # COCO dataset 198 | nc = 1 if single_cls else int(data['nc']) # number of classes 199 | iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95 200 | niou = iouv.numel() 201 | 202 | # Dataloader 203 | if not training: 204 | task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images 205 | dataloader = create_dataloader(data[task], 206 | imgsz, 207 | batch_size, 208 | stride, 209 | single_cls, 210 | pad=0.5, 211 | rect=True, 212 | workers=workers, 213 | min_items=0, 214 | prefix=colorstr(f'{task}: '))[0] 215 | 216 | seen = 0 217 | confusion_matrix = ConfusionMatrix(nc=nc) 218 | #names = model.names if hasattr(model, 'names') else model.module.names # get class names 219 | names=data['names'] 220 | 221 | if isinstance(names, (list, tuple)): # old format 222 | names = dict(enumerate(names)) 223 | class_map = coco80_to_coco91_class() if is_coco else list(range(1000)) 224 | s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95') 225 | tp, fp, p, r, f1, mp, mr, map50, ap50, map = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 226 | dt = Profile(), Profile(), Profile() # profiling times 227 | loss = torch.zeros(3, device=device) 228 | jdict, stats, ap, ap_class = [], [], [], [] 229 | callbacks.run('on_val_start') 230 | pbar = tqdm(dataloader, desc=s, bar_format=TQDM_BAR_FORMAT) # progress bar 231 | 232 | for batch_i, (im, targets, paths, shapes) in enumerate(pbar): 233 | callbacks.run('on_val_batch_start') 234 | with dt[0]: 235 | im = (im / 255.0).float() # 0 - 255 to 0.0 - 1.0 236 | # dataloader is setup pad=0.5 237 | input_image = torch.full((inputshape[1], inputshape[2], inputshape[3]), 114 / 255.0, dtype=torch.float32) 238 | im_cp=im.squeeze(0) 239 | input_image[:, :im_cp.size(1), :im_cp.size(2)] = im_cp 240 | targets = targets.to(device) 241 | nb, _, height, width = im.shape # batch size, channels, height, width 242 | # Run model 243 | with dt[1]: 244 | inputs[0].host = input_image.data.numpy() 245 | trt_outputs = do_inference(context, inputs=inputs, outputs=outputs, stream = stream) 246 | preds = torch.Tensor(trt_outputs[0].reshape(outputshape)) 247 | 248 | 249 | # Run NMS 250 | targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels 251 | lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling 252 | 253 | with dt[2]: 254 | preds = non_max_suppression(preds, 255 | conf_thres, 256 | iou_thres, 257 | labels=lb, 258 | multi_label=True, 259 | agnostic=single_cls, 260 | max_det=max_det) 261 | # Metrics 262 | for si, pred in enumerate(preds): 263 | labels = targets[targets[:, 0] == si, 1:] 264 | nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions 265 | path, shape = Path(paths[si]), shapes[si][0] 266 | correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init 267 | seen += 1 268 | 269 | if npr == 0: 270 | if nl: 271 | stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0])) 272 | if plots: 273 | confusion_matrix.process_batch(detections=None, labels=labels[:, 0]) 274 | continue 275 | 276 | # Predictions 277 | if single_cls: 278 | pred[:, 5] = 0 279 | predn = pred.clone() 280 | scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred 281 | 282 | # Evaluate 283 | if nl: 284 | tbox = xywh2xyxy(labels[:, 1:5]) # target boxes 285 | scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels 286 | labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels 287 | correct = process_batch(predn, labelsn, iouv) 288 | if plots: 289 | confusion_matrix.process_batch(predn, labelsn) 290 | stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls) 291 | 292 | # Save/log 293 | if save_txt: 294 | save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt') 295 | if save_json: 296 | save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary 297 | callbacks.run('on_val_image_end', pred, predn, path, names, im[si]) 298 | 299 | # Plot images 300 | if plots and batch_i < 3: 301 | plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels 302 | plot_images(im, output_to_target(preds), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred 303 | 304 | callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, preds) 305 | 306 | # Compute metrics 307 | stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy 308 | if len(stats) and stats[0].any(): 309 | tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names) 310 | ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 311 | mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean() 312 | nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class 313 | 314 | # Print results 315 | pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format 316 | LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map)) 317 | if nt.sum() == 0: 318 | LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels') 319 | 320 | # Print results per class 321 | if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats): 322 | for i, c in enumerate(ap_class): 323 | LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i])) 324 | 325 | # Print speeds 326 | t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image 327 | if not training: 328 | shape = (batch_size, 3, imgsz, imgsz) 329 | LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t) 330 | 331 | # Plots 332 | if plots: 333 | confusion_matrix.plot(save_dir=save_dir, names=list(names.values())) 334 | callbacks.run('on_val_end', nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix) 335 | 336 | # Save JSON 337 | if save_json and len(jdict): 338 | w = Path(engine_file[0] if isinstance(engine_file, list) else engine_file).stem if engine_file is not None else '' # weights 339 | anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json 340 | pred_json = str(save_dir / f"{w}_predictions.json") # predictions json 341 | LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...') 342 | with open(pred_json, 'w') as f: 343 | json.dump(jdict, f) 344 | 345 | try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb 346 | check_requirements('pycocotools') 347 | from pycocotools.coco import COCO 348 | from pycocotools.cocoeval import COCOeval 349 | 350 | anno = COCO(anno_json) # init annotations api 351 | pred = anno.loadRes(pred_json) # init predictions api 352 | eval = COCOeval(anno, pred, 'bbox') 353 | if is_coco: 354 | eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # image IDs to evaluate 355 | eval.evaluate() 356 | eval.accumulate() 357 | eval.summarize() 358 | map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5) 359 | except Exception as e: 360 | LOGGER.info(f'pycocotools unable to run: {e}') 361 | 362 | # Return results 363 | if not training: 364 | s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' 365 | LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}") 366 | maps = np.zeros(nc) + map 367 | for i, c in enumerate(ap_class): 368 | maps[c] = ap[i] 369 | 370 | eval_mp, eval_mr, eval_map50, eval_map = round(mp, 4), round(mr, 4), round(map50, 4), round(map, 4) 371 | LOGGER.info(f'\n{prefix} Eval TRT - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}') 372 | 373 | return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t 374 | 375 | 376 | def parse_opt(): 377 | parser = argparse.ArgumentParser() 378 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path') 379 | parser.add_argument('--engine-file', type=str, default=ROOT / 'yolo.engine', help='model path(s)') 380 | parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold') 381 | parser.add_argument('--iou-thres', type=float, default=0.7, help='NMS IoU threshold') 382 | parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image') 383 | parser.add_argument('--task', default='val', help='train, val, test, speed or study') 384 | parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 ') 385 | parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)') 386 | parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') 387 | parser.add_argument('--verbose', action='store_true', help='report mAP by class') 388 | parser.add_argument('--save-txt', action='store_true', help='save results to *.txt') 389 | parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt') 390 | parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels') 391 | parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file') 392 | parser.add_argument('--project', default=ROOT / 'runs/val_trt', help='save to project/name') 393 | parser.add_argument('--name', default='exp', help='save to project/name') 394 | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment') 395 | opt = parser.parse_args() 396 | opt.data = check_yaml(opt.data) # check YAML 397 | opt.save_json |= opt.data.endswith('coco.yaml') 398 | opt.save_txt |= opt.save_hybrid 399 | print_args(vars(opt)) 400 | return opt 401 | 402 | 403 | def main(opt): 404 | #check_requirements(exclude=('tensorboard', 'thop')) 405 | 406 | if opt.task in ('train', 'val', 'test'): # run normally 407 | if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466 408 | LOGGER.info(f'WARNING ⚠️ confidence threshold {opt.conf_thres} > 0.001 produces invalid results') 409 | if opt.save_hybrid: 410 | LOGGER.info('WARNING ⚠️ --save-hybrid will return high mAP from hybrid labels, not from predictions alone') 411 | run(**vars(opt)) 412 | 413 | else: 414 | engine_file = opt.engine_file if isinstance(opt.engine_file, list) else [opt.engine_file] 415 | opt.half = torch.cuda.is_available() and opt.device != 'cpu' # FP16 for fastest results 416 | if opt.task == 'speed': # speed benchmarks 417 | # python val.py --task speed --data coco.yaml --batch 1 --weights yolo.pt... 418 | opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False 419 | for opt.engine_file in engine_file: 420 | run(**vars(opt), plots=False) 421 | 422 | elif opt.task == 'study': # speed vs mAP benchmarks 423 | # python val.py --task study --data coco.yaml --iou 0.7 --weights yolo.pt... 424 | for opt.engine_file in engine_file: 425 | f = f'study_{Path(opt.data).stem}_{Path(opt.engine_file).stem}.txt' # filename to save to 426 | x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis 427 | for opt.imgsz in x: # img-size 428 | LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...') 429 | r, _, t = run(**vars(opt), plots=False) 430 | y.append(r + t) # results and times 431 | np.savetxt(f, y, fmt='%10.4g') # save 432 | os.system('zip -r study.zip study_*.txt') 433 | plot_val_study(x=x) # plot 434 | 435 | 436 | if __name__ == "__main__": 437 | opt = parse_opt() 438 | main(opt) 439 | --------------------------------------------------------------------------------