├── LICENSE
├── README.md
├── draw-engine.py
├── export_qat.py
├── install_dependencies.sh
├── models
├── experimental_trt.py
├── quantize.py
└── quantize_rules.py
├── patch_yolov9.sh
├── qat.py
├── scripts
├── generate_trt_engine.sh
└── val_trt.sh
├── segment
└── qat_seg.py
└── val_trt.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YOLOv9 QAT for TensorRT 10.9 Detection / Segmentation
2 |
3 | This repository contains an implementation of YOLOv9 with Quantization-Aware Training (QAT), specifically designed for deployment on platforms utilizing TensorRT for hardware-accelerated inference.
4 | This implementation aims to provide an efficient, low-latency version of YOLOv9 for real-time detection applications.
5 | If you do not intend to deploy your model using TensorRT, it is recommended not to proceed with this implementation.
6 |
7 | - The files in this repository represent a patch that adds QAT functionality to the original [YOLOv9 repository](https://github.com/WongKinYiu/yolov9/).
8 | - This patch is intended to be applied to the main YOLOv9 repository to incorporate the ability to train with QAT.
9 | - The implementation is optimized to work efficiently with TensorRT, an inference library that leverages hardware acceleration to enhance inference performance.
10 | - Users interested in implementing object detection using YOLOv9 with QAT on TensorRT platforms can benefit from this repository as it provides a ready-to-use solution.
11 |
12 |
13 | We use [TensorRT's pytorch quantization tool](https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization) to finetune training QAT yolov9 from the pre-trained weight, then export the model to onnx and deploy it with TensorRT. The accuray and performance can be found in below table.
14 |
15 | For those who are not familiar with QAT, I highly recommend watching this video:
[Quantization explained with PyTorch - Post-Training Quantization, Quantization-Aware Training](https://www.youtube.com/watch?v=0VdNflU08yA)
16 |
17 | **Important**
18 | Evaluation of the segmentation model using TensorRT is currently under development. Once I have more available time, I will complete and release this work.
19 |
20 | 🌟 We still have plenty of nodes to improve Q/DQ, and we rely on the community's contribution to enhance this project, benefiting us all. Let's collaborate and make it even better! 🚀
21 |
22 | ## Release Highlights
23 | - This release includes an upgrade from TensorRT 8 to TensorRT 10, ensuring compatibility with the CUDA version supported - by the latest NVIDIA Ada Lovelace GPUs.
24 | - The inference has been upgraded utilizing `enqueueV3` instead `enqueueV2`.
25 | - To maintain legacy support for TensorRT 8, a [dedicated branch](https://github.com/levipereira/yolov9-qat/tree/TensorRT-8) has been created. **Outdated**
26 | - We've added a new option `val_trt.sh --generate-graph` which enables [Graph Rendering](#generate-tensort-profiling-and-svg-image) functionality. This feature facilitates the creation of graphical representations of the engine plan in SVG image format.
27 |
28 |
29 | # Perfomance / Accuracy
30 | [Full Report](#benchmark)
31 |
32 |
33 | ## Accuracy Report
34 |
35 | **YOLOv9-C**
36 |
37 | ### Evaluation Results
38 |
39 | ## Detection
40 | #### Activation SiLU
41 |
42 | | Eval Model | AP | AP50 | Precision | Recall |
43 | |------------|--------|--------|-----------|--------|
44 | | **Origin (Pytorch)** | 0.529 | 0.699 | 0.743 | 0.634 |
45 | | **INT8 (Pytorch)** | 0.529 | 0.702 | 0.742 | 0.63 |
46 | | **INT8 (TensorRT)** | 0.529 | 0.696 | 0.739 | 0.635 |
47 |
48 |
49 | #### Activation ReLU
50 |
51 | | Eval Model | AP | AP50 | Precision | Recall |
52 | |------------|--------|--------|-----------|--------|
53 | | **Origin (Pytorch)** | 0.519 | 0.69 | 0.719 | 0.629 |
54 | | **INT8 (Pytorch)** | 0.518 | 0.69 | 0.726 | 0.625 |
55 | | **INT8 (TensorRT)** | 0.517 | 0.685 | 0.723 | 0.626 |
56 |
57 | ### Evaluation Comparison
58 |
59 | #### Activation SiLU
60 | | Eval Model | AP | AP50 | Precision | Recall |
61 | |----------------------|------|------|-----------|--------|
62 | | **INT8 (TensorRT)** vs **Origin (Pytorch)** | | | | |
63 | | | 0.000 | -0.003 | -0.004 | +0.001 |
64 |
65 | #### Activation ReLU
66 | | Eval Model | AP | AP50 | Precision | Recall |
67 | |----------------------|------|------|-----------|--------|
68 | | **INT8 (TensorRT)** vs **Origin (Pytorch)** | | | | |
69 | | | -0.002 | -0.005 | +0.004 | -0.003 |
70 |
71 | ## Segmentation
72 | | Model | Box | | | | Mask | | | |
73 | |--------|-----|--|--|--|------|--|--|--|
74 | | | P | R | mAP50 | mAP50-95 | P | R | mAP50 | mAP50-95 |
75 | | Origin | 0.729 | 0.632 | 0.691 | 0.521 | 0.717 | 0.611 | 0.657 | 0.423 |
76 | | PTQ | 0.729 | 0.626 | 0.688 | 0.520 | 0.717 | 0.604 | 0.654 | 0.421 |
77 | | QAT | 0.725 | 0.631 | 0.689 | 0.521 | 0.714 | 0.609 | 0.655 | 0.421 |
78 |
79 |
80 | ## Latency/Throughput Report - TensorRT
81 |
82 | 
83 |
84 | ## Device
85 | | **GPU** | |
86 | |---------------------------|------------------------------|
87 | | Device | **NVIDIA GeForce RTX 4090** |
88 | | Compute Capability | 8.9 |
89 | | SMs | 128 |
90 | | Device Global Memory | 24207 MiB |
91 | | Application Compute Clock Rate | 2.58 GHz |
92 | | Application Memory Clock Rate | 10.501 GHz |
93 |
94 |
95 | ### Latency/Throughput
96 |
97 | | Model Name | Batch Size | Latency (99%) | Throughput (qps) | Total Inferences (IPS) |
98 | |-----------------|------------|----------------|------------------|------------------------|
99 | | **(FP16) SiLU** | 1 | 1.25 ms | 803 | 803 |
100 | | | 4 | 3.37 ms | 300 | 1200 |
101 | | | 8 | 6.6 ms | 153 | 1224 |
102 | | | 12 | 10 ms | 99 | 1188 |
103 | | | | | | |
104 | | **INT8 (SiLU)** | 1 | 0.97 ms | 1030 | 1030 |
105 | | | 4 | 2,06 ms | 486 | 1944 |
106 | | | 8 | 3.69 ms | 271 | 2168 |
107 | | | 12 | 5.36 ms | 189 | 2268 |
108 | | | | | | |
109 | | **INT8 (ReLU)** | 1 | 0.87 ms | 1150 | 1150 |
110 | | | 4 | 1.78 ms | 562 | 2248 |
111 | | | 8 | 3.06 ms | 327 | 2616 |
112 | | | 12 | 4.63 ms | 217 | 2604 |
113 |
114 | ## Latency/Throughput Comparison (INT8 vs FP16)
115 |
116 | | Model Name | Batch Size | Latency (99%) Change | Throughput (qps) Change | Total Inferences (IPS) Change |
117 | |---|---|---|---|---|
118 | | **INT8(SiLU)** vs **FP16** | 1 | -20.8% | +28.4% | +28.4% |
119 | | | 4 | -37.1% | +62.0% | +62.0% |
120 | | | 8 | -41.1% | +77.0% | +77.0% |
121 | | | 12 | -46.9% | +90.9% | +90.9% |
122 |
123 |
124 | ## QAT Training (Finetune)
125 |
126 | In this section, we'll outline the steps to perform Quantization-Aware Training (QAT) using fine-tuning.
**Please note that the supported quantization mode is fine-tuning only.**
The model should be trained using the original implementation train.py, and after training and reparameterization of the model, the user should proceed with quantization.
127 |
128 | ### Steps:
129 |
130 | 1. **Train the Model Using [Training Session](https://github.com/WongKinYiu/yolov9/tree/main?tab=readme-ov-file#training):**
131 | - Utilize the original implementation train.py to train your YOLOv9 model with your dataset and desired configurations.
132 | - Follow the training instructions provided in the original YOLOv9 repository to ensure proper training.
133 |
134 | 2. **Reparameterize the Model [reparameterization.py](https://github.com/sunmooncode/yolov9/blob/main/tools/reparameterization.py):**
135 | - After completing the training, reparameterize the trained model to prepare it for quantization. This step is crucial for ensuring that the model's weights are in a suitable format for quantization.
136 |
137 | 3. **[Proceed with Quantization](#quantize-model):**
138 | - Once the model is reparameterized, proceed with the quantization process. This involves applying the Quantization-Aware Training technique to fine-tune the model's weights, taking into account the quantization effects.
139 |
140 | 4. **[Eval Pytorch](#evaluate-using-pytorch) / [Eval TensorRT](#evaluate-using-tensorrt):**
141 | - After quantization, it's crucial to validate the performance of the quantized model to ensure that it meets your requirements in terms of accuracy and efficiency.
142 | - Test the quantized model thoroughly at both stages: during the quantization phase using PyTorch and after training using TensorRT.
143 | - Please note that different versions of TensorRT may yield varying results and perfomance
144 |
145 | 5. **Export to ONNX:**
146 | - [Export ONNX](#export-onnx)
147 | - Once you are satisfied with the quantized model's performance, you can proceed to export it to ONNX format.
148 |
149 | 6. **Deploy with TensorRT:**
150 | - [Deployment with TensorRT](#deployment-with-tensorrt)
151 | - After exporting to ONNX, you can deploy the model using TensorRT for hardware-accelerated inference on platforms supporting TensorRT.
152 |
153 |
154 |
155 |
156 | By following these steps, you can successfully perform Quantization-Aware Training (QAT) using fine-tuning with your YOLOv9 model.
157 |
158 | ## How to Install and Training
159 | Suggest to use docker environment.
160 | NVIDIA PyTorch image (`nvcr.io/nvidia/pytorch:24.10-py3`)
161 |
162 | Release 24.10 is based Ubuntu 22.04 including Python 3.10 CUDA 12.6.2, which requires NVIDIA Driver release 560 or later, if you are running on a data center GPU check docs.
163 | https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-10.html
164 |
165 | ## Installation
166 | ```bash
167 |
168 | docker pull nvcr.io/nvidia/pytorch:24.10-py3
169 |
170 | ## clone original yolov9
171 | git clone https://github.com/WongKinYiu/yolov9.git
172 |
173 | docker run --gpus all \
174 | -it \
175 | --net host \
176 | --ipc=host \
177 | -v $(pwd)/yolov9:/yolov9 \
178 | -v $(pwd)/coco/:/yolov9/coco \
179 | -v $(pwd)/runs:/yolov9/runs \
180 | nvcr.io/nvidia/pytorch:24.10-py3
181 |
182 | ```
183 |
184 | 1. Clone and apply patch (Inside Docker)
185 | ```bash
186 | cd /
187 | git clone https://github.com/levipereira/yolov9-qat.git
188 | cd /yolov9-qat
189 | ./patch_yolov9.sh /yolov9
190 | ```
191 |
192 | 2. Install dependencies
193 |
194 | - **This release upgrade TensorRT to 10.9**
195 | - `./install_dependencies.sh`
196 |
197 | ```bash
198 | cd /yolov9-qat
199 | ./install_dependencies.sh
200 | ```
201 |
202 |
203 | 3. Download dataset and pretrained model
204 | ```bash
205 | $ cd /yolov9
206 | $ bash scripts/get_coco.sh
207 | $ wget https://github.com/WongKinYiu/yolov9/releases/download/v0.1/yolov9-c-converted.pt
208 | ```
209 |
210 |
211 | ## Usage
212 |
213 | ## Quantize Model
214 |
215 | To quantize a YOLOv9 model, run:
216 |
217 | ```bash
218 | python3 qat.py quantize --weights yolov9-c-converted.pt --name yolov9_qat --exist-ok
219 |
220 | python qat.py quantize --weights --data --hyp ...
221 | ```
222 | ## Quantize Command Arguments
223 |
224 | ### Description
225 | This command is used to perform PTQ/QAT finetuning.
226 |
227 | ### Arguments
228 |
229 | - `--weights`: Path to the model weights (.pt). Default: ROOT/runs/models_original/yolov9-c.pt.
230 | - `--data`: Path to the dataset configuration file (data.yaml). Default: ROOT/data/coco.yaml.
231 | - `--hyp`: Path to the hyperparameters file (hyp.yaml). Default: ROOT/data/hyps/hyp.scratch-high.yaml.
232 | - `--device`: Device to use for training/evaluation (e.g., "cuda:0"). Default: "cuda:0".
233 | - `--batch-size`: Total batch size for training/evaluation. Default: 10.
234 | - `--imgsz`, `--img`, `--img-size`: Train/val image size (pixels). Default: 640.
235 | - `--project`: Directory to save the training/evaluation outputs. Default: ROOT/runs/qat.
236 | - `--name`: Name of the training/evaluation experiment. Default: 'exp'.
237 | - `--exist-ok`: Flag to indicate if existing project/name should be overwritten.
238 | - `--iters`: Iterations per epoch. Default: 200.
239 | - `--seed`: Global training seed. Default: 57.
240 | - `--supervision-stride`: Supervision stride. Default: 1.
241 | - `--no-eval-origin`: Disable eval for origin model.
242 | - `--no-eval-ptq`: Disable eval for ptq model.
243 |
244 |
245 | ## Sensitive Layer Analysis
246 | ```bash
247 | python qat.py sensitive --weights yolov9-c.pt --data data/coco.yaml --hyp hyp.scratch-high.yaml ...
248 | ```
249 |
250 | ## Sensitive Command Arguments
251 |
252 | ### Description
253 | This command is used for sensitive layer analysis.
254 |
255 | ### Arguments
256 |
257 | - `--weights`: Path to the model weights (.pt). Default: ROOT/runs/models_original/yolov9-c.pt.
258 | - `--device`: Device to use for training/evaluation (e.g., "cuda:0"). Default: "cuda:0".
259 | - `--data`: Path to the dataset configuration file (data.yaml). Default: data/coco.yaml.
260 | - `--batch-size`: Total batch size for training/evaluation. Default: 10.
261 | - `--imgsz`, `--img`, `--img-size`: Train/val image size (pixels). Default: 640.
262 | - `--hyp`: Path to the hyperparameters file (hyp.yaml). Default: data/hyps/hyp.scratch-high.yaml.
263 | - `--project`: Directory to save the training/evaluation outputs. Default: ROOT/runs/qat_sentive.
264 | - `--name`: Name of the training/evaluation experiment. Default: 'exp'.
265 | - `--exist-ok`: Flag to indicate if existing project/name should be overwritten.
266 | - `--num-image`: Number of images to evaluate. Default: None.
267 |
268 |
269 | ## Evaluate QAT Model
270 |
271 | ### Evaluate using Pytorch
272 | ```bash
273 | python3 qat.py eval --weights runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.pt --name eval_qat_yolov9
274 | ```
275 | ## Evaluation Command Arguments
276 |
277 | ### Description
278 | This command is used to perform evaluation on QAT Models.
279 |
280 | ### Arguments
281 |
282 | - `--weights`: Path to the model weights (.pt). Default: ROOT/runs/models_original/yolov9-c.pt.
283 | - `--data`: Path to the dataset configuration file (data.yaml). Default: data/coco.yaml.
284 | - `--batch-size`: Total batch size for evaluation. Default: 10.
285 | - `--imgsz`, `--img`, `--img-size`: Validation image size (pixels). Default: 640.
286 | - `--device`: Device to use for evaluation (e.g., "cuda:0"). Default: "cuda:0".
287 | - `--conf-thres`: Confidence threshold for evaluation. Default: 0.001.
288 | - `--iou-thres`: NMS threshold for evaluation. Default: 0.7.
289 | - `--project`: Directory to save the evaluation outputs. Default: ROOT/runs/qat_eval.
290 | - `--name`: Name of the evaluation experiment. Default: 'exp'.
291 | - `--exist-ok`: Flag to indicate if existing project/name should be overwritten.
292 |
293 | ### Evaluate using TensorRT
294 |
295 | ```bash
296 | ./scripts/val_trt.sh
297 |
298 | ./scripts/val_trt.sh runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.pt data/coco.yaml 640
299 | ```
300 |
301 | ## Generate TensoRT Profiling and SVG image
302 |
303 |
304 | TensorRT Explorer can be installed by executing `./install_dependencies.sh --trex`.
This installation is necessary to enable the generation of Graph SV, allowing visualization of the profiling data for a TensorRT engine.
305 |
306 | ```bash
307 | ./scripts/val_trt.sh runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.pt data/coco.yaml 640 --generate-graph
308 | ```
309 |
310 | # Export ONNX
311 | The goal of exporting to ONNX is to deploy to TensorRT, not to ONNX runtime. So we only export fake quantized model into a form TensorRT will take. Fake quantization will be broken into a pair of QuantizeLinear/DequantizeLinear ONNX ops. TensorRT will take the generated ONNX graph, and execute it in int8 in the most optimized way to its capability.
312 |
313 | ## Export ONNX Model without End2End
314 | ```bash
315 | python3 export_qat.py --weights runs/qat/yolov9_qat/weights/qat_best_yolov9-c.pt --include onnx --dynamic --simplify --inplace
316 | ```
317 |
318 | ## Export ONNX Model End2End
319 | ```bash
320 | python3 export_qat.py --weights runs/qat/yolov9_qat/weights/qat_best_yolov9-c.pt --include onnx_end2end
321 | ```
322 |
323 |
324 | ## Deployment with Tensorrt
325 | ```bash
326 | /usr/src/tensorrt/bin/trtexec \
327 | --onnx=runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.onnx \
328 | --int8 --fp16 \
329 | --useCudaGraph \
330 | --minShapes=images:1x3x640x640 \
331 | --optShapes=images:4x3x640x640 \
332 | --maxShapes=images:8x3x640x640 \
333 | --saveEngine=runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted.engine
334 | ```
335 |
336 | # Benchmark
337 | Note: To test FP16 Models (such as Origin) remove flag `--int8`
338 | ```bash
339 | # Set variable batch_size and model_path_no_ext
340 | export batch_size=4
341 | export filepath_no_ext=runs/qat/yolov9_qat/weights/qat_best_yolov9-c-converted
342 | trtexec \
343 | --onnx=${filepath_no_ext}.onnx \
344 | --fp16 \
345 | --int8 \
346 | --saveEngine=${filepath_no_ext}.engine \
347 | --timingCacheFile=${filepath_no_ext}.engine.timing.cache \
348 | --warmUp=500 \
349 | --duration=10 \
350 | --useCudaGraph \
351 | --useSpinWait \
352 | --noDataTransfers \
353 | --minShapes=images:1x3x640x640 \
354 | --optShapes=images:${batch_size}x3x640x640 \
355 | --maxShapes=images:${batch_size}x3x640x640
356 | ```
357 |
358 | ### Device
359 | ```bash
360 | === Device Information ===
361 | Available Devices:
362 | Device 0: "NVIDIA GeForce RTX 4090"
363 | Selected Device: NVIDIA GeForce RTX 4090
364 | Selected Device ID: 0
365 | Compute Capability: 8.9
366 | SMs: 128
367 | Device Global Memory: 24207 MiB
368 | Shared Memory per SM: 100 KiB
369 | Memory Bus Width: 384 bits (ECC disabled)
370 | Application Compute Clock Rate: 2.58 GHz
371 | Application Memory Clock Rate: 10.501 GHz
372 | ```
373 |
374 | ## Output Details
375 | - `Latency`: refers to the [min, max, mean, median, 99% percentile] of the engine latency measurements, when timing the engine w/o profiling layers.
376 | - `Throughput`: is measured in query (inference) per second (QPS).
377 |
378 | ## YOLOv9-C QAT (SiLU)
379 | ## Batch Size 1
380 | ```bash
381 | Throughput: 1026.71 qps
382 | Latency: min = 0.969727 ms, max = 0.975098 ms, mean = 0.972263 ms, median = 0.972656 ms, percentile(90%) = 0.973145 ms, percentile(95%) = 0.973633 ms, percentile(99%) = 0.974121 ms
383 | Enqueue Time: min = 0.00195312 ms, max = 0.0195312 ms, mean = 0.00228119 ms, median = 0.00219727 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms
384 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
385 | GPU Compute Time: min = 0.969727 ms, max = 0.975098 ms, mean = 0.972263 ms, median = 0.972656 ms, percentile(90%) = 0.973145 ms, percentile(95%) = 0.973633 ms, percentile(99%) = 0.974121 ms
386 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
387 | Total Host Walltime: 10.0019 s
388 | Total GPU Compute Time: 9.98417 s
389 | ```
390 |
391 | ## BatchSize 4
392 | ```bash
393 | === Performance summary ===
394 | Throughput: 485.73 qps
395 | Latency: min = 2.05176 ms, max = 2.06152 ms, mean = 2.05712 ms, median = 2.05713 ms, percentile(90%) = 2.05908 ms, percentile(95%) = 2.05957 ms, percentile(99%) = 2.06055 ms
396 | Enqueue Time: min = 0.00195312 ms, max = 0.00708008 ms, mean = 0.00230195 ms, median = 0.00219727 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00415039 ms
397 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
398 | GPU Compute Time: min = 2.05176 ms, max = 2.06152 ms, mean = 2.05712 ms, median = 2.05713 ms, percentile(90%) = 2.05908 ms, percentile(95%) = 2.05957 ms, percentile(99%) = 2.06055 ms
399 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
400 | Total Host Walltime: 10.0035 s
401 | Total GPU Compute Time: 9.99553 s
402 | ```
403 |
404 |
405 | ## BatchSize 8
406 | ```bash
407 | === Performance summary ===
408 | Throughput: 271.107 qps
409 | Latency: min = 3.6792 ms, max = 3.69775 ms, mean = 3.68694 ms, median = 3.68652 ms, percentile(90%) = 3.69043 ms, percentile(95%) = 3.69141 ms, percentile(99%) = 3.69336 ms
410 | Enqueue Time: min = 0.00195312 ms, max = 0.0090332 ms, mean = 0.0023588 ms, median = 0.00231934 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00476074 ms
411 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
412 | GPU Compute Time: min = 3.6792 ms, max = 3.69775 ms, mean = 3.68694 ms, median = 3.68652 ms, percentile(90%) = 3.69043 ms, percentile(95%) = 3.69141 ms, percentile(99%) = 3.69336 ms
413 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
414 | Total Host Walltime: 10.0071 s
415 | Total GPU Compute Time: 10.0027 s
416 | ```
417 | ## BatchSize 12
418 | ```bash
419 | === Performance summary ===
420 | Throughput: 188.812 qps
421 | Latency: min = 5.25 ms, max = 5.37097 ms, mean = 5.2946 ms, median = 5.28906 ms, percentile(90%) = 5.32129 ms, percentile(95%) = 5.32593 ms, percentile(99%) = 5.36475 ms
422 | Enqueue Time: min = 0.00195312 ms, max = 0.0898438 ms, mean = 0.00248513 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00463867 ms
423 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
424 | GPU Compute Time: min = 5.25 ms, max = 5.37097 ms, mean = 5.2946 ms, median = 5.28906 ms, percentile(90%) = 5.32129 ms, percentile(95%) = 5.32593 ms, percentile(99%) = 5.36475 ms
425 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
426 | Total Host Walltime: 10.01 s
427 | Total GPU Compute Time: 10.0068 s
428 | ```
429 |
430 |
431 |
432 | ## YOLOv9-C QAT (ReLU)
433 | ## Batch Size 1
434 | ```bash
435 | === Performance summary ===
436 | Throughput: 1149.49 qps
437 | Latency: min = 0.866211 ms, max = 0.871094 ms, mean = 0.868257 ms, median = 0.868164 ms, percentile(90%) = 0.869385 ms, percentile(95%) = 0.869629 ms, percentile(99%) = 0.870117 ms
438 | Enqueue Time: min = 0.00195312 ms, max = 0.0180664 ms, mean = 0.00224214 ms, median = 0.00219727 ms, percentile(90%) = 0.00268555 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms
439 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
440 | GPU Compute Time: min = 0.866211 ms, max = 0.871094 ms, mean = 0.868257 ms, median = 0.868164 ms, percentile(90%) = 0.869385 ms, percentile(95%) = 0.869629 ms, percentile(99%) = 0.870117 ms
441 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
442 | Total Host Walltime: 10.0018 s
443 | Total GPU Compute Time: 9.98235 s
444 | ```
445 |
446 | ## BatchSize 4
447 | ```bash
448 | === Performance summary ===
449 | Throughput: 561.857 qps
450 | Latency: min = 1.77344 ms, max = 1.78418 ms, mean = 1.77814 ms, median = 1.77832 ms, percentile(90%) = 1.77979 ms, percentile(95%) = 1.78076 ms, percentile(99%) = 1.78174 ms
451 | Enqueue Time: min = 0.00195312 ms, max = 0.0205078 ms, mean = 0.00233018 ms, median = 0.0022583 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00439453 ms
452 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
453 | GPU Compute Time: min = 1.77344 ms, max = 1.78418 ms, mean = 1.77814 ms, median = 1.77832 ms, percentile(90%) = 1.77979 ms, percentile(95%) = 1.78076 ms, percentile(99%) = 1.78174 ms
454 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
455 | Total Host Walltime: 10.0043 s
456 | Total GPU Compute Time: 9.99494 s
457 | ```
458 |
459 |
460 | ## BatchSize 8
461 | ```bash
462 | === Performance summary ===
463 | Throughput: 326.86 qps
464 | Latency: min = 3.04126 ms, max = 3.06934 ms, mean = 3.05773 ms, median = 3.05859 ms, percentile(90%) = 3.06152 ms, percentile(95%) = 3.0625 ms, percentile(99%) = 3.06396 ms
465 | Enqueue Time: min = 0.00195312 ms, max = 0.0209961 ms, mean = 0.00235826 ms, median = 0.00231934 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00463867 ms
466 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
467 | GPU Compute Time: min = 3.04126 ms, max = 3.06934 ms, mean = 3.05773 ms, median = 3.05859 ms, percentile(90%) = 3.06152 ms, percentile(95%) = 3.0625 ms, percentile(99%) = 3.06396 ms
468 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
469 | Total Host Walltime: 10.0043 s
470 | Total GPU Compute Time: 9.99877 s
471 | ```
472 | ## BatchSize 12
473 | ```bash
474 | === Performance summary ===
475 | Throughput: 216.441 qps
476 | Latency: min = 4.60742 ms, max = 4.63184 ms, mean = 4.61852 ms, median = 4.61816 ms, percentile(90%) = 4.62305 ms, percentile(95%) = 4.62439 ms, percentile(99%) = 4.62744 ms
477 | Enqueue Time: min = 0.00195312 ms, max = 0.0131836 ms, mean = 0.00250633 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00341797 ms, percentile(99%) = 0.00531006 ms
478 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
479 | GPU Compute Time: min = 4.60742 ms, max = 4.63184 ms, mean = 4.61852 ms, median = 4.61816 ms, percentile(90%) = 4.62305 ms, percentile(95%) = 4.62439 ms, percentile(99%) = 4.62744 ms
480 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
481 | Total Host Walltime: 10.0074 s
482 | Total GPU Compute Time: 10.0037 s
483 | ```
484 |
485 | ## YOLOv9-C FP16
486 | ## Batch Size 1
487 | ```bash
488 | === Performance summary ===
489 | Throughput: 802.984 qps
490 | Latency: min = 1.23901 ms, max = 1.25439 ms, mean = 1.24376 ms, median = 1.24316 ms, percentile(90%) = 1.24805 ms, percentile(95%) = 1.24902 ms, percentile(99%) = 1.24951 ms
491 | Enqueue Time: min = 0.00195312 ms, max = 0.00756836 ms, mean = 0.00240711 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms
492 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
493 | GPU Compute Time: min = 1.23901 ms, max = 1.25439 ms, mean = 1.24376 ms, median = 1.24316 ms, percentile(90%) = 1.24805 ms, percentile(95%) = 1.24902 ms, percentile(99%) = 1.24951 ms
494 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
495 | Total Host Walltime: 10.0027 s
496 | Total GPU Compute Time: 9.98985 s
497 | ```
498 |
499 | ## BatchSize 4
500 | ```bash
501 | === Performance summary ===
502 | Throughput: 300.281 qps
503 | Latency: min = 3.30341 ms, max = 3.38025 ms, mean = 3.32861 ms, median = 3.3291 ms, percentile(90%) = 3.33594 ms, percentile(95%) = 3.34229 ms, percentile(99%) = 3.37 ms
504 | Enqueue Time: min = 0.00195312 ms, max = 0.00830078 ms, mean = 0.00244718 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms
505 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
506 | GPU Compute Time: min = 3.30341 ms, max = 3.38025 ms, mean = 3.32861 ms, median = 3.3291 ms, percentile(90%) = 3.33594 ms, percentile(95%) = 3.34229 ms, percentile(99%) = 3.37 ms
507 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
508 | Total Host Walltime: 10.0073 s
509 | Total GPU Compute Time: 10.0025 s
510 | ```
511 |
512 |
513 | ## BatchSize 8
514 | ```bash
515 | === Performance summary ===
516 | Throughput: 153.031 qps
517 | Latency: min = 6.47882 ms, max = 6.64679 ms, mean = 6.53299 ms, median = 6.5332 ms, percentile(90%) = 6.55029 ms, percentile(95%) = 6.55762 ms, percentile(99%) = 6.59766 ms
518 | Enqueue Time: min = 0.00195312 ms, max = 0.0117188 ms, mean = 0.00248772 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms
519 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
520 | GPU Compute Time: min = 6.47882 ms, max = 6.64679 ms, mean = 6.53299 ms, median = 6.5332 ms, percentile(90%) = 6.55029 ms, percentile(95%) = 6.55762 ms, percentile(99%) = 6.59766 ms
521 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
522 | Total Host Walltime: 10.011 s
523 | Total GPU Compute Time: 10.0085 s
524 | ```
525 |
526 | ## BatchSize 8
527 | ```bash
528 | === Performance summary ===
529 | Throughput: 99.3162 qps
530 | Latency: min = 10.0372 ms, max = 10.0947 ms, mean = 10.0672 ms, median = 10.0674 ms, percentile(90%) = 10.0781 ms, percentile(95%) = 10.0811 ms, percentile(99%) = 10.0859 ms
531 | Enqueue Time: min = 0.00195312 ms, max = 0.0078125 ms, mean = 0.00248219 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00292969 ms, percentile(99%) = 0.00390625 ms
532 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
533 | GPU Compute Time: min = 10.0372 ms, max = 10.0947 ms, mean = 10.0672 ms, median = 10.0674 ms, percentile(90%) = 10.0781 ms, percentile(95%) = 10.0811 ms, percentile(99%) = 10.0859 ms
534 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
535 | Total Host Walltime: 10.0286 s
536 | Total GPU Compute Time: 10.0269 s
537 | ```
538 |
539 |
540 | # Segmentation
541 |
542 | ## FP16
543 | ### Batch Size 8
544 |
545 | ```bash
546 | === Performance summary ===
547 | Throughput: 124.055 qps
548 | Latency: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms
549 | Enqueue Time: min = 0.00219727 ms, max = 0.0200653 ms, mean = 0.00271174 ms, median = 0.00256348 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00317383 ms, percentile(99%) = 0.00466919 ms
550 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
551 | GPU Compute Time: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms
552 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
553 | Total Host Walltime: 3.01478 s
554 | Total GPU Compute Time: 3.01415 s
555 | ```
556 |
557 | ## INT8 / FP16
558 | ### Batch Size 8
559 | ```bash
560 | === Performance summary ===
561 | Throughput: 223.63 qps
562 | Latency: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms
563 | Enqueue Time: min = 0.00219727 ms, max = 0.00854492 ms, mean = 0.00258152 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00305176 ms, percentile(99%) = 0.00439453 ms
564 | H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
565 | GPU Compute Time: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms
566 | D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
567 | Total Host Walltime: 3.00944 s
568 | Total GPU Compute Time: 3.00836 s
569 | ```
--------------------------------------------------------------------------------
/draw-engine.py:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: MIT
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a
6 | # copy of this software and associated documentation files (the "Software"),
7 | # to deal in the Software without restriction, including without limitation
8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 | # and/or sell copies of the Software, and to permit persons to whom the
10 | # Software is furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in
13 | # all copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 | # DEALINGS IN THE SOFTWARE.
22 | ################################################################################
23 |
24 |
25 | """
26 | This script generates an SVG diagram of the input engine graph SVG file.
27 | Note:
28 | THIS SCRIPT DEPENDS ON LIB: https://github.com/NVIDIA/TensorRT/tree/main/tools/experimental/trt-engine-explorer
29 | this script requires graphviz which can be installed:
30 | $ /yolov9-qat/install_dependencies.sh --trex
31 | """
32 |
33 | import graphviz
34 | from trex import *
35 | from trex import graphing
36 | import argparse
37 | import shutil
38 |
39 |
40 | def draw_engine(engine_json_fname: str, engine_profile_fname: str):
41 | engine_name=engine_json_fname.replace('.layer.json', '')
42 | graphviz_is_installed = shutil.which("dot") is not None
43 | if not graphviz_is_installed:
44 | print("graphviz is required but it is not installed.\n")
45 | print("To install on Ubuntu:")
46 | print("$ /yolov9-qat/install_dependencies.sh --trex")
47 | exit()
48 |
49 | plan = EnginePlan(engine_json_fname, engine_profile_fname)
50 | formatter = graphing.layer_type_formatter
51 | display_regions = True
52 | expand_layer_details = False
53 |
54 | graph = graphing.to_dot(plan, formatter,
55 | display_regions=display_regions,
56 | expand_layer_details=expand_layer_details)
57 | graphing.render_dot(graph, engine_name, 'svg')
58 |
59 | graphing.render_dot(graph, engine_name, 'png')
60 |
61 |
62 | if __name__ == "__main__":
63 | parser = argparse.ArgumentParser()
64 | parser.add_argument('--layer', help="name of engine JSON file to draw" , required=True)
65 | parser.add_argument('--profile', help="name of profile JSON file to draw", required=True)
66 | args = parser.parse_args()
67 | draw_engine(engine_json_fname=args.layer,engine_profile_fname=args.profile)
68 |
--------------------------------------------------------------------------------
/export_qat.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import contextlib
3 | import json
4 | import os
5 | import platform
6 | import re
7 | import subprocess
8 | import sys
9 | import time
10 | import warnings
11 | from pathlib import Path
12 |
13 | import pandas as pd
14 | import torch
15 | from torch.utils.mobile_optimizer import optimize_for_mobile
16 | import onnx_graphsurgeon as gs
17 |
18 | import models.quantize as quantize
19 |
20 | from pytorch_quantization import nn as quant_nn
21 |
22 | FILE = Path(__file__).resolve()
23 | ROOT = FILE.parents[0] # YOLO root directory
24 | if str(ROOT) not in sys.path:
25 | sys.path.append(str(ROOT)) # add ROOT to PATH
26 | if platform.system() != 'Windows':
27 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
28 |
29 | from models.experimental_trt import End2End_TRT
30 | from models.experimental import attempt_load
31 | from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel
32 | from utils.dataloaders import LoadImages
33 | from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
34 | check_yaml, colorstr, file_size, get_default_args, print_args, url2file, yaml_save)
35 | from utils.torch_utils import select_device, smart_inference_mode
36 | from models.quantize import remove_redundant_qdq_model
37 |
38 | MACOS = platform.system() == 'Darwin' # macOS environment
39 |
40 |
41 |
42 | def export_formats():
43 | # YOLO export formats
44 | x = [
45 | ['PyTorch', '-', '.pt', True, True],
46 | ['TorchScript', 'torchscript', '.torchscript', True, True],
47 | ['ONNX', 'onnx', '.onnx', True, True],
48 | ['ONNX END2END', 'onnx_end2end', '_end2end.onnx', True, True],
49 | ['OpenVINO', 'openvino', '_openvino_model', True, False],
50 | ['TensorRT', 'engine', '.engine', False, True],
51 | ['CoreML', 'coreml', '.mlmodel', True, False],
52 | ['TensorFlow SavedModel', 'saved_model', '_saved_model', True, True],
53 | ['TensorFlow GraphDef', 'pb', '.pb', True, True],
54 | ['TensorFlow Lite', 'tflite', '.tflite', True, False],
55 | ['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', False, False],
56 | ['TensorFlow.js', 'tfjs', '_web_model', False, False],
57 | ['PaddlePaddle', 'paddle', '_paddle_model', True, True],]
58 | return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
59 |
60 |
61 | def try_export(inner_func):
62 | # YOLO export decorator, i..e @try_export
63 | inner_args = get_default_args(inner_func)
64 |
65 | def outer_func(*args, **kwargs):
66 | prefix = inner_args['prefix']
67 | try:
68 | with Profile() as dt:
69 | f, model = inner_func(*args, **kwargs)
70 | LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
71 | return f, model
72 | except Exception as e:
73 | LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
74 | return None, None
75 |
76 | return outer_func
77 |
78 |
79 | @try_export
80 | def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
81 | # YOLO TorchScript model export
82 | LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
83 | f = file.with_suffix('.torchscript')
84 |
85 | ts = torch.jit.trace(model, im, strict=False)
86 | d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
87 | extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
88 | if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
89 | optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
90 | else:
91 | ts.save(str(f), _extra_files=extra_files)
92 | return f, None
93 |
94 |
95 | @try_export
96 | def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX:')):
97 | # YOLO ONNX export
98 | check_requirements('onnx')
99 | import onnx
100 |
101 | is_model_qat=False
102 | for i in range(0, len(model.model)):
103 | layer = model.model[i]
104 | if quantize.have_quantizer(layer):
105 | is_model_qat=True
106 | break
107 |
108 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
109 | f = file.with_suffix('.onnx')
110 |
111 | output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
112 | if dynamic:
113 | dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
114 | if isinstance(model, SegmentationModel):
115 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
116 | dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
117 | elif isinstance(model, DetectionModel):
118 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
119 |
120 | if is_model_qat:
121 | warnings.filterwarnings("ignore")
122 | LOGGER.info(f'{prefix} Model QAT Detected ...')
123 | quant_nn.TensorQuantizer.use_fb_fake_quant = True
124 | model.eval()
125 | quantize.initialize()
126 | quantize.replace_custom_module_forward(model)
127 | with torch.no_grad():
128 | torch.onnx.export(
129 | model,
130 | im,
131 | f,
132 | verbose=False,
133 | opset_version=13,
134 | do_constant_folding=True,
135 | input_names=['images'],
136 | output_names=output_names,
137 | dynamic_axes=dynamic)
138 |
139 | else:
140 | torch.onnx.export(
141 | model.cpu() if dynamic else model, # --dynamic only compatible with cpu
142 | im.cpu() if dynamic else im,
143 | f,
144 | verbose=False,
145 | opset_version=opset,
146 | do_constant_folding=True,
147 | input_names=['images'],
148 | output_names=output_names,
149 | dynamic_axes=dynamic or None)
150 |
151 | # Checks
152 | model_onnx = onnx.load(f) # load onnx model
153 | onnx.checker.check_model(model_onnx) # check onnx model
154 |
155 | # Metadata
156 | d = {'stride': int(max(model.stride)), 'names': model.names}
157 | for k, v in d.items():
158 | meta = model_onnx.metadata_props.add()
159 | meta.key, meta.value = k, str(v)
160 |
161 | # Simplify
162 | if simplify:
163 | try:
164 | cuda = torch.cuda.is_available()
165 | check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
166 | import onnxsim
167 |
168 | LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
169 | model_onnx, check = onnxsim.simplify(model_onnx)
170 | assert check, 'assert check failed'
171 | onnx.save(model_onnx, f)
172 | except Exception as e:
173 | LOGGER.info(f'{prefix} simplifier failure: {e}')
174 |
175 | LOGGER.info(f'{prefix} Removing redundant Q/DQ layer with onnx_graphsurgeon {gs.__version__}...')
176 | remove_redundant_qdq_model(model_onnx, f)
177 | model_onnx = onnx.load(f)
178 | return f, model_onnx
179 |
180 |
181 |
182 | @try_export
183 | def export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, labels, mask_resolution, pooler_scale, sampling_ratio, prefix=colorstr('ONNX END2END:')):
184 | LOGGER.info(f'{prefix} Model type: {type(model)}')
185 | LOGGER.info(f'{prefix} Is DetectionModel: {isinstance(model, DetectionModel)}')
186 | LOGGER.info(f'{prefix} Is SegmentationModel: {isinstance(model, SegmentationModel)}')
187 |
188 | has_detection_capabilities = hasattr(model, 'model') and hasattr(model, 'names') and hasattr(model, 'stride')
189 |
190 | if not has_detection_capabilities:
191 | raise RuntimeError("Model not supported. Only Detection Models can be exported with End2End functionality.")
192 |
193 | LOGGER.info(f'{prefix} Model accepted for export.')
194 |
195 | is_det_model=True
196 | if isinstance(model, SegmentationModel):
197 | is_det_model=False
198 |
199 | env_is_det_model = os.getenv("MODEL_DET")
200 | if env_is_det_model == "0":
201 | is_det_model = False
202 |
203 | # YOLO ONNX export
204 | check_requirements('onnx')
205 | import onnx
206 |
207 | is_model_qat=False
208 | for i in range(0, len(model.model)):
209 | layer = model.model[i]
210 | if quantize.have_quantizer(layer):
211 | is_model_qat=True
212 | break
213 |
214 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
215 | f = os.path.splitext(file)[0] + "-end2end.onnx"
216 | batch_size = 'batch'
217 | d = {
218 | 'stride': int(max(model.stride)),
219 | 'names': model.names,
220 | 'model type' : 'Detection' if is_det_model else 'Segmentation',
221 | 'TRT Compatibility': '8.6 or above' if class_agnostic else '8.5 or above',
222 | 'TRT Plugins': 'EfficientNMS_TRT' if is_det_model else 'EfficientNMSX_TRT, ROIAlign'
223 | }
224 |
225 |
226 | dynamic_axes = {'images': {0 : 'batch', 2: 'height', 3:'width'}, } # variable length axes
227 |
228 | output_axes = {
229 | 'num_dets': {0: 'batch'},
230 | 'det_boxes': {0: 'batch'},
231 | 'det_scores': {0: 'batch'},
232 | 'det_classes': {0: 'batch'},
233 | }
234 | if is_det_model:
235 | output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
236 | shapes = [ batch_size, 1,
237 | batch_size, topk_all, 4,
238 | batch_size, topk_all,
239 | batch_size, topk_all]
240 |
241 | else:
242 | output_axes['det_masks'] = {0: 'batch'}
243 | output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes', 'det_masks']
244 | shapes = [ batch_size, 1,
245 | batch_size, topk_all, 4,
246 | batch_size, topk_all,
247 | batch_size, topk_all,
248 | batch_size, topk_all, mask_resolution * mask_resolution]
249 |
250 | dynamic_axes.update(output_axes)
251 | model = End2End_TRT(model, class_agnostic, topk_all, iou_thres, conf_thres, mask_resolution, pooler_scale, sampling_ratio, None ,device, labels, is_det_model )
252 |
253 |
254 | if is_model_qat:
255 | warnings.filterwarnings("ignore")
256 | LOGGER.info(f'{prefix} Model QAT Detected ...')
257 | quant_nn.TensorQuantizer.use_fb_fake_quant = True
258 | model.eval()
259 | quantize.initialize()
260 |
261 | with torch.no_grad():
262 | torch.onnx.export(model,
263 | im,
264 | f,
265 | verbose=False,
266 | export_params=True, # store the trained parameter weights inside the model file
267 | opset_version=16,
268 | do_constant_folding=True, # whether to execute constant folding for optimization
269 | input_names=['images'],
270 | output_names=output_names,
271 | dynamic_axes=dynamic_axes)
272 | quant_nn.TensorQuantizer.use_fb_fake_quant = False
273 | else:
274 | torch.onnx.export(model,
275 | im,
276 | f,
277 | verbose=False,
278 | export_params=True, # store the trained parameter weights inside the model file
279 | opset_version=16,
280 | do_constant_folding=True, # whether to execute constant folding for optimization
281 | input_names=['images'],
282 | output_names=output_names,
283 | dynamic_axes=dynamic_axes)
284 |
285 | # Checks
286 | model_onnx = onnx.load(f) # load onnx model
287 | onnx.checker.check_model(model_onnx) # check onnx model
288 | for k, v in d.items():
289 | meta = model_onnx.metadata_props.add()
290 | meta.key, meta.value = k, str(v)
291 |
292 | for i in model_onnx.graph.output:
293 | for j in i.type.tensor_type.shape.dim:
294 | j.dim_param = str(shapes.pop(0))
295 |
296 | if simplify:
297 | try:
298 | import onnxsim
299 |
300 | print('\nStarting to simplify ONNX...')
301 | model_onnx, check = onnxsim.simplify(model_onnx)
302 | assert check, 'assert check failed'
303 | except Exception as e:
304 | print(f'Simplifier failure: {e}')
305 |
306 | # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
307 | onnx.save(model_onnx,f)
308 | print('ONNX export success, saved as %s' % f)
309 |
310 | LOGGER.info(f'{prefix} Removing redundant Q/DQ layer with onnx_graphsurgeon {gs.__version__}...')
311 | remove_redundant_qdq_model(model_onnx, f)
312 | model_onnx = onnx.load(f)
313 |
314 | return f, model_onnx
315 |
316 |
317 | @try_export
318 | def export_openvino(file, metadata, half, prefix=colorstr('OpenVINO:')):
319 | # YOLO OpenVINO export
320 | check_requirements('openvino-dev') # requires openvino-dev: https://pypi.org/project/openvino-dev/
321 | import openvino.inference_engine as ie
322 |
323 | LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
324 | f = str(file).replace('.pt', f'_openvino_model{os.sep}')
325 |
326 | #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
327 | #cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {"--compress_to_fp16" if half else ""}"
328 | half_arg = "--compress_to_fp16" if half else ""
329 | cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} {half_arg}"
330 | subprocess.run(cmd.split(), check=True, env=os.environ) # export
331 | yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
332 | return f, None
333 |
334 |
335 | @try_export
336 | def export_paddle(model, im, file, metadata, prefix=colorstr('PaddlePaddle:')):
337 | # YOLO Paddle export
338 | check_requirements(('paddlepaddle', 'x2paddle'))
339 | import x2paddle
340 | from x2paddle.convert import pytorch2paddle
341 |
342 | LOGGER.info(f'\n{prefix} starting export with X2Paddle {x2paddle.__version__}...')
343 | f = str(file).replace('.pt', f'_paddle_model{os.sep}')
344 |
345 | pytorch2paddle(module=model, save_dir=f, jit_type='trace', input_examples=[im]) # export
346 | yaml_save(Path(f) / file.with_suffix('.yaml').name, metadata) # add metadata.yaml
347 | return f, None
348 |
349 |
350 | @try_export
351 | def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
352 | # YOLO CoreML export
353 | check_requirements('coremltools')
354 | import coremltools as ct
355 |
356 | LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
357 | f = file.with_suffix('.mlmodel')
358 |
359 | ts = torch.jit.trace(model, im, strict=False) # TorchScript model
360 | ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
361 | bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
362 | if bits < 32:
363 | if MACOS: # quantization only supported on macOS
364 | with warnings.catch_warnings():
365 | warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
366 | ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
367 | else:
368 | print(f'{prefix} quantization only supported on macOS, skipping...')
369 | ct_model.save(f)
370 | return f, ct_model
371 |
372 |
373 | @try_export
374 | def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
375 | # YOLO TensorRT export https://developer.nvidia.com/tensorrt
376 | assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
377 | try:
378 | import tensorrt as trt
379 | except Exception:
380 | if platform.system() == 'Linux':
381 | check_requirements('nvidia-tensorrt', cmds='-U --index-url https://pypi.ngc.nvidia.com')
382 | import tensorrt as trt
383 |
384 | if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
385 | grid = model.model[-1].anchor_grid
386 | model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
387 | export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
388 | model.model[-1].anchor_grid = grid
389 | else: # TensorRT >= 8
390 | check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
391 | export_onnx(model, im, file, 12, dynamic, simplify) # opset 12
392 | onnx = file.with_suffix('.onnx')
393 |
394 | LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
395 | assert onnx.exists(), f'failed to export ONNX file: {onnx}'
396 | f = file.with_suffix('.engine') # TensorRT engine file
397 | logger = trt.Logger(trt.Logger.INFO)
398 | if verbose:
399 | logger.min_severity = trt.Logger.Severity.VERBOSE
400 |
401 | builder = trt.Builder(logger)
402 | config = builder.create_builder_config()
403 | config.max_workspace_size = workspace * 1 << 30
404 | # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
405 |
406 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
407 | network = builder.create_network(flag)
408 | parser = trt.OnnxParser(network, logger)
409 | if not parser.parse_from_file(str(onnx)):
410 | raise RuntimeError(f'failed to load ONNX file: {onnx}')
411 |
412 | inputs = [network.get_input(i) for i in range(network.num_inputs)]
413 | outputs = [network.get_output(i) for i in range(network.num_outputs)]
414 | for inp in inputs:
415 | LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
416 | for out in outputs:
417 | LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
418 |
419 | if dynamic:
420 | if im.shape[0] <= 1:
421 | LOGGER.warning(f"{prefix} WARNING ⚠️ --dynamic model requires maximum --batch-size argument")
422 | profile = builder.create_optimization_profile()
423 | for inp in inputs:
424 | profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
425 | config.add_optimization_profile(profile)
426 |
427 | LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
428 | if builder.platform_has_fast_fp16 and half:
429 | config.set_flag(trt.BuilderFlag.FP16)
430 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
431 | t.write(engine.serialize())
432 | return f, None
433 |
434 |
435 | @try_export
436 | def export_saved_model(model,
437 | im,
438 | file,
439 | dynamic,
440 | tf_nms=False,
441 | agnostic_nms=False,
442 | topk_per_class=100,
443 | topk_all=100,
444 | iou_thres=0.45,
445 | conf_thres=0.25,
446 | keras=False,
447 | prefix=colorstr('TensorFlow SavedModel:')):
448 | # YOLO TensorFlow SavedModel export
449 | try:
450 | import tensorflow as tf
451 | except Exception:
452 | check_requirements(f"tensorflow{'' if torch.cuda.is_available() else '-macos' if MACOS else '-cpu'}")
453 | import tensorflow as tf
454 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
455 |
456 | from models.tf import TFModel
457 |
458 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
459 | f = str(file).replace('.pt', '_saved_model')
460 | batch_size, ch, *imgsz = list(im.shape) # BCHW
461 |
462 | tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
463 | im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
464 | _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
465 | inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
466 | outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
467 | keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
468 | keras_model.trainable = False
469 | keras_model.summary()
470 | if keras:
471 | keras_model.save(f, save_format='tf')
472 | else:
473 | spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
474 | m = tf.function(lambda x: keras_model(x)) # full model
475 | m = m.get_concrete_function(spec)
476 | frozen_func = convert_variables_to_constants_v2(m)
477 | tfm = tf.Module()
478 | tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x), [spec])
479 | tfm.__call__(im)
480 | tf.saved_model.save(tfm,
481 | f,
482 | options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
483 | tf.__version__, '2.6') else tf.saved_model.SaveOptions())
484 | return f, keras_model
485 |
486 |
487 | @try_export
488 | def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
489 | # YOLO TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
490 | import tensorflow as tf
491 | from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
492 |
493 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
494 | f = file.with_suffix('.pb')
495 |
496 | m = tf.function(lambda x: keras_model(x)) # full model
497 | m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
498 | frozen_func = convert_variables_to_constants_v2(m)
499 | frozen_func.graph.as_graph_def()
500 | tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
501 | return f, None
502 |
503 |
504 | @try_export
505 | def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
506 | # YOLOv5 TensorFlow Lite export
507 | import tensorflow as tf
508 |
509 | LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
510 | batch_size, ch, *imgsz = list(im.shape) # BCHW
511 | f = str(file).replace('.pt', '-fp16.tflite')
512 |
513 | converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
514 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
515 | converter.target_spec.supported_types = [tf.float16]
516 | converter.optimizations = [tf.lite.Optimize.DEFAULT]
517 | if int8:
518 | from models.tf import representative_dataset_gen
519 | dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
520 | converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
521 | converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
522 | converter.target_spec.supported_types = []
523 | converter.inference_input_type = tf.uint8 # or tf.int8
524 | converter.inference_output_type = tf.uint8 # or tf.int8
525 | converter.experimental_new_quantizer = True
526 | f = str(file).replace('.pt', '-int8.tflite')
527 | if nms or agnostic_nms:
528 | converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
529 |
530 | tflite_model = converter.convert()
531 | open(f, "wb").write(tflite_model)
532 | return f, None
533 |
534 |
535 | @try_export
536 | def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
537 | # YOLO Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
538 | cmd = 'edgetpu_compiler --version'
539 | help_url = 'https://coral.ai/docs/edgetpu/compiler/'
540 | assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
541 | if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
542 | LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
543 | sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
544 | for c in (
545 | 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
546 | 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
547 | 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
548 | subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
549 | ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
550 |
551 | LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
552 | f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
553 | f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
554 |
555 | cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
556 | subprocess.run(cmd.split(), check=True)
557 | return f, None
558 |
559 |
560 | @try_export
561 | def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
562 | # YOLO TensorFlow.js export
563 | check_requirements('tensorflowjs')
564 | import tensorflowjs as tfjs
565 |
566 | LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
567 | f = str(file).replace('.pt', '_web_model') # js dir
568 | f_pb = file.with_suffix('.pb') # *.pb path
569 | f_json = f'{f}/model.json' # *.json path
570 |
571 | cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
572 | f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
573 | subprocess.run(cmd.split())
574 |
575 | json = Path(f_json).read_text()
576 | with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
577 | subst = re.sub(
578 | r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
579 | r'"Identity.?.?": {"name": "Identity.?.?"}, '
580 | r'"Identity.?.?": {"name": "Identity.?.?"}, '
581 | r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
582 | r'"Identity_1": {"name": "Identity_1"}, '
583 | r'"Identity_2": {"name": "Identity_2"}, '
584 | r'"Identity_3": {"name": "Identity_3"}}}', json)
585 | j.write(subst)
586 | return f, None
587 |
588 |
589 | def add_tflite_metadata(file, metadata, num_outputs):
590 | # Add metadata to *.tflite models per https://www.tensorflow.org/lite/models/convert/metadata
591 | with contextlib.suppress(ImportError):
592 | # check_requirements('tflite_support')
593 | from tflite_support import flatbuffers
594 | from tflite_support import metadata as _metadata
595 | from tflite_support import metadata_schema_py_generated as _metadata_fb
596 |
597 | tmp_file = Path('/tmp/meta.txt')
598 | with open(tmp_file, 'w') as meta_f:
599 | meta_f.write(str(metadata))
600 |
601 | model_meta = _metadata_fb.ModelMetadataT()
602 | label_file = _metadata_fb.AssociatedFileT()
603 | label_file.name = tmp_file.name
604 | model_meta.associatedFiles = [label_file]
605 |
606 | subgraph = _metadata_fb.SubGraphMetadataT()
607 | subgraph.inputTensorMetadata = [_metadata_fb.TensorMetadataT()]
608 | subgraph.outputTensorMetadata = [_metadata_fb.TensorMetadataT()] * num_outputs
609 | model_meta.subgraphMetadata = [subgraph]
610 |
611 | b = flatbuffers.Builder(0)
612 | b.Finish(model_meta.Pack(b), _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
613 | metadata_buf = b.Output()
614 |
615 | populator = _metadata.MetadataPopulator.with_model_file(file)
616 | populator.load_metadata_buffer(metadata_buf)
617 | populator.load_associated_files([str(tmp_file)])
618 | populator.populate()
619 | tmp_file.unlink()
620 |
621 |
622 | @smart_inference_mode()
623 | def run(
624 | data=ROOT / 'data/coco.yaml', # 'dataset.yaml path'
625 | weights=ROOT / 'yolo.pt', # weights path
626 | imgsz=(640, 640), # image (height, width)
627 | batch_size=1, # batch size
628 | device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
629 | include=('torchscript', 'onnx'), # include formats
630 | class_agnostic=False,
631 | half=False, # FP16 half-precision export
632 | inplace=False, # set YOLO Detect() inplace=True
633 | keras=False, # use Keras
634 | optimize=False, # TorchScript: optimize for mobile
635 | int8=False, # CoreML/TF INT8 quantization
636 | dynamic=False, # ONNX/TF/TensorRT: dynamic axes
637 | simplify=False, # ONNX: simplify model
638 | opset=12, # ONNX: opset version
639 | verbose=False, # TensorRT: verbose log
640 | workspace=4, # TensorRT: workspace size (GB)
641 | nms=False, # TF: add NMS to model
642 | agnostic_nms=False, # TF: add agnostic NMS to model
643 | topk_per_class=100, # TF.js NMS: topk per class to keep
644 | topk_all=100, # TF.js NMS: topk for all classes to keep
645 | iou_thres=0.45, # TF.js NMS: IoU threshold
646 | conf_thres=0.25, # TF.js NMS: confidence threshold
647 | mask_resolution=56,
648 | pooler_scale=0.25,
649 | sampling_ratio=0,
650 | ):
651 | t = time.time()
652 | include = [x.lower() for x in include] # to lowercase
653 | fmts = tuple(export_formats()['Argument'][1:]) # --include arguments
654 | flags = [x in include for x in fmts]
655 | assert sum(flags) == len(include), f'ERROR: Invalid --include {include}, valid --include arguments are {fmts}'
656 | jit, onnx, onnx_end2end, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle = flags # export booleans
657 | file = Path(url2file(weights) if str(weights).startswith(('http:/', 'https:/')) else weights) # PyTorch weights
658 |
659 | # Load PyTorch model
660 | device = select_device(device)
661 | if half:
662 | assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
663 | assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
664 | model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
665 |
666 | # Checks
667 | imgsz *= 2 if len(imgsz) == 1 else 1 # expand
668 | if optimize:
669 | assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
670 |
671 | # Input
672 | gs = int(max(model.stride)) # grid size (max stride)
673 | imgsz = [check_img_size(x, gs) for x in imgsz] # verify img_size are gs-multiples
674 | im = torch.zeros(batch_size, 3, *imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
675 |
676 | # Update model
677 | model.eval()
678 | for k, m in model.named_modules():
679 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)):
680 | m.inplace = inplace
681 | m.dynamic = dynamic
682 | m.export = True
683 |
684 | for _ in range(2):
685 | y = model(im) # dry runs
686 | if half and not coreml:
687 | im, model = im.half(), model.half() # to FP16
688 | shape = tuple((y[0] if isinstance(y, (tuple, list)) else y).shape) # model output shape
689 | metadata = {'stride': int(max(model.stride)), 'names': model.names} # model metadata
690 | LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
691 |
692 | # Exports
693 | f = [''] * len(fmts) # exported filenames
694 | warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
695 | if jit: # TorchScript
696 | f[0], _ = export_torchscript(model, im, file, optimize)
697 | if engine: # TensorRT required before ONNX
698 | f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
699 | if onnx or xml: # OpenVINO requires ONNX
700 | f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
701 | if onnx_end2end:
702 | labels = model.names
703 | f[2], _ = export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, len(labels), mask_resolution, pooler_scale, sampling_ratio )
704 | if xml: # OpenVINO
705 | f[3], _ = export_openvino(file, metadata, half)
706 | if coreml: # CoreML
707 | f[4], _ = export_coreml(model, im, file, int8, half)
708 | if any((saved_model, pb, tflite, edgetpu, tfjs)): # TensorFlow formats
709 | assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
710 | assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
711 | f[5], s_model = export_saved_model(model.cpu(),
712 | im,
713 | file,
714 | dynamic,
715 | tf_nms=nms or agnostic_nms or tfjs,
716 | agnostic_nms=agnostic_nms or tfjs,
717 | topk_per_class=topk_per_class,
718 | topk_all=topk_all,
719 | iou_thres=iou_thres,
720 | conf_thres=conf_thres,
721 | keras=keras)
722 | if pb or tfjs: # pb prerequisite to tfjs
723 | f[6], _ = export_pb(s_model, file)
724 | if tflite or edgetpu:
725 | f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
726 | if edgetpu:
727 | f[8], _ = export_edgetpu(file)
728 | add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
729 | if tfjs:
730 | f[9], _ = export_tfjs(file)
731 | if paddle: # PaddlePaddle
732 | f[10], _ = export_paddle(model, im, file, metadata)
733 |
734 | # Finish
735 | f = [str(x) for x in f if x] # filter out '' and None
736 | if any(f):
737 | cls, det, seg = (isinstance(model, x) for x in (ClassificationModel, DetectionModel, SegmentationModel)) # type
738 | dir = Path('segment' if seg else 'classify' if cls else '')
739 | h = '--half' if half else '' # --half FP16 inference arg
740 | s = "# WARNING ⚠️ ClassificationModel not yet supported for PyTorch Hub AutoShape inference" if cls else \
741 | "# WARNING ⚠️ SegmentationModel not yet supported for PyTorch Hub AutoShape inference" if seg else ''
742 | if onnx_end2end:
743 | LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
744 | f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
745 | f"\nVisualize: https://netron.app")
746 | else:
747 | LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
748 | f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
749 | f"\nDetect: python {dir / ('detect.py' if det else 'predict.py')} --weights {f[-1]} {h}"
750 | f"\nValidate: python {dir / 'val.py'} --weights {f[-1]} {h}"
751 | f"\nPyTorch Hub: model = torch.hub.load('ultralytics/yolov5', 'custom', '{f[-1]}') {s}"
752 | f"\nVisualize: https://netron.app")
753 | return f # return list of exported files/dirs
754 |
755 |
756 | def parse_opt():
757 | parser = argparse.ArgumentParser()
758 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
759 | parser.add_argument('--weights', nargs='+', type=str, default=ROOT / 'yolo.pt', help='model.pt path(s)')
760 | parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640, 640], help='image (h, w)')
761 | parser.add_argument('--batch-size', type=int, default=1, help='batch size')
762 | parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
763 | parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
764 | parser.add_argument('--inplace', action='store_true', help='set YOLO Detect() inplace=True')
765 | parser.add_argument('--keras', action='store_true', help='TF: use Keras')
766 | parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
767 | parser.add_argument('--int8', action='store_true', help='CoreML/TF INT8 quantization')
768 | parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
769 | parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
770 | parser.add_argument('--opset', type=int, default=12, help='ONNX: opset version')
771 | parser.add_argument('--verbose', action='store_true', help='TensorRT: verbose log')
772 | parser.add_argument('--workspace', type=int, default=4, help='TensorRT: workspace size (GB)')
773 | parser.add_argument('--nms', action='store_true', help='TF: add NMS to model')
774 | parser.add_argument('--agnostic-nms', action='store_true', help='TF: add agnostic NMS to model')
775 | parser.add_argument('--topk-per-class', type=int, default=100, help='TF.js NMS: topk per class to keep')
776 | parser.add_argument('--topk-all', type=int, default=100, help='ONNX END2END/TF.js NMS: topk for all classes to keep')
777 | parser.add_argument('--iou-thres', type=float, default=0.45, help='ONNX END2END/TF.js NMS: IoU threshold')
778 | parser.add_argument('--conf-thres', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold')
779 | parser.add_argument('--class-agnostic', action='store_true', help='Agnostic NMS (single class)')
780 | parser.add_argument('--mask-resolution', type=int, default=160, help='Mask pooled output.')
781 | parser.add_argument('--pooler-scale', type=float, default=0.25, help='Multiplicative factor used to translate the ROI coordinates. ')
782 | parser.add_argument('--sampling-ratio', type=int, default=0, help='Number of sampling points in the interpolation. Allowed values are non-negative integers.')
783 | parser.add_argument(
784 | '--include',
785 | nargs='+',
786 | default=['torchscript'],
787 | help='torchscript, onnx, onnx_end2end, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle')
788 | opt = parser.parse_args()
789 |
790 | if 'onnx_end2end' in opt.include:
791 | opt.simplify = True
792 | opt.dynamic = True
793 | opt.inplace = True
794 | opt.half = False
795 |
796 | print_args(vars(opt))
797 | return opt
798 |
799 |
800 | def main(opt):
801 | for opt.weights in (opt.weights if isinstance(opt.weights, list) else [opt.weights]):
802 | run(**vars(opt))
803 |
804 |
805 | if __name__ == "__main__":
806 | opt = parse_opt()
807 | main(opt)
808 |
--------------------------------------------------------------------------------
/install_dependencies.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Function to display usage message
4 | usage() {
5 | echo "Usage: $0 [--no-trex]" 1>&2
6 | exit 1
7 | }
8 |
9 | # Set default flags
10 | install_trex=true # TREx installation enabled by default
11 | install_base=true # Base dependencies always installed by default
12 |
13 | # Shared paths
14 | TENSORRT_REPO_PATH="/opt/nvidia/TensorRT"
15 | DOWNLOADS_PATH="/yolov9-qat/downloads"
16 |
17 | # Check command line options
18 | while [[ $# -gt 0 ]]; do
19 | case "$1" in
20 | --no-trex)
21 | install_trex=false
22 | ;;
23 | *)
24 | usage
25 | ;;
26 | esac
27 | shift
28 | done
29 |
30 | # Function to install system dependencies
31 | install_system_dependencies() {
32 | echo "Installing system dependencies..."
33 | apt-get update || return 1
34 | apt-get install -y zip htop screen libgl1-mesa-glx libfreetype6-dev || return 1
35 | return 0
36 | }
37 |
38 | # Function to upgrade TensorRT
39 | upgrade_tensorrt() {
40 | echo "Upgrading TensorRT..."
41 | local os="ubuntu2204"
42 | local trt_version="10.9.0"
43 | local cuda="cuda-12.8"
44 | local tensorrt_package="nv-tensorrt-local-repo-${os}-${trt_version}-${cuda}_1.0-1_amd64.deb"
45 | local download_path="${DOWNLOADS_PATH}/${tensorrt_package}"
46 |
47 | # Create downloads directory if it doesn't exist
48 | mkdir -p "$DOWNLOADS_PATH" || return 1
49 |
50 | # Check if the package already exists
51 | if [ ! -f "$download_path" ]; then
52 | echo "Downloading TensorRT package..."
53 | wget "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${trt_version}/local_repo/${tensorrt_package}" -O "$download_path" || return 1
54 | else
55 | echo "TensorRT package already exists at $download_path. Reusing existing file."
56 | fi
57 |
58 | # Install the package
59 | dpkg -i "$download_path" || return 1
60 | cp /var/nv-tensorrt-local-repo-${os}-${trt_version}-${cuda}/*keyring.gpg /usr/share/keyrings/ || return 1
61 | apt-get update || return 1
62 | apt-get install -y tensorrt || return 1
63 | apt-get purge "nv-tensorrt-local-repo*" -y || return 1
64 |
65 | # Keep the downloaded file for potential reuse
66 | echo "TensorRT package kept at $download_path for future use"
67 | return 0
68 | }
69 |
70 | # Function to install Python packages
71 | install_python_packages() {
72 | echo "Installing Python packages..."
73 | pip install --upgrade pip || return 1
74 | pip install --upgrade tensorrt==10.9.0.34 || return 1
75 |
76 | pip install seaborn \
77 | thop \
78 | "markdown-it-py>=2.2.0" \
79 | "onnx-simplifier>=0.4.35" \
80 | "onnxsim>=0.4.35" \
81 | "onnxruntime>=1.16.3" \
82 | "ujson>=5.9.0" \
83 | "pycocotools>=2.0.7" \
84 | "pycuda>=2025.1" || return 1
85 |
86 | pip install --upgrade onnx_graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com || return 1
87 | pip install pillow==9.5.0 --no-cache-dir --force-reinstall || return 1
88 | return 0
89 | }
90 |
91 | # Function to clone TensorRT repository once
92 | clone_tensorrt_repo() {
93 | echo "Cloning NVIDIA TensorRT repository..."
94 |
95 | if [ ! -d "$TENSORRT_REPO_PATH" ]; then
96 | # Create directory and clone repository
97 | mkdir -p "$(dirname "$TENSORRT_REPO_PATH")" || return 1
98 | git clone https://github.com/NVIDIA/TensorRT.git "$TENSORRT_REPO_PATH" || return 1
99 | cd "$TENSORRT_REPO_PATH" || return 1
100 | git checkout release/10.9 || return 1
101 | echo "TensorRT repository cloned successfully to $TENSORRT_REPO_PATH"
102 | else
103 | echo "TensorRT repository already exists at $TENSORRT_REPO_PATH"
104 | fi
105 |
106 | return 0
107 | }
108 |
109 | # Function to install PyTorch Quantization
110 | install_pytorch_quantization() {
111 | echo "Installing PyTorch Quantization..."
112 |
113 | # Navigate to PyTorch Quantization directory in TensorRT repo
114 | cd "$TENSORRT_REPO_PATH/tools/pytorch-quantization" || return 1
115 |
116 | # Install requirements and setup
117 | pip install -r requirements.txt || return 1
118 | python setup.py install || return 1
119 |
120 | echo "PyTorch Quantization installed successfully"
121 | return 0
122 | }
123 |
124 | # Function to install TREx
125 | install_trex_environment() {
126 | echo "Installing NVIDIA TREx environment..."
127 | # Check if TREx is not already installed
128 | if [ ! -d "/opt/nvidia_trex/env_trex" ]; then
129 | apt-get install -y graphviz || return 1
130 | pip install virtualenv "widgetsnbextension>=4.0.9" || return 1
131 |
132 | mkdir -p /opt/nvidia_trex || return 1
133 | cd /opt/nvidia_trex/ || return 1
134 | python3 -m virtualenv env_trex || return 1
135 | source env_trex/bin/activate || return 1
136 | pip install "Werkzeug>=2.2.2" "graphviz>=0.20.1" || return 1
137 |
138 | # Navigate to TREx directory in TensorRT repo
139 | cd "$TENSORRT_REPO_PATH/tools/experimental/trt-engine-explorer" || return 1
140 |
141 | source /opt/nvidia_trex/env_trex/bin/activate || return 1
142 | pip install -e . || return 1
143 | pip install jupyter_nbextensions_configurator notebook==6.4.12 ipywidgets || return 1
144 | jupyter nbextension enable widgetsnbextension --user --py || return 1
145 | deactivate || return 1
146 | else
147 | echo "NVIDIA TREx virtual environment already exists. Skipping installation."
148 | fi
149 | return 0
150 | }
151 |
152 | # Function to cleanup
153 | cleanup() {
154 | echo "Cleaning up..."
155 | apt-get clean
156 | rm -rf /var/lib/apt/lists/*
157 | }
158 |
159 | # Main installation process
160 | main() {
161 | # Install base dependencies (always)
162 | if $install_base; then
163 | install_system_dependencies || { echo "Failed to install system dependencies"; exit 1; }
164 | upgrade_tensorrt || { echo "Failed to upgrade TensorRT"; exit 1; }
165 | install_python_packages || { echo "Failed to install Python packages"; exit 1; }
166 | fi
167 |
168 | # Clone TensorRT repository once
169 | clone_tensorrt_repo || { echo "Failed to clone TensorRT repository"; exit 1; }
170 |
171 | # Install TREx by default unless --no-trex flag is provided
172 | if $install_trex; then
173 | install_trex_environment || { echo "Failed to install TREx environment"; exit 1; }
174 | fi
175 |
176 | # Always install PyTorch Quantization
177 | install_pytorch_quantization || { echo "Failed to install PyTorch Quantization"; exit 1; }
178 |
179 | # Final cleanup
180 | cleanup
181 |
182 | echo "Installation completed successfully."
183 | return 0
184 | }
185 |
186 | # Execute main function
187 | main
188 |
189 |
190 |
--------------------------------------------------------------------------------
/models/experimental_trt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class TRT_EfficientNMS_85(torch.autograd.Function):
5 | '''TensorRT NMS operation'''
6 | @staticmethod
7 | def forward(
8 | ctx,
9 | boxes,
10 | scores,
11 | background_class=-1,
12 | box_coding=1,
13 | iou_threshold=0.45,
14 | max_output_boxes=100,
15 | plugin_version="1",
16 | score_activation=0,
17 | score_threshold=0.25,
18 | ):
19 |
20 | batch_size, num_boxes, num_classes = scores.shape
21 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
22 | det_boxes = torch.randn(batch_size, max_output_boxes, 4)
23 | det_scores = torch.randn(batch_size, max_output_boxes)
24 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
25 | return num_det, det_boxes, det_scores, det_classes
26 |
27 | @staticmethod
28 | def symbolic(g,
29 | boxes,
30 | scores,
31 | background_class=-1,
32 | box_coding=1,
33 | iou_threshold=0.45,
34 | max_output_boxes=100,
35 | plugin_version="1",
36 | score_activation=0,
37 | score_threshold=0.25):
38 | out = g.op("TRT::EfficientNMS_TRT",
39 | boxes,
40 | scores,
41 | background_class_i=background_class,
42 | box_coding_i=box_coding,
43 | iou_threshold_f=iou_threshold,
44 | max_output_boxes_i=max_output_boxes,
45 | plugin_version_s=plugin_version,
46 | score_activation_i=score_activation,
47 | score_threshold_f=score_threshold,
48 | outputs=4)
49 | nums, boxes, scores, classes = out
50 | return nums, boxes, scores, classes
51 |
52 | class TRT_EfficientNMS(torch.autograd.Function):
53 | '''TensorRT NMS operation'''
54 | @staticmethod
55 | def forward(
56 | ctx,
57 | boxes,
58 | scores,
59 | background_class=-1,
60 | box_coding=1,
61 | iou_threshold=0.45,
62 | max_output_boxes=100,
63 | plugin_version="1",
64 | score_activation=0,
65 | score_threshold=0.25,
66 | class_agnostic=0,
67 | ):
68 |
69 | batch_size, num_boxes, num_classes = scores.shape
70 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
71 | det_boxes = torch.randn(batch_size, max_output_boxes, 4)
72 | det_scores = torch.randn(batch_size, max_output_boxes)
73 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
74 | return num_det, det_boxes, det_scores, det_classes
75 |
76 | @staticmethod
77 | def symbolic(g,
78 | boxes,
79 | scores,
80 | background_class=-1,
81 | box_coding=1,
82 | iou_threshold=0.45,
83 | max_output_boxes=100,
84 | plugin_version="1",
85 | score_activation=0,
86 | score_threshold=0.25,
87 | class_agnostic=0):
88 | out = g.op("TRT::EfficientNMS_TRT",
89 | boxes,
90 | scores,
91 | background_class_i=background_class,
92 | box_coding_i=box_coding,
93 | iou_threshold_f=iou_threshold,
94 | max_output_boxes_i=max_output_boxes,
95 | plugin_version_s=plugin_version,
96 | score_activation_i=score_activation,
97 | class_agnostic_i=class_agnostic,
98 | score_threshold_f=score_threshold,
99 | outputs=4)
100 | nums, boxes, scores, classes = out
101 | return nums, boxes, scores, classes
102 |
103 | class TRT_EfficientNMSX_85(torch.autograd.Function):
104 | '''TensorRT NMS operation'''
105 | @staticmethod
106 | def forward(
107 | ctx,
108 | boxes,
109 | scores,
110 | background_class=-1,
111 | box_coding=1,
112 | iou_threshold=0.45,
113 | max_output_boxes=100,
114 | plugin_version="1",
115 | score_activation=0,
116 | score_threshold=0.25
117 | ):
118 |
119 | batch_size, num_boxes, num_classes = scores.shape
120 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
121 | det_boxes = torch.randn(batch_size, max_output_boxes, 4)
122 | det_scores = torch.randn(batch_size, max_output_boxes)
123 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
124 | det_indices = torch.randint(0,num_boxes,(batch_size, max_output_boxes), dtype=torch.int32)
125 | return num_det, det_boxes, det_scores, det_classes, det_indices
126 |
127 | @staticmethod
128 | def symbolic(g,
129 | boxes,
130 | scores,
131 | background_class=-1,
132 | box_coding=1,
133 | iou_threshold=0.45,
134 | max_output_boxes=100,
135 | plugin_version="1",
136 | score_activation=0,
137 | score_threshold=0.25):
138 | out = g.op("TRT::EfficientNMSX_TRT",
139 | boxes,
140 | scores,
141 | background_class_i=background_class,
142 | box_coding_i=box_coding,
143 | iou_threshold_f=iou_threshold,
144 | max_output_boxes_i=max_output_boxes,
145 | plugin_version_s=plugin_version,
146 | score_activation_i=score_activation,
147 | score_threshold_f=score_threshold,
148 | outputs=5)
149 | nums, boxes, scores, classes, det_indices = out
150 | return nums, boxes, scores, classes, det_indices
151 |
152 | class TRT_EfficientNMSX(torch.autograd.Function):
153 | '''TensorRT NMS operation'''
154 | @staticmethod
155 | def forward(
156 | ctx,
157 | boxes,
158 | scores,
159 | background_class=-1,
160 | box_coding=1,
161 | iou_threshold=0.45,
162 | max_output_boxes=100,
163 | plugin_version="1",
164 | score_activation=0,
165 | score_threshold=0.25,
166 | class_agnostic=0,
167 | ):
168 |
169 | batch_size, num_boxes, num_classes = scores.shape
170 | num_det = torch.randint(0, max_output_boxes, (batch_size, 1), dtype=torch.int32)
171 | det_boxes = torch.randn(batch_size, max_output_boxes, 4)
172 | det_scores = torch.randn(batch_size, max_output_boxes)
173 | det_classes = torch.randint(0, num_classes, (batch_size, max_output_boxes), dtype=torch.int32)
174 | det_indices = torch.randint(0,num_boxes,(batch_size, max_output_boxes), dtype=torch.int32)
175 | return num_det, det_boxes, det_scores, det_classes, det_indices
176 |
177 | @staticmethod
178 | def symbolic(g,
179 | boxes,
180 | scores,
181 | background_class=-1,
182 | box_coding=1,
183 | iou_threshold=0.45,
184 | max_output_boxes=100,
185 | plugin_version="1",
186 | score_activation=0,
187 | score_threshold=0.25,
188 | class_agnostic=0):
189 | out = g.op("TRT::EfficientNMSX_TRT",
190 | boxes,
191 | scores,
192 | background_class_i=background_class,
193 | box_coding_i=box_coding,
194 | iou_threshold_f=iou_threshold,
195 | max_output_boxes_i=max_output_boxes,
196 | plugin_version_s=plugin_version,
197 | score_activation_i=score_activation,
198 | class_agnostic_i=class_agnostic,
199 | score_threshold_f=score_threshold,
200 | outputs=5)
201 | nums, boxes, scores, classes, det_indices = out
202 | return nums, boxes, scores, classes, det_indices
203 |
204 | class TRT_ROIAlign(torch.autograd.Function):
205 | @staticmethod
206 | def forward(
207 | ctx,
208 | X,
209 | rois,
210 | batch_indices,
211 | coordinate_transformation_mode= 1,
212 | mode=1, # 1- avg pooling / 0 - max pooling
213 | output_height=160,
214 | output_width=160,
215 | sampling_ratio=0,
216 | spatial_scale=0.25,
217 | ):
218 | device = rois.device
219 | dtype = rois.dtype
220 | N, C, H, W = X.shape
221 | num_rois = rois.shape[0]
222 | return torch.randn((num_rois, C, output_height, output_width), device=device, dtype=dtype)
223 |
224 | @staticmethod
225 | def symbolic(
226 | g,
227 | X,
228 | rois,
229 | batch_indices,
230 | coordinate_transformation_mode=1,
231 | mode=1,
232 | output_height=160,
233 | output_width=160,
234 | sampling_ratio=0,
235 | spatial_scale=0.25,
236 | ):
237 | return g.op(
238 | "TRT::ROIAlign_TRT",
239 | X,
240 | rois,
241 | batch_indices,
242 | coordinate_transformation_mode_i=coordinate_transformation_mode,
243 | mode_i=mode,
244 | output_height_i=output_height,
245 | output_width_i=output_width,
246 | sampling_ratio_i=sampling_ratio,
247 | spatial_scale_f=spatial_scale,
248 | )
249 |
250 | class ONNX_EfficientNMS_TRT(nn.Module):
251 | '''onnx module with TensorRT NMS operation.'''
252 | def __init__(self, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
253 | super().__init__()
254 | assert max_wh is None
255 | self.device = device if device else torch.device('cpu')
256 | self.class_agnostic = 1 if class_agnostic else 0
257 | self.background_class = -1,
258 | self.box_coding = 1,
259 | self.iou_threshold = iou_thres
260 | self.max_obj = max_obj
261 | self.plugin_version = '1'
262 | self.score_activation = 0
263 | self.score_threshold = score_thres
264 | self.n_classes=n_classes
265 |
266 |
267 | def forward(self, x):
268 | if isinstance(x, list):
269 | x = x[1]
270 | x = x.permute(0, 2, 1)
271 | bboxes_x = x[..., 0:1]
272 | bboxes_y = x[..., 1:2]
273 | bboxes_w = x[..., 2:3]
274 | bboxes_h = x[..., 3:4]
275 | bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
276 | bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
277 | obj_conf = x[..., 4:]
278 | scores = obj_conf
279 | if self.class_agnostic == 1:
280 | num_det, det_boxes, det_scores, det_classes = TRT_EfficientNMS.apply(bboxes, scores, self.background_class, self.box_coding,
281 | self.iou_threshold, self.max_obj,
282 | self.plugin_version, self.score_activation,
283 | self.score_threshold, self.class_agnostic)
284 | else:
285 | num_det, det_boxes, det_scores, det_classes = TRT_EfficientNMS_85.apply(bboxes, scores, self.background_class, self.box_coding,
286 | self.iou_threshold, self.max_obj,
287 | self.plugin_version, self.score_activation,
288 | self.score_threshold)
289 | return num_det, det_boxes, det_scores, det_classes
290 |
291 | class ONNX_EfficientNMSX_TRT(nn.Module):
292 | '''onnx module with TensorRT NMS operation.'''
293 | def __init__(self, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, max_wh=None ,device=None, n_classes=80):
294 | super().__init__()
295 | assert max_wh is None
296 | self.device = device if device else torch.device('cpu')
297 | self.class_agnostic = 1 if class_agnostic else 0
298 | self.background_class = -1,
299 | self.box_coding = 1,
300 | self.iou_threshold = iou_thres
301 | self.max_obj = max_obj
302 | self.plugin_version = '1'
303 | self.score_activation = 0
304 | self.score_threshold = score_thres
305 | self.n_classes=n_classes
306 |
307 |
308 | def forward(self, x):
309 | if isinstance(x, list):
310 | x = x[1]
311 | x = x.permute(0, 2, 1)
312 | bboxes_x = x[..., 0:1]
313 | bboxes_y = x[..., 1:2]
314 | bboxes_w = x[..., 2:3]
315 | bboxes_h = x[..., 3:4]
316 | bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
317 | bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
318 | obj_conf = x[..., 4:]
319 | scores = obj_conf
320 | if self.class_agnostic == 1:
321 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX.apply(bboxes, scores, self.background_class, self.box_coding,
322 | self.iou_threshold, self.max_obj,
323 | self.plugin_version, self.score_activation,
324 | self.score_threshold, self.class_agnostic)
325 | else:
326 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX_85.apply(bboxes, scores, self.background_class, self.box_coding,
327 | self.iou_threshold, self.max_obj,
328 | self.plugin_version, self.score_activation,
329 | self.score_threshold)
330 | return num_det, det_boxes, det_scores, det_classes, det_indices
331 |
332 |
333 |
334 | class End2End_TRT(nn.Module):
335 | '''export onnx or tensorrt model with NMS operation.'''
336 | def __init__(self, model, class_agnostic=False, max_obj=100, iou_thres=0.45, score_thres=0.25, mask_resolution=56, pooler_scale=0.25, sampling_ratio=0, max_wh=None, device=None, n_classes=80, is_det_model=True):
337 | super().__init__()
338 | device = device if device else torch.device('cpu')
339 | assert isinstance(max_wh,(int)) or max_wh is None
340 | self.model = model.to(device)
341 | self.model.model[-1].end2end = True
342 | if is_det_model:
343 | self.patch_model = ONNX_EfficientNMS_TRT
344 | self.end2end = self.patch_model(class_agnostic, max_obj, iou_thres, score_thres, max_wh, device, n_classes)
345 | else:
346 | self.patch_model = ONNX_End2End_MASK_TRT
347 | self.end2end = self.patch_model(class_agnostic, max_obj, iou_thres, score_thres, mask_resolution, pooler_scale, sampling_ratio, max_wh, device, n_classes)
348 | self.end2end.eval()
349 |
350 | def forward(self, x):
351 | x = self.model(x)
352 | x = self.end2end(x)
353 | return x
354 |
355 |
356 | class ONNX_End2End_MASK_TRT(nn.Module):
357 | """onnx module with ONNX-TensorRT NMS/ROIAlign operation."""
358 | def __init__(
359 | self,
360 | class_agnostic=False,
361 | max_obj=100,
362 | iou_thres=0.45,
363 | score_thres=0.25,
364 | mask_resolution=160,
365 | pooler_scale=0.25,
366 | sampling_ratio=0,
367 | max_wh=None,
368 | device=None,
369 | n_classes=80
370 | ):
371 | super().__init__()
372 | assert isinstance(max_wh,(int)) or max_wh is None
373 | self.device = device if device else torch.device('cpu')
374 | self.class_agnostic = 1 if class_agnostic else 0
375 | self.max_obj = max_obj
376 | self.background_class = -1,
377 | self.box_coding = 1,
378 | self.iou_threshold = iou_thres
379 | self.max_obj = max_obj
380 | self.plugin_version = '1'
381 | self.score_activation = 0
382 | self.score_threshold = score_thres
383 | self.n_classes=n_classes
384 | self.mask_resolution = mask_resolution
385 | self.pooler_scale = pooler_scale
386 | self.sampling_ratio = sampling_ratio
387 |
388 | def forward(self, x):
389 | if isinstance(x, list): ## remove auxiliary branch
390 | x = x[1]
391 | det=x[0]
392 | proto=x[1]
393 | det = det.permute(0, 2, 1)
394 |
395 | bboxes_x = det[..., 0:1]
396 | bboxes_y = det[..., 1:2]
397 | bboxes_w = det[..., 2:3]
398 | bboxes_h = det[..., 3:4]
399 | bboxes = torch.cat([bboxes_x, bboxes_y, bboxes_w, bboxes_h], dim = -1)
400 | bboxes = bboxes.unsqueeze(2) # [n_batch, n_bboxes, 4] -> [n_batch, n_bboxes, 1, 4]
401 | scores = det[..., 4: 4 + self.n_classes]
402 |
403 | batch_size, nm, proto_h, proto_w = proto.shape
404 | total_object = batch_size * self.max_obj
405 | masks = det[..., 4 + self.n_classes : 4 + self.n_classes + nm]
406 | if self.class_agnostic == 1:
407 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX.apply(bboxes, scores, self.background_class, self.box_coding,
408 | self.iou_threshold, self.max_obj,
409 | self.plugin_version, self.score_activation,
410 | self.score_threshold,self.class_agnostic)
411 | else:
412 | num_det, det_boxes, det_scores, det_classes, det_indices = TRT_EfficientNMSX_85.apply(bboxes, scores, self.background_class, self.box_coding,
413 | self.iou_threshold, self.max_obj,
414 | self.plugin_version, self.score_activation,
415 | self.score_threshold)
416 |
417 | batch_indices = torch.ones_like(det_indices) * torch.arange(batch_size, device=self.device, dtype=torch.int32).unsqueeze(1)
418 | batch_indices = batch_indices.view(total_object).to(torch.long)
419 | det_indices = det_indices.view(total_object).to(torch.long)
420 | det_masks = masks[batch_indices, det_indices]
421 |
422 |
423 | pooled_proto = TRT_ROIAlign.apply( proto,
424 | det_boxes.view(total_object, 4),
425 | batch_indices,
426 | 1,
427 | 1,
428 | self.mask_resolution,
429 | self.mask_resolution,
430 | self.sampling_ratio,
431 | self.pooler_scale
432 | )
433 | pooled_proto = pooled_proto.view(
434 | total_object, nm, self.mask_resolution * self.mask_resolution,
435 | )
436 |
437 | det_masks = (
438 | torch.matmul(det_masks.unsqueeze(dim=1), pooled_proto)
439 | .sigmoid()
440 | .view(batch_size, self.max_obj, self.mask_resolution * self.mask_resolution)
441 | )
442 |
443 | return num_det, det_boxes, det_scores, det_classes, det_masks
--------------------------------------------------------------------------------
/models/quantize.py:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: MIT
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a
6 | # copy of this software and associated documentation files (the "Software"),
7 | # to deal in the Software without restriction, including without limitation
8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 | # and/or sell copies of the Software, and to permit persons to whom the
10 | # Software is furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in
13 | # all copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 | # DEALINGS IN THE SOFTWARE.
22 | ################################################################################
23 |
24 |
25 | import os
26 | import re
27 | from typing import List, Callable, Union, Dict
28 | from tqdm import tqdm
29 | from copy import deepcopy
30 |
31 | # PyTorch
32 | import torch
33 | import torch.optim as optim
34 | from torch.cuda import amp
35 | import torch.nn.functional as F
36 |
37 | # Pytorch Quantization
38 | from pytorch_quantization import nn as quant_nn
39 | from pytorch_quantization.nn.modules import _utils as quant_nn_utils
40 | from pytorch_quantization import calib
41 | from pytorch_quantization.tensor_quant import QuantDescriptor
42 | from pytorch_quantization import quant_modules
43 | from absl import logging as quant_logging
44 |
45 | import onnx_graphsurgeon as gs
46 | from utils.general import (check_requirements, LOGGER,colorstr)
47 | from models.quantize_rules import find_quantizer_pairs
48 |
49 |
50 | class QuantAdd(torch.nn.Module, quant_nn_utils.QuantMixin):
51 | def __init__(self, quantization):
52 | super().__init__()
53 |
54 | if quantization:
55 | self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
56 | self._input1_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
57 | self.quantization = quantization
58 |
59 | def forward(self, x, y):
60 | if self.quantization:
61 | return self._input0_quantizer(x) + self._input1_quantizer(y)
62 | return x + y
63 |
64 |
65 | class QuantADownAvgChunk(torch.nn.Module):
66 | def __init__(self):
67 | super().__init__()
68 | self._chunk_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
69 | self._chunk_quantizer._calibrator._torch_hist = True
70 | self.avg_pool2d = torch.nn.AvgPool2d(2, 1, 0, False, True)
71 |
72 | def forward(self, x):
73 | x = self.avg_pool2d(x)
74 | x = self._chunk_quantizer(x)
75 | return x.chunk(2, 1)
76 |
77 | class QuantAConvAvgChunk(torch.nn.Module):
78 | def __init__(self):
79 | super().__init__()
80 | self._chunk_quantizer = quant_nn.TensorQuantizer(QuantDescriptor(num_bits=8, calib_method="histogram"))
81 | self._chunk_quantizer._calibrator._torch_hist = True
82 | self.avg_pool2d = torch.nn.AvgPool2d(2, 1, 0, False, True)
83 |
84 | def forward(self, x):
85 | x = self.avg_pool2d(x)
86 | x = self._chunk_quantizer(x)
87 | return x
88 |
89 | class QuantRepNCSPELAN4Chunk(torch.nn.Module):
90 | def __init__(self, c):
91 | super().__init__()
92 | self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
93 | self.c = c
94 | def forward(self, x, chunks, dims):
95 | return torch.split(self._input0_quantizer(x), (self.c, self.c), dims)
96 |
97 | class QuantUpsample(torch.nn.Module):
98 | def __init__(self, size, scale_factor, mode):
99 | super().__init__()
100 | self.size = size
101 | self.scale_factor = scale_factor
102 | self.mode = mode
103 | self._input_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
104 |
105 | def forward(self, x):
106 | return F.interpolate(self._input_quantizer(x), self.size, self.scale_factor, self.mode)
107 |
108 |
109 | class QuantConcat(torch.nn.Module):
110 | def __init__(self, dim):
111 | super().__init__()
112 | self._input0_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
113 | self._input1_quantizer = quant_nn.TensorQuantizer(QuantDescriptor())
114 | self.dim = dim
115 |
116 | def forward(self, x, dim):
117 | x_0 = self._input0_quantizer(x[0])
118 | x_1 = self._input1_quantizer(x[1])
119 | return torch.cat((x_0, x_1), self.dim)
120 |
121 |
122 | class disable_quantization:
123 | def __init__(self, model):
124 | self.model = model
125 |
126 | def apply(self, disabled=True):
127 | for name, module in self.model.named_modules():
128 | if isinstance(module, quant_nn.TensorQuantizer):
129 | module._disabled = disabled
130 |
131 | def __enter__(self):
132 | self.apply(True)
133 |
134 | def __exit__(self, *args, **kwargs):
135 | self.apply(False)
136 |
137 |
138 | class enable_quantization:
139 | def __init__(self, model):
140 | self.model = model
141 |
142 | def apply(self, enabled=True):
143 | for name, module in self.model.named_modules():
144 | if isinstance(module, quant_nn.TensorQuantizer):
145 | module._disabled = not enabled
146 |
147 | def __enter__(self):
148 | self.apply(True)
149 | return self
150 |
151 | def __exit__(self, *args, **kwargs):
152 | self.apply(False)
153 |
154 |
155 | def have_quantizer(module):
156 | for name, module in module.named_modules():
157 | if isinstance(module, quant_nn.TensorQuantizer):
158 | return True
159 |
160 |
161 | # Initialize PyTorch Quantization
162 | def initialize():
163 | quant_modules.initialize( )
164 | quant_desc_input = QuantDescriptor(calib_method="histogram")
165 | quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)
166 | quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)
167 | quant_nn.QuantAvgPool2d.set_default_quant_desc_input(quant_desc_input)
168 | quant_nn.QuantMaxPool2d.set_default_quant_desc_input(quant_desc_input)
169 |
170 | quant_logging.set_verbosity(quant_logging.ERROR)
171 |
172 |
173 | def remove_redundant_qdq_model(onnx_model, f):
174 | check_requirements('onnx')
175 | import onnx
176 |
177 | graph = gs.import_onnx(onnx_model)
178 | nodes = graph.nodes
179 |
180 | mul_nodes = [node for node in nodes if node.op == "Mul" and node.i(0).op == "Conv" and node.i(1).op == "Sigmoid"]
181 | many_outputs_mul_nodes = []
182 |
183 | for node in mul_nodes:
184 | try:
185 | for i in range(99):
186 | node.o(i)
187 | except:
188 | if i > 1:
189 | mul_nodename_outnum = {"node": node, "out_num": i}
190 | many_outputs_mul_nodes.append(mul_nodename_outnum)
191 |
192 | for node_dict in many_outputs_mul_nodes:
193 | if node_dict["out_num"] == 2:
194 | if node_dict["node"].o(0).op == "QuantizeLinear" and node_dict["node"].o(1).op == "QuantizeLinear":
195 | if node_dict["node"].o(1).o(0).o(0).op == "Concat":
196 | concat_dq_out_name = node_dict["node"].o(1).o(0).outputs[0].name
197 | for i, concat_input in enumerate(node_dict["node"].o(1).o(0).o(0).inputs):
198 | if concat_input.name == concat_dq_out_name:
199 | node_dict["node"].o(1).o(0).o(0).inputs[i] = node_dict["node"].o(0).o(0).outputs[0]
200 | else:
201 | node_dict["node"].o(1).o(0).o(0).inputs[0] = node_dict["node"].o(0).o(0).outputs[0]
202 |
203 |
204 | # elif node_dict["node"].o(0).op == "QuantizeLinear" and node_dict["node"].o(1).op == "Concat":
205 | # concat_dq_out_name = node_dict["node"].outputs[0].outputs[0].inputs[0].name
206 | # for i, concat_input in enumerate(node_dict["node"].outputs[0].outputs[1].inputs):
207 | # if concat_input.name == concat_dq_out_name:
208 | # #print("elif", concat_input.name, concat_dq_out_name )
209 | # #print("will-be", node_dict["node"].outputs[0].outputs[1].inputs[i], node_dict["node"].outputs[0].outputs[0].o().outputs[0] )
210 | # node_dict["node"].outputs[0].outputs[1].inputs[i] = node_dict["node"].outputs[0].outputs[0].o().outputs[0]
211 |
212 |
213 | # add_nodes = [node for node in nodes if node.op == "Add"]
214 | # many_outputs_add_nodes = []
215 | # for node in add_nodes:
216 | # try:
217 | # for i in range(99):
218 | # node.o(i)
219 | # except:
220 | # if i > 1 and node.o().op == "QuantizeLinear":
221 | # add_nodename_outnum = {"node": node, "out_num": i}
222 | # many_outputs_add_nodes.append(add_nodename_outnum)
223 |
224 |
225 | # for node_dict in many_outputs_add_nodes:
226 | # if node_dict["node"].outputs[0].outputs[0].op == "QuantizeLinear" and node_dict["node"].outputs[0].outputs[1].op == "Concat":
227 | # concat_dq_out_name = node_dict["node"].outputs[0].outputs[0].inputs[0].name
228 | # for i, concat_input in enumerate(node_dict["node"].outputs[0].outputs[1].inputs):
229 | # if concat_input.name == concat_dq_out_name:
230 | # node_dict["node"].outputs[0].outputs[1].inputs[i] = node_dict["node"].outputs[0].outputs[0].o().outputs[0]
231 |
232 | onnx.save(gs.export_onnx(graph), f)
233 |
234 |
235 | def transfer_torch_to_quantization(nninstance : torch.nn.Module, quantmodule):
236 | quant_instance = quantmodule.__new__(quantmodule)
237 | for k, val in vars(nninstance).items():
238 | setattr(quant_instance, k, val)
239 |
240 | def __init__(self):
241 | if self.__class__.__name__ == 'QuantAvgPool2d':
242 | self.__init__(nninstance.kernel_size, nninstance.stride, nninstance.padding, nninstance.ceil_mode, nninstance.count_include_pad)
243 | elif isinstance(self, quant_nn_utils.QuantInputMixin):
244 | quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__)
245 | self.init_quantizer(quant_desc_input)
246 |
247 | # Turn on torch_hist to enable higher calibration speeds
248 | if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
249 | self._input_quantizer._calibrator._torch_hist = True
250 | else:
251 | quant_desc_input, quant_desc_weight = quant_nn_utils.pop_quant_desc_in_kwargs(self.__class__)
252 | self.init_quantizer(quant_desc_input, quant_desc_weight)
253 | # Turn on torch_hist to enable higher calibration speeds
254 | if isinstance(self._input_quantizer._calibrator, calib.HistogramCalibrator):
255 | self._input_quantizer._calibrator._torch_hist = True
256 | self._weight_quantizer._calibrator._torch_hist = True
257 |
258 | __init__(quant_instance)
259 | return quant_instance
260 |
261 |
262 | def quantization_ignore_match(ignore_policy : Union[str, List[str], Callable], path : str) -> bool:
263 |
264 | if ignore_policy is None: return False
265 | if isinstance(ignore_policy, Callable):
266 | return ignore_policy(path)
267 |
268 | if isinstance(ignore_policy, str) or isinstance(ignore_policy, List):
269 |
270 | if isinstance(ignore_policy, str):
271 | ignore_policy = [ignore_policy]
272 |
273 | if path in ignore_policy: return True
274 | for item in ignore_policy:
275 | if re.match(item, path):
276 | return True
277 | return False
278 |
279 |
280 | def set_module(model, submodule_key, module):
281 | tokens = submodule_key.split('.')
282 | sub_tokens = tokens[:-1]
283 | cur_mod = model
284 | for s in sub_tokens:
285 | cur_mod = getattr(cur_mod, s)
286 | setattr(cur_mod, tokens[-1], module)
287 |
288 |
289 | def replace_to_quantization_module(model : torch.nn.Module, ignore_policy : Union[str, List[str], Callable] = None, prefixx=colorstr('QAT:')):
290 |
291 | module_dict = {}
292 | for entry in quant_modules._DEFAULT_QUANT_MAP:
293 | module = getattr(entry.orig_mod, entry.mod_name)
294 | module_dict[id(module)] = entry.replace_mod
295 |
296 | def recursive_and_replace_module(module, prefix=""):
297 | for name in module._modules:
298 | submodule = module._modules[name]
299 | path = name if prefix == "" else prefix + "." + name
300 | recursive_and_replace_module(submodule, path)
301 |
302 | submodule_id = id(type(submodule))
303 | if submodule_id in module_dict:
304 | ignored = quantization_ignore_match(ignore_policy, path)
305 | if ignored:
306 | LOGGER.info(f'{prefixx} Quantization: {path} has ignored.')
307 | continue
308 |
309 | module._modules[name] = transfer_torch_to_quantization(submodule, module_dict[submodule_id])
310 |
311 | recursive_and_replace_module(model)
312 |
313 |
314 | def get_attr_with_path(m, path):
315 | def sub_attr(m, names):
316 | name = names[0]
317 | value = getattr(m, name)
318 | if len(names) == 1:
319 | return value
320 | return sub_attr(value, names[1:])
321 | return sub_attr(m, path.split("."))
322 |
323 | def repncspelan4_qaunt_forward(self, x):
324 | if hasattr(self, "repncspelan4chunkop"):
325 | y = list(self.repncspelan4chunkop(self.cv1(x), 2, 1))
326 | y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
327 | return self.cv4(torch.cat(y, 1))
328 | else:
329 | y = list(self.cv1(x).split((self.c, self.c), 1))
330 | y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
331 | return self.cv4(torch.cat(y, 1))
332 |
333 | def repbottleneck_quant_forward(self, x):
334 | if hasattr(self, "addop"):
335 | return self.addop(x, self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x))
336 | return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
337 |
338 | def upsample_quant_forward(self, x):
339 | if hasattr(self, "upsampleop"):
340 | return self.upsampleop(x)
341 | return F.interpolate(x)
342 |
343 | def concat_quant_forward(self, x):
344 | if hasattr(self, "concatop"):
345 | return self.concatop(x, self.d)
346 | return torch.cat(x, self.d)
347 |
348 | def adown_quant_forward(self, x):
349 | if hasattr(self, "adownchunkop"):
350 | x1, x2 = self.adownchunkop(x)
351 | x1 = self.cv1(x1)
352 | x2 = torch.nn.functional.max_pool2d(x2, 3, 2, 1)
353 | x2 = self.cv2(x2)
354 | return torch.cat((x1, x2), 1)
355 |
356 | def aconv_quant_forward(self, x):
357 | if hasattr(self, "aconvchunkop"):
358 | x = self.aconvchunkop(x)
359 | return self.cv1(x)
360 |
361 | def apply_custom_rules_to_quantizer(model : torch.nn.Module, export_onnx : Callable):
362 | export_onnx(model, "quantization-custom-rules-temp.onnx")
363 | pairs = find_quantizer_pairs("quantization-custom-rules-temp.onnx")
364 | for major, sub in pairs:
365 | print(f"Rules: {sub} match to {major}")
366 | get_attr_with_path(model, sub)._input_quantizer = get_attr_with_path(model, major)._input_quantizer
367 | os.remove("quantization-custom-rules-temp.onnx")
368 |
369 | for name, module in model.named_modules():
370 | if module.__class__.__name__ == "RepNBottleneck":
371 | if module.add:
372 | print(f"Rules: {name}.add match to {name}.cv1")
373 | major = module.cv1.conv._input_quantizer
374 | module.addop._input0_quantizer = major
375 | module.addop._input1_quantizer = major
376 |
377 | if isinstance(module, torch.nn.MaxPool2d):
378 | quant_conv_desc_input = QuantDescriptor(num_bits=8, calib_method='histogram')
379 | quant_maxpool2d = quant_nn.QuantMaxPool2d(module.kernel_size,
380 | module.stride,
381 | module.padding,
382 | module.dilation,
383 | module.ceil_mode,
384 | quant_desc_input = quant_conv_desc_input)
385 | set_module(model, name, quant_maxpool2d)
386 |
387 | if module.__class__.__name__ == 'ADown':
388 | module.cv1.conv._input_quantizer = module.adownchunkop._chunk_quantizer
389 | if module.__class__.__name__ == 'AConv':
390 | module.cv1.conv._input_quantizer = module.aconvchunkop._chunk_quantizer
391 |
392 | def replace_custom_module_forward(model):
393 | for name, module in model.named_modules():
394 | # if module.__class__.__name__ == "RepNCSPELAN4":
395 | # if not hasattr(module, "repncspelan4chunkop"):
396 | # print(f"Add RepNCSPELAN4QuantChunk to {name}")
397 | # module.repncspelan4chunkop = QuantRepNCSPELAN4Chunk(module.c)
398 | # module.__class__.forward = repncspelan4_qaunt_forward
399 |
400 | if module.__class__.__name__ == "ADown":
401 | if not hasattr(module, "adownchunkop"):
402 | print(f"Add ADownQuantChunk to {name}")
403 | module.adownchunkop = QuantADownAvgChunk()
404 | module.__class__.forward = adown_quant_forward
405 |
406 | if module.__class__.__name__ == "AConv":
407 | if not hasattr(module, "aconvchunkop"):
408 | print(f"Add AConvQuantChunk to {name}")
409 | module.aconvchunkop = QuantAConvAvgChunk()
410 | module.__class__.forward = aconv_quant_forward
411 |
412 | if module.__class__.__name__ == "RepNBottleneck":
413 | if module.add:
414 | if not hasattr(module, "addop"):
415 | print(f"Add QuantAdd to {name}")
416 | module.addop = QuantAdd(module.add)
417 | module.__class__.forward = repbottleneck_quant_forward
418 |
419 | if module.__class__.__name__ == "Concat":
420 | if not hasattr(module, "concatop"):
421 | print(f"Add QuantConcat to {name}")
422 | module.concatop = QuantConcat(module.d)
423 | module.__class__.forward = concat_quant_forward
424 |
425 | if module.__class__.__name__ == "Upsample":
426 | if not hasattr(module, "upsampleop"):
427 | print(f"Add QuantUpsample to {name}")
428 | module.upsampleop = QuantUpsample(module.size, module.scale_factor, module.mode)
429 | module.__class__.forward = upsample_quant_forward
430 |
431 | def calibrate_model(model : torch.nn.Module, dataloader, device, num_batch=25):
432 |
433 | def compute_amax(model, **kwargs):
434 | for name, module in model.named_modules():
435 | if isinstance(module, quant_nn.TensorQuantizer):
436 | if module._calibrator is not None:
437 | if isinstance(module._calibrator, calib.MaxCalibrator):
438 | module.load_calib_amax()
439 | else:
440 | module.load_calib_amax(**kwargs)
441 |
442 | module._amax = module._amax.to(device)
443 |
444 | def collect_stats(model, data_loader, device, num_batch=200):
445 | """Feed data to the network and collect statistics"""
446 | # Enable calibrators
447 | model.eval()
448 | for name, module in model.named_modules():
449 |
450 | if isinstance(module, quant_nn.TensorQuantizer):
451 | if module._calibrator is not None:
452 | module.disable_quant()
453 | module.enable_calib()
454 | else:
455 | module.disable()
456 |
457 | # Feed data to the network for collecting stats
458 | with torch.no_grad():
459 | for i, datas in tqdm(enumerate(data_loader), total=num_batch, desc="Collect stats for calibrating"):
460 | imgs = datas[0].to(device, non_blocking=True).float() / 255.0
461 | model(imgs)
462 |
463 | if i >= num_batch:
464 | break
465 |
466 | # Disable calibrators
467 | for name, module in model.named_modules():
468 | if isinstance(module, quant_nn.TensorQuantizer):
469 | if module._calibrator is not None:
470 | module.enable_quant()
471 | module.disable_calib()
472 | else:
473 | module.enable()
474 |
475 | with torch.no_grad():
476 | collect_stats(model, dataloader, device, num_batch=num_batch)
477 | #compute_amax(model, method="percentile", percentile=99.99, strict=True) # strict=False avoid Exception when some quantizer are never used
478 | compute_amax(model, method="mse")
479 |
480 |
481 |
482 | def finetune(
483 | model : torch.nn.Module, train_dataloader, per_epoch_callback : Callable = None, preprocess : Callable = None,
484 | nepochs=10, early_exit_batchs_per_epoch=1000, lrschedule : Dict = None, fp16=True, learningrate=1e-5,
485 | supervision_policy : Callable = None, prefix=colorstr('QAT:')
486 | ):
487 | origin_model = deepcopy(model).eval()
488 | disable_quantization(origin_model).apply()
489 |
490 | model.train()
491 | model.requires_grad_(True)
492 |
493 | scaler = amp.GradScaler(enabled=fp16)
494 | optimizer = optim.Adam(model.parameters(), learningrate)
495 | quant_lossfn = torch.nn.MSELoss()
496 | device = next(model.parameters()).device
497 |
498 |
499 | if lrschedule is None:
500 | lrschedule = {
501 | 0: 1e-6,
502 | 6: 1e-5,
503 | 7: 1e-6
504 | }
505 |
506 |
507 | def make_layer_forward_hook(l):
508 | def forward_hook(m, input, output):
509 | l.append(output)
510 | return forward_hook
511 |
512 | supervision_module_pairs = []
513 | for ((mname, ml), (oriname, ori)) in zip(model.named_modules(), origin_model.named_modules()):
514 | if isinstance(ml, quant_nn.TensorQuantizer): continue
515 |
516 | if supervision_policy:
517 | if not supervision_policy(mname, ml):
518 | continue
519 |
520 | supervision_module_pairs.append([ml, ori])
521 |
522 |
523 | for iepoch in range(nepochs):
524 |
525 | if iepoch in lrschedule:
526 | learningrate = lrschedule[iepoch]
527 | for g in optimizer.param_groups:
528 | g["lr"] = learningrate
529 |
530 | model_outputs = []
531 | origin_outputs = []
532 | remove_handle = []
533 |
534 |
535 |
536 | for ml, ori in supervision_module_pairs:
537 | remove_handle.append(ml.register_forward_hook(make_layer_forward_hook(model_outputs)))
538 | remove_handle.append(ori.register_forward_hook(make_layer_forward_hook(origin_outputs)))
539 |
540 | model.train()
541 | pbar = tqdm(train_dataloader, desc="QAT", total=early_exit_batchs_per_epoch)
542 | for ibatch, imgs in enumerate(pbar):
543 |
544 | if ibatch >= early_exit_batchs_per_epoch:
545 | break
546 |
547 | if preprocess:
548 | imgs = preprocess(imgs)
549 |
550 |
551 | imgs = imgs.to(device)
552 | with amp.autocast(enabled=fp16):
553 | model(imgs)
554 |
555 | with torch.no_grad():
556 | origin_model(imgs)
557 |
558 | quant_loss = 0
559 | for mo, fo in zip(model_outputs, origin_outputs):
560 | for m, f in zip(mo, fo):
561 | quant_loss += quant_lossfn(m, f)
562 |
563 | model_outputs.clear()
564 | origin_outputs.clear()
565 |
566 | if fp16:
567 | scaler.scale(quant_loss).backward()
568 | scaler.step(optimizer)
569 | scaler.update()
570 | else:
571 | quant_loss.backward()
572 | optimizer.step()
573 | optimizer.zero_grad()
574 | pbar.set_description(f"QAT Finetuning {iepoch + 1} / {nepochs}, Loss: {quant_loss.detach().item():.5f}, LR: {learningrate:g}")
575 |
576 | # You must remove hooks during onnx export or torch.save
577 | for rm in remove_handle:
578 | rm.remove()
579 |
580 | if per_epoch_callback:
581 | if per_epoch_callback(model, iepoch, learningrate):
582 | break
583 |
584 |
585 | def export_onnx(model, input, file, *args, **kwargs):
586 | quant_nn.TensorQuantizer.use_fb_fake_quant = True
587 |
588 | model.eval()
589 | with torch.no_grad():
590 | torch.onnx.export(model, input, file, *args, **kwargs)
591 |
592 | quant_nn.TensorQuantizer.use_fb_fake_quant = False
593 |
--------------------------------------------------------------------------------
/models/quantize_rules.py:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: MIT
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a
6 | # copy of this software and associated documentation files (the "Software"),
7 | # to deal in the Software without restriction, including without limitation
8 | # the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 | # and/or sell copies of the Software, and to permit persons to whom the
10 | # Software is furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in
13 | # all copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 | # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 | # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 | # DEALINGS IN THE SOFTWARE.
22 | ################################################################################
23 |
24 | import onnx
25 |
26 | def find_with_input_node(model, name):
27 | for node in model.graph.node:
28 | if len(node.input) > 0 and name in node.input:
29 | return node
30 |
31 | def find_all_with_input_node(model, name):
32 | all = []
33 | for node in model.graph.node:
34 | if len(node.input) > 0 and name in node.input:
35 | all.append(node)
36 | return all
37 |
38 | def find_with_output_node(model, name):
39 | for node in model.graph.node:
40 | if len(node.output) > 0 and name in node.output:
41 | return node
42 |
43 | """
44 | def find_with_no_change_parent_node(model, node):
45 | parent = find_with_output_node(model, node.input[0])
46 | if parent is not None:
47 | print("Parent:", parent.op_type)
48 | if parent.op_type in ["Concat", "MaxPool", "AveragePool", "Slice"]:
49 | return find_with_no_change_parent_node(model, parent)
50 | return parent
51 | """
52 |
53 | def find_quantizelinear_conv(model, qnode):
54 | dq = find_with_input_node(model, qnode.output[0])
55 | conv = find_with_input_node(model, dq.output[0])
56 | return conv
57 |
58 |
59 | def find_quantize_conv_name(model, weight_qname):
60 | dq = find_with_output_node(model, weight_qname)
61 | q = find_with_output_node(model, dq.input[0])
62 | return ".".join(q.input[0].split(".")[:-1])
63 |
64 | def find_quantizer_pairs(onnx_file):
65 |
66 | model = onnx.load(onnx_file)
67 | match_pairs = []
68 | for node in model.graph.node:
69 | if node.op_type == "Concat":
70 | qnodes = find_all_with_input_node(model, node.output[0])
71 | major = None
72 | for qnode in qnodes:
73 | if qnode.op_type != "QuantizeLinear":
74 | continue
75 | conv = find_quantizelinear_conv(model, qnode)
76 | if major is None:
77 | major = find_quantize_conv_name(model, conv.input[1])
78 | else:
79 | match_pairs.append([major, find_quantize_conv_name(model, conv.input[1])])
80 |
81 | for subnode in model.graph.node:
82 | if len(subnode.input) > 0 and subnode.op_type == "QuantizeLinear" and subnode.input[0] in node.input:
83 | subconv = find_quantizelinear_conv(model, subnode)
84 | match_pairs.append([major, find_quantize_conv_name(model, subconv.input[1])])
85 |
86 |
87 | elif node.op_type == "MaxPool":
88 | qnode = find_with_input_node(model, node.output[0])
89 | if not (qnode and qnode.op_type == "QuantizeLinear"):
90 | continue
91 | major = find_quantizelinear_conv(model, qnode)
92 | major = find_quantize_conv_name(model, major.input[1])
93 | same_input_nodes = find_all_with_input_node(model, node.input[0])
94 |
95 | for same_input_node in same_input_nodes:
96 | if same_input_node.op_type == "QuantizeLinear":
97 | subconv = find_quantizelinear_conv(model, same_input_node)
98 | match_pairs.append([major, find_quantize_conv_name(model, subconv.input[1])])
99 |
100 | return match_pairs
101 |
--------------------------------------------------------------------------------
/patch_yolov9.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Check if the correct number of arguments were provided
4 | if [ "$#" -ne 1 ]; then
5 | echo "Usage: $0 "
6 | exit 1
7 | fi
8 |
9 | # YOLOv9 directory
10 | yolov9_dir="$1"
11 |
12 | # Check if the YOLOv9 directory exists
13 | if [ ! -d "$yolov9_dir" ]; then
14 | echo "Error: Directory '$yolov9_dir' not found."
15 | exit 1
16 | fi
17 |
18 | if [ ! -f "$yolov9_dir/models/experimental.py" ]; then
19 | echo "Error: '$yolov9_dir' does not appear to contain a valid YOLOv9 repository."
20 | echo "Please make sure '$yolov9_dir' is the root directory of a YOLOv9 repository."
21 | exit 1
22 | fi
23 |
24 | # Copy files to the YOLOv9 directory
25 | cp val_trt.py "$yolov9_dir/val_trt.py" && echo "qat.py patched successfully."
26 | cp qat.py "$yolov9_dir/qat.py" && echo "qat.py patched successfully."
27 | cp export_qat.py "$yolov9_dir/export_qat.py" && echo "export_qat.py patched successfully."
28 | cp models/quantize_rules.py "$yolov9_dir/models/quantize_rules.py" && echo "quantize_rules.py patched successfully."
29 | cp models/quantize.py "$yolov9_dir/models/quantize.py" && echo "quantize.py patched successfully."
30 | cp scripts/generate_trt_engine.sh "$yolov9_dir/scripts/generate_trt_engine.sh" && echo "generate_trt_engine.sh patched successfully."
31 | cp scripts/val_trt.sh "$yolov9_dir/scripts/val_trt.sh" && echo "val_trt.sh patched successfully."
32 | cp draw-engine.py "$yolov9_dir/draw-engine.py" && echo "draw-engine.py patched successfully."
33 |
34 | cp models/experimental_trt.py "$yolov9_dir/models/experimental_trt.py" && echo "experimental_trt.py patched successfully."
35 | cp segment/qat_seg.py "$yolov9_dir/segment/qat_seg.py" && echo "qat_seg.py patched successfully."
36 |
37 | echo "Patch applied successfully to YOLOv9 directory: $yolov9_dir"
38 |
--------------------------------------------------------------------------------
/qat.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | import yaml
5 | import argparse
6 | import json
7 | from copy import deepcopy
8 | from pathlib import Path
9 | import warnings
10 |
11 | # PyTorch
12 | import torch
13 | import torch.nn as nn
14 |
15 | import val as validate
16 | from models.yolo import Model
17 | from models.common import Conv
18 | from utils.dataloaders import create_dataloader
19 | from utils.downloads import attempt_download
20 |
21 | from models.yolo import Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel
22 | import models.quantize as quantize
23 |
24 | from utils.general import (LOGGER, check_dataset, check_requirements, check_img_size, colorstr, init_seeds,increment_path,file_size)
25 | from utils.torch_utils import (torch_distributed_zero_first)
26 |
27 | warnings.filterwarnings("ignore")
28 |
29 | FILE = Path(__file__).resolve()
30 | ROOT = FILE.parents[0] # YOLO root directory
31 | if str(ROOT) not in sys.path:
32 | sys.path.append(str(ROOT)) # add ROOT to PATH
33 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
34 |
35 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
36 | RANK = int(os.getenv('RANK', -1))
37 | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
38 | GIT_INFO = None
39 |
40 |
41 | class ReportTool:
42 | def __init__(self, file):
43 | self.file = file
44 | if os.path.exists(self.file):
45 | open(self.file, 'w').close()
46 | self.data = []
47 |
48 | def load_data(self):
49 | try:
50 | return json.load(open(self.file, "r"))
51 | except FileNotFoundError:
52 | return []
53 |
54 | def append(self, item):
55 | self.data.append(item)
56 | self.save_data()
57 |
58 | def update(self, item):
59 | for i, data_item in enumerate(self.data):
60 | if data_item[0] == item[0]:
61 | self.data[i] = item
62 | break
63 | else:
64 | # Se não encontrar, adiciona como um novo item
65 | self.append(item)
66 | self.save_data()
67 |
68 | def save_data(self):
69 | json.dump(self.data, open(self.file, "w"), indent=4)
70 |
71 |
72 | def load_model(weights, device) -> Model:
73 | with torch_distributed_zero_first(LOCAL_RANK):
74 | attempt_download(weights)
75 | model = torch.load(weights, map_location=device)["model"]
76 | for m in model.modules():
77 | if type(m) is nn.Upsample:
78 | m.recompute_scale_factor = None # torch 1.11.0 compatibility
79 | elif type(m) is Conv:
80 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
81 | model.float()
82 | model.eval()
83 | with torch.no_grad():
84 | model.fuse()
85 | return model
86 |
87 |
88 |
89 | def create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp_path):
90 | with open(hyp_path) as f:
91 | hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
92 | loader = create_dataloader(
93 | train_path,
94 | imgsz=imgsz,
95 | batch_size=batch_size,
96 | single_cls=single_cls,
97 | augment=True, hyp=hyp, rect=False, cache=False, stride=stride, pad=0.0, image_weights=False)[0]
98 | return loader
99 |
100 |
101 |
102 | def create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride, keep_images=None):
103 | loader = create_dataloader(
104 | test_path,
105 | imgsz=imgsz,
106 | batch_size=batch_size,
107 | single_cls=single_cls,
108 | augment=False, hyp=None, rect=True, cache=False,stride=stride,pad=0.5, image_weights=False)[0]
109 |
110 | def subclass_len(self):
111 | if keep_images is not None:
112 | return keep_images
113 | return len(self.img_files)
114 |
115 | loader.dataset.__len__ = subclass_len
116 | return loader
117 |
118 | def evaluate_dataset(model_eval, val_loader, imgsz, data_dict, single_cls, save_dir, is_coco, conf_thres=0.001 , iou_thres=0.7 ):
119 | return validate.run(data_dict,
120 | model=model_eval,
121 | imgsz=imgsz,
122 | single_cls=single_cls,
123 | half=True,
124 | task='val',
125 | verbose=True,
126 | conf_thres=conf_thres,
127 | iou_thres=iou_thres,
128 | save_dir=save_dir,
129 | save_json=is_coco,
130 | dataloader=val_loader,
131 | )[0][:4]
132 |
133 |
134 | def export_onnx(model, file, im, opset=12, dynamic=False, prefix=colorstr('QAT ONNX:')):
135 | check_requirements('onnx')
136 | import onnx
137 |
138 | file = Path(file)
139 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
140 |
141 | f = file.with_suffix('.onnx')
142 | output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
143 | model.eval()
144 | for k, m in model.named_modules():
145 | # print(m)
146 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)):
147 | m.inplace = True
148 | m.dynamic = dynamic
149 | m.export = True
150 | dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
151 | if isinstance(model, SegmentationModel):
152 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
153 | dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
154 | elif isinstance(model, DetectionModel):
155 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
156 |
157 | quantize.export_onnx(model, im, file, opset_version=13,
158 | input_names=["images"], output_names=output_names,
159 | dynamic_axes=dynamic or None
160 | )
161 |
162 | for k, m in model.named_modules():
163 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)):
164 | m.inplace = True
165 | m.dynamic = False
166 | m.export = False
167 |
168 |
169 | def run_quantize(weights, data, imgsz, batch_size, hyp, device, save_dir, supervision_stride, iters, no_eval_origin, no_eval_ptq, prefix=colorstr('QAT:')):
170 |
171 | if not Path(weights).exists():
172 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌')
173 | exit(1)
174 |
175 | quantize.initialize()
176 |
177 | with torch_distributed_zero_first(LOCAL_RANK):
178 | data_dict = check_dataset(data)
179 |
180 | w = save_dir / 'weights' # weights dir
181 | w.mkdir(parents=True, exist_ok=True) # make dir
182 |
183 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset
184 |
185 | nc = int(data_dict['nc']) # number of classes
186 | single_cls = False if nc > 1 else True
187 | names = data_dict['names'] # class names
188 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check
189 | train_path = data_dict['train']
190 | test_path = data_dict['val']
191 |
192 | result_eval_origin=None
193 | result_eval_ptq=None
194 | result_eval_qat_best=None
195 |
196 | device = torch.device(device)
197 | model = load_model(weights, device)
198 |
199 | if not isinstance(model, DetectionModel):
200 | model_name=model.__class__.__name__
201 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌')
202 | exit(1)
203 |
204 | stride = max(int(model.stride.max()), 32) # grid size (max stride)
205 | imgsz = check_img_size(imgsz, s=stride) # check image size
206 |
207 | # conf onnx export
208 | exp_imgsz=[imgsz,imgsz]
209 | gs = int(max(model.stride)) # grid size (max stride)
210 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples
211 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
212 |
213 |
214 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp)
215 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride)
216 |
217 | ### This rule is disabled - This allow user disable qat per Layers ###
218 | # This rule has been disabled, but it remains in the code to maintain compatibility or future implementation.
219 | """
220 | ignore_layer=-1
221 | if ignore_layer > -1:
222 | ignore_policy=f"model\.{ignore_layer}\.cv\d+\.\d+\.\d+(\.conv)?"
223 | else:
224 | ignore_policy=f"model\.9999999999\.cv\d+\.\d+\.\d+(\.conv)?"
225 | """
226 | ### End #######
227 |
228 | quantize.replace_custom_module_forward(model)
229 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented
230 | quantize.apply_custom_rules_to_quantizer(model, lambda model, file: export_onnx(model, file, im))
231 | quantize.calibrate_model(model, train_dataloader, device)
232 |
233 | report_file = os.path.join(save_dir, "report.json")
234 | report = ReportTool(report_file)
235 |
236 | if no_eval_origin:
237 | LOGGER.info(f'\n{prefix} Evaluating Origin...')
238 | model_eval = deepcopy(model).eval()
239 | with quantize.disable_quantization(model_eval):
240 | result_eval_origin = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco )
241 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_origin)
242 | LOGGER.info(f'\n{prefix} Eval Origin - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
243 | report.append(["Origin", str(weights), eval_map, eval_map50,eval_mp, eval_mr ])
244 |
245 | if no_eval_ptq:
246 |
247 | LOGGER.info(f'\n{prefix} Evaluating PTQ...')
248 | model_eval = deepcopy(model).eval()
249 |
250 | result_eval_ptq = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco )
251 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_ptq)
252 | LOGGER.info(f'\n{prefix} Eval PTQ - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
253 | ptq_weights = w / f'ptq_ap_{eval_map}_{os.path.basename(weights)}'
254 | torch.save({"model": model_eval},f'{ptq_weights}')
255 | LOGGER.info(f'\n{prefix} PTQ, weights saved as {ptq_weights} ({file_size(ptq_weights):.1f} MB)')
256 | report.append(["PTQ", str(ptq_weights), eval_map, eval_map50,eval_mp, eval_mr ])
257 |
258 | best_map = 0
259 |
260 | def per_epoch(model, epoch, lr):
261 | nonlocal best_map , result_eval_qat_best
262 |
263 | epoch +=1
264 | model_eval = deepcopy(model).eval()
265 | with torch.no_grad():
266 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco )
267 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
268 | qat_weights = w / f'qat_ep_{epoch}_ap_{eval_map}_{os.path.basename(weights)}'
269 | torch.save({"model": model_eval},f'{qat_weights}')
270 | LOGGER.info(f'\n{prefix} Epoch-{epoch}, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)')
271 | report.append([f"QAT-{epoch}", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ])
272 |
273 | if eval_map > best_map:
274 | best_map = eval_map
275 | result_eval_qat_best=eval_result
276 | qat_weights = w / f'qat_best_{os.path.basename(weights)}'
277 | torch.save({"model": model_eval}, f'{qat_weights}')
278 | LOGGER.info(f'{prefix} QAT Best, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)')
279 | report.update(["QAT-Best", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ])
280 |
281 | eval_results = [result_eval_origin, result_eval_ptq, result_eval_qat_best]
282 |
283 | LOGGER.info(f'\n\nEval Model | {"AP":<8} | {"AP50":<8} | {"Precision":<10} | {"Recall":<8}')
284 | LOGGER.info('-' * 55)
285 | for idx, eval_r in enumerate(eval_results):
286 | if eval_r is not None:
287 | eval_mp, eval_mr, eval_map50, eval_map = tuple(round(x, 3) for x in eval_r)
288 | if idx == 0:
289 | LOGGER.info(f'Origin | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}')
290 | if idx == 1:
291 | LOGGER.info(f'PTQ | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}')
292 | if idx == 2:
293 | LOGGER.info(f'QAT - Best | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}\n')
294 |
295 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
296 | LOGGER.info(f'\n{prefix} Eval - Epoch {epoch} | AP: {eval_map} | AP50: {eval_map50} | Precision: {eval_mp} | Recall: {eval_mr}\n')
297 |
298 | def preprocess(datas):
299 | return datas[0].to(device).float() / 255.0
300 |
301 | def supervision_policy():
302 | supervision_list = []
303 | for item in model.model:
304 | supervision_list.append(id(item))
305 |
306 | keep_idx = list(range(0, len(model.model) - 1, supervision_stride))
307 | keep_idx.append(len(model.model) - 2)
308 | def impl(name, module):
309 | if id(module) not in supervision_list: return False
310 | idx = supervision_list.index(id(module))
311 | if idx in keep_idx:
312 | print(f"Supervision: {name} will compute loss with origin model during QAT training")
313 | else:
314 | print(f"Supervision: {name} no compute loss during QAT training, that is unsupervised only and doesn't mean don't learn")
315 | return idx in keep_idx
316 | return impl
317 |
318 | quantize.finetune(
319 | model, train_dataloader, per_epoch, early_exit_batchs_per_epoch=iters,
320 | preprocess=preprocess, supervision_policy=supervision_policy())
321 |
322 | def run_sensitive_analysis(weights, device, data, imgsz, batch_size, hyp, save_dir, num_image, prefix=colorstr('QAT ANALYSIS:')):
323 |
324 | if not Path(weights).exists():
325 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌')
326 | exit(1)
327 |
328 | save_dir = Path(save_dir)
329 | # Create the directory if it doesn't exist
330 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok)
331 |
332 | with torch_distributed_zero_first(LOCAL_RANK):
333 | data_dict = check_dataset(data)
334 |
335 | is_coco=False
336 |
337 | nc = int(data_dict['nc']) # number of classes
338 | single_cls = False if nc > 1 else True
339 | names = data_dict['names'] # class names
340 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check
341 | train_path = data_dict['train']
342 | test_path = data_dict['val']
343 |
344 | device = torch.device(device)
345 | model = load_model(weights, device)
346 |
347 | if not isinstance(model, DetectionModel) or isinstance(model, SegmentationModel):
348 | LOGGER.info(f'{prefix} " Model not supported. Only Detection Models is supported. ❌')
349 | exit(1)
350 |
351 | is_model_qat=False
352 | for i in range(0, len(model.model)):
353 | layer = model.model[i]
354 | if quantize.have_quantizer(layer):
355 | is_model_qat=True
356 | break
357 |
358 | if is_model_qat:
359 | LOGGER.info(f'{prefix} This model already quantized. Only not quantized models is allowed. ❌')
360 | exit(1)
361 |
362 | stride = max(int(model.stride.max()), 32) # grid size (max stride)
363 | imgsz = check_img_size(imgsz, s=stride) # check image size
364 |
365 | exp_imgsz=[imgsz,imgsz]
366 | gs = int(max(model.stride)) # grid size (max stride)
367 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples
368 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
369 |
370 |
371 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp)
372 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride)
373 | quantize.initialize()
374 | quantize.replace_custom_module_forward(model)
375 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented
376 | quantize.calibrate_model(model, train_dataloader, device)
377 |
378 | report_file=os.path.join(save_dir , "summary-sensitive-analysis.json")
379 | report = ReportTool(report_file)
380 |
381 | model_eval = deepcopy(model).eval()
382 | LOGGER.info(f'\n{prefix} Evaluating PTQ...')
383 |
384 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco )
385 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
386 |
387 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT enabled on All Layers - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
388 | report.append([eval_map, "PTQ"])
389 | LOGGER.info(f'{prefix} Sensitive analysis by each layer. Layers Detected: {len(model.model)}')
390 |
391 | for i in range(0, len(model.model)):
392 | layer = model.model[i]
393 | if quantize.have_quantizer(layer):
394 | LOGGER.info(f'{prefix} QAT disabled on Layer model.{i}')
395 | quantize.disable_quantization(layer).apply()
396 | model_eval = deepcopy(model).eval()
397 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco )
398 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
399 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT disabled on Layer model.{i} - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}\n')
400 | report.append([eval_map, f"model.{i}"])
401 | quantize.enable_quantization(layer).apply()
402 | else:
403 | LOGGER.info(f'{prefix} Ignored Layer model.{i} because it is {type(layer)}')
404 |
405 | report = sorted(report.data, key=lambda x:x[0], reverse=True)
406 | print("Sensitive summary:")
407 | for n, (ap, name) in enumerate(report[:10]):
408 | print(f"Top{n}: Using fp16 {name}, ap = {ap:.5f}")
409 |
410 |
411 | def run_eval(weights, device, data, imgsz, batch_size, save_dir, conf_thres, iou_thres, prefix=colorstr('QAT TEST:')):
412 |
413 | if not Path(weights).exists():
414 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌')
415 | exit(1)
416 |
417 | quantize.initialize()
418 |
419 | save_dir = Path(save_dir)
420 | # Create the directory if it doesn't exist
421 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok)
422 |
423 | with torch_distributed_zero_first(LOCAL_RANK):
424 | data_dict = check_dataset(data)
425 |
426 |
427 | device = torch.device(device)
428 | model = load_model(weights, device)
429 |
430 | if not isinstance(model, DetectionModel):
431 | model_name=model.__class__.__name__
432 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌')
433 | exit(1)
434 |
435 | is_model_qat=False
436 | for i in range(0, len(model.model)):
437 | layer = model.model[i]
438 | if quantize.have_quantizer(layer):
439 | is_model_qat=True
440 | break
441 |
442 | if not is_model_qat:
443 | LOGGER.info(f'{prefix} This model was not Quantized. ❌')
444 | exit(1)
445 |
446 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset
447 |
448 | stride = max(int(model.stride.max()), 32) # grid size (max stride)
449 | imgsz = check_img_size(imgsz, s=stride) # check image size
450 |
451 | nc = int(data_dict['nc']) # number of classes
452 | single_cls = False if nc > 1 else True
453 | names = data_dict['names'] # class names
454 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check
455 |
456 | test_path = data_dict['val']
457 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride)
458 |
459 |
460 | LOGGER.info(f'\n{prefix} Evaluating ...')
461 | model_eval = deepcopy(model).eval()
462 |
463 | result_eval = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, conf_thres=conf_thres, iou_thres=iou_thres )
464 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval)
465 | LOGGER.info(f'\n{prefix} Eval Result - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
466 | LOGGER.info(f'\n{prefix} Eval Result, saved on {save_dir}')
467 |
468 |
469 |
470 | if __name__ == "__main__":
471 | parser = argparse.ArgumentParser(prog='qat.py')
472 | subps = parser.add_subparsers(dest="cmd")
473 | qat = subps.add_parser("quantize", help="PTQ/QAT finetune ...")
474 |
475 | qat.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='weights path')
476 | qat.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
477 | qat.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-high.yaml', help='hyperparameters path')
478 | qat.add_argument("--device", type=str, default="cuda:0", help="device")
479 | qat.add_argument('--batch-size', type=int, default=10, help='total batch size')
480 | qat.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
481 | qat.add_argument('--project', default=ROOT / 'runs/qat', help='save to project/name')
482 | qat.add_argument('--name', default='exp', help='save to project/name')
483 | qat.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
484 | qat.add_argument("--iters", type=int, default=200, help="iters per epoch")
485 | qat.add_argument('--seed', type=int, default=57, help='Global training seed')
486 | qat.add_argument("--supervision-stride", type=int, default=1, help="supervision stride")
487 | qat.add_argument("--no-eval-origin", action="store_false", help="Disable eval for origin model")
488 | qat.add_argument("--no-eval-ptq", action="store_false", help="Disable eval for ptq model")
489 |
490 | sensitive = subps.add_parser("sensitive", help="Sensitive layer analysis")
491 | sensitive.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)')
492 | sensitive.add_argument("--device", type=str, default="cuda:0", help="device")
493 | sensitive.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path')
494 | sensitive.add_argument('--batch-size', type=int, default=10, help='total batch size')
495 | sensitive.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
496 | sensitive.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch-high.yaml', help='hyperparameters path')
497 | sensitive.add_argument('--project', default=ROOT / 'runs/qat_sentive', help='save to project/name')
498 | sensitive.add_argument('--name', default='exp', help='save to project/name')
499 | sensitive.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
500 | sensitive.add_argument("--num-image", type=int, default=None, help="number of image to evaluate")
501 |
502 | testcmd = subps.add_parser("eval", help="Do evaluate")
503 | testcmd.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)')
504 | testcmd.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path')
505 | testcmd.add_argument('--batch-size', type=int, default=10, help='total batch size')
506 | testcmd.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='val image size (pixels)')
507 | testcmd.add_argument("--device", type=str, default="cuda:0", help="device")
508 | testcmd.add_argument("--conf-thres", type=float, default=0.001, help="confidence threshold")
509 | testcmd.add_argument("--iou-thres", type=float, default=0.7, help="nms threshold")
510 | testcmd.add_argument('--project', default=ROOT / 'runs/qat_eval', help='save to project/name')
511 | testcmd.add_argument('--name', default='exp', help='save to project/name')
512 | testcmd.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
513 |
514 | opt = parser.parse_args()
515 | if opt.cmd == "quantize":
516 | print(opt)
517 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
518 | init_seeds(opt.seed + 1 + RANK, deterministic=False)
519 |
520 | run_quantize(
521 | opt.weights, opt.data, opt.imgsz, opt.batch_size,
522 | opt.hyp, opt.device, Path(opt.save_dir),
523 | opt.supervision_stride, opt.iters,
524 | opt.no_eval_origin, opt.no_eval_ptq
525 | )
526 |
527 | elif opt.cmd == "sensitive":
528 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
529 | print(opt)
530 | run_sensitive_analysis(opt.weights, opt.device, opt.data,
531 | opt.imgsz, opt.batch_size, opt.hyp,
532 | opt.save_dir, opt.num_image
533 | )
534 | elif opt.cmd == "eval":
535 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
536 | print(opt)
537 | run_eval(opt.weights, opt.device, opt.data,
538 | opt.imgsz, opt.batch_size, opt.save_dir,
539 | opt.conf_thres, opt.iou_thres
540 | )
541 | else:
542 | parser.print_help()
--------------------------------------------------------------------------------
/scripts/generate_trt_engine.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Function to display usage message
4 | usage() {
5 | echo "Usage: $0 [--generate-graph]"
6 | exit 1
7 | }
8 |
9 | # Check if the correct number of arguments are provided
10 | if [ "$#" -lt 2 ] || [ "$#" -gt 3 ]; then
11 | usage
12 | fi
13 |
14 | function get_free_gpu_memory() {
15 | # Get the total memory and used memory from nvidia-smi for GPU 0
16 | local total_memory=$(nvidia-smi --id=0 --query-gpu=memory.total --format=csv,noheader,nounits | awk '{print $1}')
17 | local used_memory=$(nvidia-smi --id=0 --query-gpu=memory.used --format=csv,noheader,nounits | awk '{print $1}')
18 |
19 | # Calculate free memory
20 | local free_memory=$((total_memory - used_memory))
21 | echo "$free_memory"
22 | }
23 |
24 | workspace=$(get_free_gpu_memory)
25 |
26 | # Set default values
27 | generate_graph=false
28 |
29 | # Parse command line arguments
30 | onnx="$1"
31 | image_size="$2"
32 | stride=32
33 | network_size=$((image_size + stride))
34 | shape=1x3x${network_size}x${network_size}
35 |
36 | file_no_ext="${onnx%.*}"
37 |
38 | # Generate engine and graph file paths
39 | trt_engine="$file_no_ext.engine"
40 | graph="$trt_engine.layer.json"
41 | profile="$trt_engine.profile.json"
42 | timing="$trt_engine.timing.json"
43 | timing_cache="$trt_engine.timing.cache"
44 |
45 |
46 | # Check if optional flag --generate-graph is provided
47 | if [ "$3" == "--generate-graph" ]; then
48 | generate_graph=true
49 | fi
50 |
51 | # Run trtexec command to generate engine and graph files
52 | if [ "$generate_graph" = true ]; then
53 | trtexec --onnx="${onnx}" \
54 | --saveEngine="${trt_engine}" \
55 | --fp16 --int8 \
56 | --useCudaGraph \
57 | --separateProfileRun \
58 | --useSpinWait \
59 | --profilingVerbosity=detailed \
60 | --minShapes=images:$shape \
61 | --optShapes=images:$shape \
62 | --maxShapes=images:$shape \
63 | --memPoolSize=workspace:${workspace}MiB \
64 | --dumpLayerInfo \
65 | --exportTimes="${timing}" \
66 | --exportLayerInfo="${graph}" \
67 | --exportProfile="${profile}" \
68 | --timingCacheFile="${timing_cache}"
69 |
70 | # Profiling affects the performance of your kernel!
71 | # Always run and time without profiling.
72 |
73 | else
74 | trtexec --onnx="${onnx}" \
75 | --saveEngine="${trt_engine}" \
76 | --fp16 --int8 \
77 | --useCudaGraph \
78 | --minShapes=images:$shape \
79 | --optShapes=images:$shape \
80 | --maxShapes=images:$shape \
81 | --memPoolSize=workspace:${workspace}MiB \
82 | --timingCacheFile="${timing_cache}"
83 | fi
84 |
85 | # Check if trtexec command was successful
86 | if [ $? -eq 0 ]; then
87 | echo "Engine file generated successfully: ${trt_engine}"
88 | if [ "$generate_graph" = true ]; then
89 | echo "Graph file generated successfully: ${graph}"
90 | echo "Profile file generated successfully: ${profile}"
91 | fi
92 | else
93 | echo "Failed to generate engine file."
94 | exit 1
95 | fi
96 |
97 |
--------------------------------------------------------------------------------
/scripts/val_trt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | usage() {
5 | echo "Usage: $0 --generate-graph"
6 | exit 1
7 | }
8 |
9 | # Check if all required arguments are provided
10 | if [ "$#" -lt 3 -o "$#" -gt 4 ]; then
11 | usage
12 | exit 1
13 | fi
14 |
15 | weight=$1
16 | data=$2
17 | img_size=$3
18 | generate_graph=false
19 |
20 | file_no_ext="${weight%.*}"
21 |
22 | # Generate engine and graph file paths
23 | onnx_file="$file_no_ext.onnx"
24 | trt_engine="$file_no_ext.engine"
25 | graph="$trt_engine.layer.json"
26 | profile="$trt_engine.profile.json"
27 |
28 |
29 | # Check if optional flag --generate-graph is provided
30 | if [ "$4" == "--generate-graph" ]; then
31 | generate_graph=true
32 | fi
33 |
34 | # Check if weight file exists
35 | if [ ! -f "$weight" ]; then
36 | echo "Error: Weight file '$weight' not found."
37 | exit 1
38 | fi
39 |
40 | # Run the script
41 | python3 export_qat.py --weights "$weight" --include onnx --inplace --dynamic --simplify
42 |
43 | echo -e "\n"
44 |
45 | # Check if ONNX file was successfully generated
46 | if [ ! -f $onnx_file ]; then
47 | echo "Error: ONNX file $onnx_file not generated."
48 | exit 1
49 | fi
50 |
51 | # Run the script
52 | if [ "$generate_graph" = true ]; then
53 | bash scripts/generate_trt_engine.sh $onnx_file $img_size --generate-graph
54 | else
55 | bash scripts/generate_trt_engine.sh $onnx_file $img_size
56 | fi
57 |
58 | if [ $? -ne 0 ]; then
59 | exit 1
60 | fi
61 |
62 | if [ "$generate_graph" = true ]; then
63 | # Check if Graph file exists
64 | if [ ! -f "$graph" ]; then
65 | echo "Error: Graph file $graph not found."
66 | exit 1
67 | fi
68 |
69 | # Check if Graph file exists
70 | if [ ! -f "$profile" ]; then
71 | echo "Error: Graph file $profile not found."
72 | exit 1
73 | fi
74 |
75 | # Run the script
76 | /bin/bash -c "source /opt/nvidia_trex/env_trex/bin/activate && python3 draw-engine.py --layer $graph --profile $profile"
77 | fi
78 |
79 | # Run the script
80 | python3 val_trt.py --engine-file $trt_engine --data $data
81 |
82 |
83 |
--------------------------------------------------------------------------------
/segment/qat_seg.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | import yaml
5 | import argparse
6 | import json
7 | from copy import deepcopy
8 | from pathlib import Path
9 | import warnings
10 |
11 | # PyTorch
12 | import torch
13 | import torch.nn as nn
14 |
15 | import val as validate
16 | from models.yolo import Model
17 | from models.common import Conv
18 | from utils.segment.dataloaders import create_dataloader
19 | from utils.downloads import attempt_download
20 |
21 | from models.yolo import Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel
22 | import models.quantize as quantize
23 |
24 | from utils.general import (LOGGER, check_dataset, check_requirements, check_img_size, colorstr, init_seeds,increment_path,file_size)
25 | from utils.torch_utils import (torch_distributed_zero_first)
26 |
27 | warnings.filterwarnings("ignore")
28 |
29 | FILE = Path(__file__).resolve()
30 | ROOT = FILE.parents[0] # YOLO root directory
31 | if str(ROOT) not in sys.path:
32 | sys.path.append(str(ROOT)) # add ROOT to PATH
33 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
34 |
35 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
36 | RANK = int(os.getenv('RANK', -1))
37 | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
38 | GIT_INFO = None
39 |
40 |
41 | class ReportTool:
42 | def __init__(self, file):
43 | self.file = file
44 | if os.path.exists(self.file):
45 | open(self.file, 'w').close()
46 | self.data = []
47 |
48 | def load_data(self):
49 | try:
50 | return json.load(open(self.file, "r"))
51 | except FileNotFoundError:
52 | return []
53 |
54 | def append(self, item):
55 | self.data.append(item)
56 | self.save_data()
57 |
58 | def update(self, item):
59 | for i, data_item in enumerate(self.data):
60 | if data_item[0] == item[0]:
61 | self.data[i] = item
62 | break
63 | else:
64 | # Se não encontrar, adiciona como um novo item
65 | self.append(item)
66 | self.save_data()
67 |
68 | def save_data(self):
69 | json.dump(self.data, open(self.file, "w"), indent=4)
70 |
71 |
72 | def load_model(weights, device) -> Model:
73 | with torch_distributed_zero_first(LOCAL_RANK):
74 | attempt_download(weights)
75 | model = torch.load(weights, map_location=device)["model"]
76 | for m in model.modules():
77 | if type(m) is nn.Upsample:
78 | m.recompute_scale_factor = None # torch 1.11.0 compatibility
79 | elif type(m) is Conv:
80 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
81 | model.float()
82 | model.eval()
83 | with torch.no_grad():
84 | model.fuse()
85 | return model
86 |
87 |
88 |
89 | def create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp_path, mask_ratio, overlap):
90 | with open(hyp_path) as f:
91 | hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps
92 | loader = create_dataloader(
93 | train_path,
94 | imgsz=imgsz,
95 | batch_size=batch_size,
96 | single_cls=single_cls,
97 | augment=True,
98 | hyp=hyp,
99 | rect=False,
100 | cache=False,
101 | stride=stride,
102 | pad=0.0,
103 | image_weights=False,
104 | shuffle=True,
105 | mask_downsample_ratio=mask_ratio,
106 | overlap_mask=overlap)[0]
107 | return loader
108 |
109 |
110 |
111 |
112 | def create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride, keep_images, mask_ratio,overlap ):
113 | loader = create_dataloader(
114 | test_path,
115 | imgsz=imgsz,
116 | batch_size=batch_size,
117 | single_cls=single_cls,
118 | augment=False,
119 | hyp=None,
120 | rect=True,
121 | cache=False,
122 | stride=stride,pad=0.5,
123 | image_weights=False,
124 | mask_downsample_ratio=mask_ratio,
125 | overlap_mask=overlap)[0]
126 |
127 | def subclass_len(self):
128 | if keep_images is not None:
129 | return keep_images
130 | return len(self.img_files)
131 |
132 | loader.dataset.__len__ = subclass_len
133 | return loader
134 |
135 | def evaluate_dataset(model_eval, val_loader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap, conf_thres=0.001 , iou_thres=0.65 ):
136 | return validate.run(data_dict,
137 | model=model_eval,
138 | imgsz=imgsz,
139 | single_cls=single_cls,
140 | half=True,
141 | task='val',
142 | verbose=True,
143 | conf_thres=conf_thres,
144 | iou_thres=iou_thres,
145 | save_dir=save_dir,
146 | save_json=is_coco,
147 | dataloader=val_loader,
148 | mask_downsample_ratio=mask_ratio,
149 | overlap=overlap
150 | )[0][:4]
151 |
152 |
153 | def export_onnx(model, file, im, opset=12, dynamic=False, prefix=colorstr('QAT ONNX:')):
154 | check_requirements('onnx')
155 | import onnx
156 |
157 | file = Path(file)
158 | LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
159 |
160 | f = file.with_suffix('.onnx')
161 | output_names = ['output0', 'output1'] if isinstance(model, SegmentationModel) else ['output0']
162 | model.eval()
163 | for k, m in model.named_modules():
164 | # print(m)
165 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)):
166 | m.inplace = True
167 | m.dynamic = dynamic
168 | m.export = True
169 | dynamic = {'images': {0: 'batch', 2: 'height', 3: 'width'}} # shape(1,3,640,640)
170 | if isinstance(model, SegmentationModel):
171 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
172 | dynamic['output1'] = {0: 'batch', 2: 'mask_height', 3: 'mask_width'} # shape(1,32,160,160)
173 | elif isinstance(model, DetectionModel):
174 | dynamic['output0'] = {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
175 |
176 | quantize.export_onnx(model, im, file, opset_version=13,
177 | input_names=["images"], output_names=output_names,
178 | dynamic_axes=dynamic or None
179 | )
180 |
181 | for k, m in model.named_modules():
182 | if isinstance(m, (Detect, DDetect, DualDetect, DualDDetect)):
183 | m.inplace = True
184 | m.dynamic = False
185 | m.export = False
186 |
187 |
188 | def run_quantize(weights, data, imgsz, batch_size, hyp, device, save_dir, supervision_stride, iters, no_eval_origin, no_eval_ptq, mask_ratio, overlap, prefix=colorstr('QAT:')):
189 |
190 | if not Path(weights).exists():
191 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌')
192 | exit(1)
193 |
194 | quantize.initialize()
195 |
196 | with torch_distributed_zero_first(LOCAL_RANK):
197 | data_dict = check_dataset(data)
198 |
199 | w = save_dir / 'weights' # weights dir
200 | w.mkdir(parents=True, exist_ok=True) # make dir
201 |
202 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset
203 |
204 |
205 | nc = int(data_dict['nc']) # number of classes
206 | single_cls = False if nc > 1 else True
207 | names = data_dict['names'] # class names
208 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check
209 | train_path = data_dict['train']
210 | test_path = data_dict['val']
211 |
212 | result_eval_origin=None
213 | result_eval_ptq=None
214 | result_eval_qat_best=None
215 |
216 | device = torch.device(device)
217 | model = load_model(weights, device)
218 |
219 | if not isinstance(model, SegmentationModel):
220 | model_name=model.__class__.__name__
221 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌')
222 | exit(1)
223 |
224 | stride = max(int(model.stride.max()), 32) # grid size (max stride)
225 | imgsz = check_img_size(imgsz, s=stride) # check image size
226 |
227 | # conf onnx export
228 | exp_imgsz=[imgsz,imgsz]
229 | gs = int(max(model.stride)) # grid size (max stride)
230 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples
231 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
232 |
233 |
234 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp, mask_ratio, overlap)
235 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride, None, mask_ratio, overlap)
236 |
237 | ### This rule is disabled - This allow user disable qat per Layers ###
238 | # This rule has been disabled, but it remains in the code to maintain compatibility or future implementation.
239 | """
240 | ignore_layer=-1
241 | if ignore_layer > -1:
242 | ignore_policy=f"model\.{ignore_layer}\.cv\d+\.\d+\.\d+(\.conv)?"
243 | else:
244 | ignore_policy=f"model\.9999999999\.cv\d+\.\d+\.\d+(\.conv)?"
245 | """
246 | ### End #######
247 |
248 | quantize.replace_custom_module_forward(model)
249 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented
250 | quantize.apply_custom_rules_to_quantizer(model, lambda model, file: export_onnx(model, file, im))
251 | quantize.calibrate_model(model, train_dataloader, device)
252 |
253 | report_file = os.path.join(save_dir, "report.json")
254 | report = ReportTool(report_file)
255 |
256 | if no_eval_origin:
257 | LOGGER.info(f'\n{prefix} Evaluating Origin...')
258 | model_eval = deepcopy(model).eval()
259 | with quantize.disable_quantization(model_eval):
260 | result_eval_origin = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap )
261 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_origin)
262 | LOGGER.info(f'\n{prefix} Eval Origin - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
263 | report.append(["Origin", str(weights), eval_map, eval_map50,eval_mp, eval_mr ])
264 |
265 | if no_eval_ptq:
266 |
267 | LOGGER.info(f'\n{prefix} Evaluating PTQ...')
268 | model_eval = deepcopy(model).eval()
269 |
270 | result_eval_ptq = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap )
271 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval_ptq)
272 | LOGGER.info(f'\n{prefix} Eval PTQ - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
273 | ptq_weights = w / f'ptq_ap_{eval_map}_{os.path.basename(weights)}'
274 | torch.save({"model": model_eval},f'{ptq_weights}')
275 | LOGGER.info(f'\n{prefix} PTQ, weights saved as {ptq_weights} ({file_size(ptq_weights):.1f} MB)')
276 | report.append(["PTQ", str(ptq_weights), eval_map, eval_map50,eval_mp, eval_mr ])
277 |
278 | best_map = 0
279 |
280 | def per_epoch(model, epoch, lr):
281 | nonlocal best_map , result_eval_qat_best
282 |
283 | epoch +=1
284 | model_eval = deepcopy(model).eval()
285 | with torch.no_grad():
286 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap )
287 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
288 | qat_weights = w / f'qat_ep_{epoch}_ap_{eval_map}_{os.path.basename(weights)}'
289 | torch.save({"model": model_eval},f'{qat_weights}')
290 | LOGGER.info(f'\n{prefix} Epoch-{epoch}, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)')
291 | report.append([f"QAT-{epoch}", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ])
292 |
293 | if eval_map > best_map:
294 | best_map = eval_map
295 | result_eval_qat_best=eval_result
296 | qat_weights = w / f'qat_best_{os.path.basename(weights)}'
297 | torch.save({"model": model_eval}, f'{qat_weights}')
298 | LOGGER.info(f'{prefix} QAT Best, weights saved as {qat_weights} ({file_size(qat_weights):.1f} MB)')
299 | report.update(["QAT-Best", str(qat_weights), eval_map, eval_map50,eval_mp, eval_mr ])
300 |
301 | eval_results = [result_eval_origin, result_eval_ptq, result_eval_qat_best]
302 |
303 | LOGGER.info(f'\n\nEval Model | {"AP":<8} | {"AP50":<8} | {"Precision":<10} | {"Recall":<8}')
304 | LOGGER.info('-' * 55)
305 | for idx, eval_r in enumerate(eval_results):
306 | if eval_r is not None:
307 | eval_mp, eval_mr, eval_map50, eval_map = tuple(round(x, 3) for x in eval_r)
308 | if idx == 0:
309 | LOGGER.info(f'Origin | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}')
310 | if idx == 1:
311 | LOGGER.info(f'PTQ | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}')
312 | if idx == 2:
313 | LOGGER.info(f'QAT - Best | {eval_map:<8} | {eval_map50:<8} | {eval_mp:<10} | {eval_mr:<8}\n')
314 |
315 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
316 | LOGGER.info(f'\n{prefix} Eval - Epoch {epoch} | AP: {eval_map} | AP50: {eval_map50} | Precision: {eval_mp} | Recall: {eval_mr}\n')
317 |
318 | def preprocess(datas):
319 | return datas[0].to(device).float() / 255.0
320 |
321 | def supervision_policy():
322 | supervision_list = []
323 | for item in model.model:
324 | supervision_list.append(id(item))
325 |
326 | keep_idx = list(range(0, len(model.model) - 1, supervision_stride))
327 | keep_idx.append(len(model.model) - 2)
328 | def impl(name, module):
329 | if id(module) not in supervision_list: return False
330 | idx = supervision_list.index(id(module))
331 | if idx in keep_idx:
332 | print(f"Supervision: {name} will compute loss with origin model during QAT training")
333 | else:
334 | print(f"Supervision: {name} no compute loss during QAT training, that is unsupervised only and doesn't mean don't learn")
335 | return idx in keep_idx
336 | return impl
337 |
338 | quantize.finetune(
339 | model, train_dataloader, per_epoch, early_exit_batchs_per_epoch=iters,
340 | preprocess=preprocess, supervision_policy=supervision_policy())
341 |
342 | def run_sensitive_analysis(weights, device, data, imgsz, batch_size, hyp, save_dir, num_image, prefix=colorstr('QAT ANALYSIS:')):
343 |
344 | if not Path(weights).exists():
345 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌')
346 | exit(1)
347 |
348 | save_dir = Path(save_dir)
349 | # Create the directory if it doesn't exist
350 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok)
351 |
352 | with torch_distributed_zero_first(LOCAL_RANK):
353 | data_dict = check_dataset(data)
354 |
355 | is_coco=False
356 |
357 | nc = int(data_dict['nc']) # number of classes
358 | single_cls = False if nc > 1 else True
359 | names = data_dict['names'] # class names
360 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check
361 | train_path = data_dict['train']
362 | test_path = data_dict['val']
363 |
364 | device = torch.device(device)
365 | model = load_model(weights, device)
366 |
367 | if not isinstance(model, DetectionModel) or isinstance(model, SegmentationModel):
368 | LOGGER.info(f'{prefix} " Model not supported. Only Detection Models is supported. ❌')
369 | exit(1)
370 |
371 | is_model_qat=False
372 | for i in range(0, len(model.model)):
373 | layer = model.model[i]
374 | if quantize.have_quantizer(layer):
375 | is_model_qat=True
376 | break
377 |
378 | if is_model_qat:
379 | LOGGER.info(f'{prefix} This model already quantized. Only not quantized models is allowed. ❌')
380 | exit(1)
381 |
382 | stride = max(int(model.stride.max()), 32) # grid size (max stride)
383 | imgsz = check_img_size(imgsz, s=stride) # check image size
384 |
385 | exp_imgsz=[imgsz,imgsz]
386 | gs = int(max(model.stride)) # grid size (max stride)
387 | exp_imgsz = [check_img_size(x, gs) for x in exp_imgsz] # verify img_size are gs-multiples
388 | im = torch.zeros(batch_size, 3, *exp_imgsz).to(device) # image size(1,3,320,192) BCHW iDetection
389 |
390 |
391 | train_dataloader = create_train_dataloader(train_path, imgsz, batch_size, single_cls, stride, hyp)
392 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride)
393 | quantize.initialize()
394 | quantize.replace_custom_module_forward(model)
395 | quantize.replace_to_quantization_module(model, ignore_policy="disabled") ## disabled because was not implemented
396 | quantize.calibrate_model(model, train_dataloader, device)
397 |
398 | report_file=os.path.join(save_dir , "summary-sensitive-analysis.json")
399 | report = ReportTool(report_file)
400 |
401 | model_eval = deepcopy(model).eval()
402 | LOGGER.info(f'\n{prefix} Evaluating PTQ...')
403 |
404 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap )
405 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
406 |
407 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT enabled on All Layers - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
408 | report.append([eval_map, "PTQ"])
409 | LOGGER.info(f'{prefix} Sensitive analysis by each layer. Layers Detected: {len(model.model)}')
410 |
411 | for i in range(0, len(model.model)):
412 | layer = model.model[i]
413 | if quantize.have_quantizer(layer):
414 | LOGGER.info(f'{prefix} QAT disabled on Layer model.{i}')
415 | quantize.disable_quantization(layer).apply()
416 | model_eval = deepcopy(model).eval()
417 | eval_result = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap )
418 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in eval_result)
419 | LOGGER.info(f'\n{prefix} Eval PTQ - QAT disabled on Layer model.{i} - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}\n')
420 | report.append([eval_map, f"model.{i}"])
421 | quantize.enable_quantization(layer).apply()
422 | else:
423 | LOGGER.info(f'{prefix} Ignored Layer model.{i} because it is {type(layer)}')
424 |
425 | report = sorted(report.data, key=lambda x:x[0], reverse=True)
426 | print("Sensitive summary:")
427 | for n, (ap, name) in enumerate(report[:10]):
428 | print(f"Top{n}: Using fp16 {name}, ap = {ap:.5f}")
429 |
430 |
431 | def run_eval(weights, device, data, imgsz, batch_size, save_dir, conf_thres, iou_thres, prefix=colorstr('QAT TEST:')):
432 |
433 | if not Path(weights).exists():
434 | LOGGER.info(f'{prefix} Weight file not found "{weights}" ❌')
435 | exit(1)
436 |
437 | quantize.initialize()
438 |
439 | save_dir = Path(save_dir)
440 | # Create the directory if it doesn't exist
441 | save_dir.mkdir(parents=True, exist_ok=opt.exist_ok)
442 |
443 | with torch_distributed_zero_first(LOCAL_RANK):
444 | data_dict = check_dataset(data)
445 |
446 |
447 | device = torch.device(device)
448 | model = load_model(weights, device)
449 |
450 | if not isinstance(model, DetectionModel):
451 | model_name=model.__class__.__name__
452 | LOGGER.info(f'{prefix} {model_name} model is not supported. Only DetectionModel is supported. ❌')
453 | exit(1)
454 |
455 | is_model_qat=False
456 | for i in range(0, len(model.model)):
457 | layer = model.model[i]
458 | if quantize.have_quantizer(layer):
459 | is_model_qat=True
460 | break
461 |
462 | if not is_model_qat:
463 | LOGGER.info(f'{prefix} This model was not Quantized. ❌')
464 | exit(1)
465 |
466 | is_coco = isinstance(data_dict.get('val'), str) and data_dict['val'].endswith(f'val2017.txt') # COCO dataset
467 |
468 | stride = max(int(model.stride.max()), 32) # grid size (max stride)
469 | imgsz = check_img_size(imgsz, s=stride) # check image size
470 |
471 | nc = int(data_dict['nc']) # number of classes
472 | single_cls = False if nc > 1 else True
473 | names = data_dict['names'] # class names
474 | assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, data_dict) # check
475 |
476 | test_path = data_dict['val']
477 | val_dataloader = create_val_dataloader(test_path, imgsz, batch_size, single_cls, stride)
478 |
479 |
480 | LOGGER.info(f'\n{prefix} Evaluating ...')
481 | model_eval = deepcopy(model).eval()
482 |
483 | result_eval = evaluate_dataset(model_eval, val_dataloader, imgsz, data_dict, single_cls, save_dir, is_coco, mask_ratio, overlap ,conf_thres=conf_thres, iou_thres=iou_thres )
484 | eval_mp, eval_mr, eval_map50, eval_map= tuple(round(x, 4) for x in result_eval)
485 | LOGGER.info(f'\n{prefix} Eval Result - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
486 | LOGGER.info(f'\n{prefix} Eval Result, saved on {save_dir}')
487 |
488 |
489 |
490 | if __name__ == "__main__":
491 | parser = argparse.ArgumentParser(prog='qat.py')
492 | subps = parser.add_subparsers(dest="cmd")
493 | qat = subps.add_parser("quantize", help="PTQ/QAT finetune ...")
494 |
495 | qat.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='weights path')
496 | qat.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
497 | qat.add_argument('--hyp', type=str, default=ROOT / 'data/hyps/hyp.scratch-high.yaml', help='hyperparameters path')
498 | qat.add_argument("--device", type=str, default="cuda:0", help="device")
499 | qat.add_argument('--batch-size', type=int, default=10, help='total batch size')
500 | qat.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
501 | qat.add_argument('--project', default=ROOT / 'runs/qat', help='save to project/name')
502 | qat.add_argument('--name', default='exp', help='save to project/name')
503 | qat.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
504 | qat.add_argument("--iters", type=int, default=200, help="iters per epoch")
505 | qat.add_argument('--seed', type=int, default=57, help='Global training seed')
506 | qat.add_argument("--supervision-stride", type=int, default=1, help="supervision stride")
507 | qat.add_argument("--no-eval-origin", action="store_false", help="Disable eval for origin model")
508 | qat.add_argument("--no-eval-ptq", action="store_false", help="Disable eval for ptq model")
509 | qat.add_argument('--mask-ratio', type=int, default=4, help='Downsample the truth masks to saving memory')
510 | qat.add_argument('--no-overlap', action='store_true', help='Overlap masks train faster at slightly less mAP')
511 |
512 | sensitive = subps.add_parser("sensitive", help="Sensitive layer analysis")
513 | sensitive.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)')
514 | sensitive.add_argument("--device", type=str, default="cuda:0", help="device")
515 | sensitive.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path')
516 | sensitive.add_argument('--batch-size', type=int, default=10, help='total batch size')
517 | sensitive.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='train, val image size (pixels)')
518 | sensitive.add_argument('--hyp', type=str, default='data/hyps/hyp.scratch-high.yaml', help='hyperparameters path')
519 | sensitive.add_argument('--project', default=ROOT / 'runs/qat_sentive', help='save to project/name')
520 | sensitive.add_argument('--name', default='exp', help='save to project/name')
521 | sensitive.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
522 | sensitive.add_argument("--num-image", type=int, default=None, help="number of image to evaluate")
523 |
524 | testcmd = subps.add_parser("eval", help="Do evaluate")
525 | testcmd.add_argument('--weights', type=str, default=ROOT / 'runs/models_original/yolov9-c.pt', help='Weights path (.pt)')
526 | testcmd.add_argument('--data', type=str, default='data/coco.yaml', help='data.yaml path')
527 | testcmd.add_argument('--batch-size', type=int, default=10, help='total batch size')
528 | testcmd.add_argument('--imgsz', '--img', '--img-size', type=int, default=640, help='val image size (pixels)')
529 | testcmd.add_argument("--device", type=str, default="cuda:0", help="device")
530 | testcmd.add_argument("--conf-thres", type=float, default=0.001, help="confidence threshold")
531 | testcmd.add_argument("--iou-thres", type=float, default=0.7, help="nms threshold")
532 | testcmd.add_argument('--project', default=ROOT / 'runs/qat_eval', help='save to project/name')
533 | testcmd.add_argument('--name', default='exp', help='save to project/name')
534 | testcmd.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
535 |
536 | opt = parser.parse_args()
537 | if opt.cmd == "quantize":
538 | print(opt)
539 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
540 | init_seeds(opt.seed + 1 + RANK, deterministic=False)
541 |
542 | run_quantize(
543 | opt.weights, opt.data, opt.imgsz, opt.batch_size,
544 | opt.hyp, opt.device, Path(opt.save_dir),
545 | opt.supervision_stride, opt.iters,
546 | opt.no_eval_origin, opt.no_eval_ptq,opt.mask_ratio, opt.no_overlap
547 | )
548 |
549 | elif opt.cmd == "sensitive":
550 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
551 | print(opt)
552 | run_sensitive_analysis(opt.weights, opt.device, opt.data,
553 | opt.imgsz, opt.batch_size, opt.hyp,
554 | opt.save_dir, opt.num_image
555 | )
556 | elif opt.cmd == "eval":
557 | opt.save_dir = str(increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok))
558 | print(opt)
559 | run_eval(opt.weights, opt.device, opt.data,
560 | opt.imgsz, opt.batch_size, opt.save_dir,
561 | opt.conf_thres, opt.iou_thres
562 | )
563 | else:
564 | parser.print_help()
--------------------------------------------------------------------------------
/val_trt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import sys
5 | from pathlib import Path
6 |
7 | import numpy as np
8 | import torch
9 | from tqdm import tqdm
10 |
11 | FILE = Path(__file__).resolve()
12 | ROOT = FILE.parents[0] # YOLO root directory
13 | if str(ROOT) not in sys.path:
14 | sys.path.append(str(ROOT)) # add ROOT to PATH
15 | ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
16 |
17 | from utils.callbacks import Callbacks
18 | from utils.dataloaders import create_dataloader
19 | from utils.general import (LOGGER, TQDM_BAR_FORMAT, Profile, check_dataset, check_requirements,
20 | check_yaml, coco80_to_coco91_class, colorstr, increment_path, non_max_suppression,
21 | print_args, scale_boxes, xywh2xyxy, xyxy2xywh)
22 | from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
23 | from utils.plots import output_to_target, plot_images, plot_val_study
24 | from utils.torch_utils import smart_inference_mode
25 |
26 |
27 | import pycuda.autoinit
28 | import pycuda.driver as cuda
29 | import tensorrt as trt
30 |
31 | class HostDeviceMem(object):
32 | def __init__(self, host_mem, device_mem):
33 | self.host = host_mem
34 | self.device = device_mem
35 |
36 | def __str__(self):
37 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)
38 |
39 | def __repr__(self):
40 | return self.__str__()
41 |
42 | ## https://docs.nvidia.com/deeplearning/tensorrt/migration-guide/index.html
43 |
44 | def do_inference(context, inputs, outputs, stream):
45 | # Transfer input data to the GPU.
46 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
47 |
48 | # Run inference.
49 | context.execute_async_v3(stream_handle=stream.handle)
50 |
51 | # Transfer predictions back from the GPU.
52 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
53 |
54 | # Synchronize the stream
55 | stream.synchronize()
56 |
57 | # Return only the host outputs.
58 | return [out.host for out in outputs]
59 |
60 |
61 | def allocate_buffers(engine):
62 | '''
63 | Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
64 | '''
65 | inputs = []
66 | outputs = []
67 | bindings = []
68 | stream = cuda.Stream()
69 |
70 | for i in range(engine.num_io_tensors):
71 | tensor_name = engine.get_tensor_name(i)
72 | size = trt.volume(engine.get_tensor_shape(tensor_name))
73 | dtype = trt.nptype(engine.get_tensor_dtype(tensor_name))
74 |
75 | # Allocate host and device buffers
76 | host_mem = cuda.pagelocked_empty(size, dtype) # page-locked memory buffer (won't swapped to disk)
77 | device_mem = cuda.mem_alloc(host_mem.nbytes)
78 |
79 | # Append the device buffer address to device bindings.
80 | # When cast to int, it's a linear index into the context's memory (like memory address).
81 | bindings.append(int(device_mem))
82 |
83 | # Append to the appropriate input/output list.
84 | if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
85 | inputs.append(HostDeviceMem(host_mem, device_mem))
86 | else:
87 | outputs.append(HostDeviceMem(host_mem, device_mem))
88 |
89 | return inputs, outputs, bindings, stream
90 |
91 | def save_one_txt(predn, save_conf, shape, file):
92 | # Save one txt result
93 | gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
94 | for *xyxy, conf, cls in predn.tolist():
95 | xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
96 | line = (cls, *xywh, conf) if save_conf else (cls, *xywh) # label format
97 | with open(file, 'a') as f:
98 | f.write(('%g ' * len(line)).rstrip() % line + '\n')
99 |
100 |
101 | def save_one_json(predn, jdict, path, class_map):
102 | # Save one JSON result {"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}
103 | image_id = int(path.stem) if path.stem.isnumeric() else path.stem
104 | box = xyxy2xywh(predn[:, :4]) # xywh
105 | box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
106 | for p, b in zip(predn.tolist(), box.tolist()):
107 | jdict.append({
108 | 'image_id': image_id,
109 | 'category_id': class_map[int(p[5])],
110 | 'bbox': [round(x, 3) for x in b],
111 | 'score': round(p[4], 5)})
112 |
113 |
114 | def process_batch(detections, labels, iouv):
115 | """
116 | Return correct prediction matrix
117 | Arguments:
118 | detections (array[N, 6]), x1, y1, x2, y2, conf, class
119 | labels (array[M, 5]), class, x1, y1, x2, y2
120 | Returns:
121 | correct (array[N, 10]), for 10 IoU levels
122 | """
123 | correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
124 | iou = box_iou(labels[:, 1:], detections[:, :4])
125 | correct_class = labels[:, 0:1] == detections[:, 5]
126 | for i in range(len(iouv)):
127 | x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match
128 | if x[0].shape[0]:
129 | matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
130 | if x[0].shape[0] > 1:
131 | matches = matches[matches[:, 2].argsort()[::-1]]
132 | matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
133 | # matches = matches[matches[:, 2].argsort()[::-1]]
134 | matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
135 | correct[matches[:, 1].astype(int), i] = True
136 | return torch.tensor(correct, dtype=torch.bool, device=iouv.device)
137 |
138 |
139 | @smart_inference_mode()
140 | def run(
141 | data,
142 | engine_file=None, # model.pt path(s)
143 | batch_size=1, # batch size
144 | imgsz=640, # inference size (pixels)
145 | conf_thres=0.001, # confidence threshold
146 | iou_thres=0.7, # NMS IoU threshold
147 | max_det=300, # maximum detections per image
148 | task='val', # train, val, test, speed or study
149 | device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
150 | workers=8, # max dataloader workers (per RANK in DDP mode)
151 | single_cls=False, # treat as single-class dataset
152 | augment=False, # augmented inference
153 | verbose=False, # verbose output
154 | save_txt=False, # save results to *.txt
155 | save_hybrid=False, # save label+prediction hybrid results to *.txt
156 | save_conf=False, # save confidences in --save-txt labels
157 | save_json=False, # save a COCO-JSON results file
158 | project=ROOT / 'runs/val', # save to project/name
159 | name='exp', # save to project/name
160 | exist_ok=False, # existing project/name ok, do not increment
161 | half=True, # use FP16 half-precision inference
162 | dnn=False, # use OpenCV DNN for ONNX inference
163 | min_items=0, # Experimental
164 | model=None,
165 | dataloader=None,
166 | save_dir=Path(''),
167 | plots=True,
168 | callbacks=Callbacks(),
169 | compute_loss=None,
170 | prefix=colorstr('VAL-TRT:')
171 | ):
172 | stride=32
173 | logger = trt.Logger(trt.Logger.INFO)
174 | runtime = trt.Runtime(logger)
175 | engine = runtime.deserialize_cuda_engine(open(engine_file, 'rb') .read())
176 | inputs, outputs, bindings, stream = allocate_buffers(engine)
177 |
178 | inputshape = [engine.get_tensor_shape(binding) for binding in engine][0]
179 | imgsz=inputshape[3] - stride ## set imgz from engine
180 | outputshape = [engine.get_tensor_shape(binding) for binding in engine][1]
181 |
182 | context = engine.create_execution_context()
183 |
184 | for i in range(engine.num_io_tensors):
185 | context.set_tensor_address(engine.get_tensor_name(i), bindings[i])
186 |
187 | # Initialize/load model and set device
188 | training = False
189 | device = torch.device("cpu")
190 | gs = 32
191 |
192 | # Directories
193 | save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
194 | (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
195 | data = check_dataset(data) # check
196 | # Configure
197 | is_coco = isinstance(data.get('val'), str) and data['val'].endswith(f'val2017.txt') # COCO dataset
198 | nc = 1 if single_cls else int(data['nc']) # number of classes
199 | iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for mAP@0.5:0.95
200 | niou = iouv.numel()
201 |
202 | # Dataloader
203 | if not training:
204 | task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
205 | dataloader = create_dataloader(data[task],
206 | imgsz,
207 | batch_size,
208 | stride,
209 | single_cls,
210 | pad=0.5,
211 | rect=True,
212 | workers=workers,
213 | min_items=0,
214 | prefix=colorstr(f'{task}: '))[0]
215 |
216 | seen = 0
217 | confusion_matrix = ConfusionMatrix(nc=nc)
218 | #names = model.names if hasattr(model, 'names') else model.module.names # get class names
219 | names=data['names']
220 |
221 | if isinstance(names, (list, tuple)): # old format
222 | names = dict(enumerate(names))
223 | class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
224 | s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP50', 'mAP50-95')
225 | tp, fp, p, r, f1, mp, mr, map50, ap50, map = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
226 | dt = Profile(), Profile(), Profile() # profiling times
227 | loss = torch.zeros(3, device=device)
228 | jdict, stats, ap, ap_class = [], [], [], []
229 | callbacks.run('on_val_start')
230 | pbar = tqdm(dataloader, desc=s, bar_format=TQDM_BAR_FORMAT) # progress bar
231 |
232 | for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
233 | callbacks.run('on_val_batch_start')
234 | with dt[0]:
235 | im = (im / 255.0).float() # 0 - 255 to 0.0 - 1.0
236 | # dataloader is setup pad=0.5
237 | input_image = torch.full((inputshape[1], inputshape[2], inputshape[3]), 114 / 255.0, dtype=torch.float32)
238 | im_cp=im.squeeze(0)
239 | input_image[:, :im_cp.size(1), :im_cp.size(2)] = im_cp
240 | targets = targets.to(device)
241 | nb, _, height, width = im.shape # batch size, channels, height, width
242 | # Run model
243 | with dt[1]:
244 | inputs[0].host = input_image.data.numpy()
245 | trt_outputs = do_inference(context, inputs=inputs, outputs=outputs, stream = stream)
246 | preds = torch.Tensor(trt_outputs[0].reshape(outputshape))
247 |
248 |
249 | # Run NMS
250 | targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
251 | lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
252 |
253 | with dt[2]:
254 | preds = non_max_suppression(preds,
255 | conf_thres,
256 | iou_thres,
257 | labels=lb,
258 | multi_label=True,
259 | agnostic=single_cls,
260 | max_det=max_det)
261 | # Metrics
262 | for si, pred in enumerate(preds):
263 | labels = targets[targets[:, 0] == si, 1:]
264 | nl, npr = labels.shape[0], pred.shape[0] # number of labels, predictions
265 | path, shape = Path(paths[si]), shapes[si][0]
266 | correct = torch.zeros(npr, niou, dtype=torch.bool, device=device) # init
267 | seen += 1
268 |
269 | if npr == 0:
270 | if nl:
271 | stats.append((correct, *torch.zeros((2, 0), device=device), labels[:, 0]))
272 | if plots:
273 | confusion_matrix.process_batch(detections=None, labels=labels[:, 0])
274 | continue
275 |
276 | # Predictions
277 | if single_cls:
278 | pred[:, 5] = 0
279 | predn = pred.clone()
280 | scale_boxes(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
281 |
282 | # Evaluate
283 | if nl:
284 | tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
285 | scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
286 | labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
287 | correct = process_batch(predn, labelsn, iouv)
288 | if plots:
289 | confusion_matrix.process_batch(predn, labelsn)
290 | stats.append((correct, pred[:, 4], pred[:, 5], labels[:, 0])) # (correct, conf, pcls, tcls)
291 |
292 | # Save/log
293 | if save_txt:
294 | save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / f'{path.stem}.txt')
295 | if save_json:
296 | save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
297 | callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
298 |
299 | # Plot images
300 | if plots and batch_i < 3:
301 | plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
302 | plot_images(im, output_to_target(preds), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
303 |
304 | callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, preds)
305 |
306 | # Compute metrics
307 | stats = [torch.cat(x, 0).cpu().numpy() for x in zip(*stats)] # to numpy
308 | if len(stats) and stats[0].any():
309 | tp, fp, p, r, f1, ap, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
310 | ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
311 | mp, mr, map50, map = p.mean(), r.mean(), ap50.mean(), ap.mean()
312 | nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class
313 |
314 | # Print results
315 | pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format
316 | LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
317 | if nt.sum() == 0:
318 | LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')
319 |
320 | # Print results per class
321 | if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
322 | for i, c in enumerate(ap_class):
323 | LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
324 |
325 | # Print speeds
326 | t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
327 | if not training:
328 | shape = (batch_size, 3, imgsz, imgsz)
329 | LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
330 |
331 | # Plots
332 | if plots:
333 | confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
334 | callbacks.run('on_val_end', nt, tp, fp, p, r, f1, ap, ap50, ap_class, confusion_matrix)
335 |
336 | # Save JSON
337 | if save_json and len(jdict):
338 | w = Path(engine_file[0] if isinstance(engine_file, list) else engine_file).stem if engine_file is not None else '' # weights
339 | anno_json = str(Path(data.get('path', '../coco')) / 'annotations/instances_val2017.json') # annotations json
340 | pred_json = str(save_dir / f"{w}_predictions.json") # predictions json
341 | LOGGER.info(f'\nEvaluating pycocotools mAP... saving {pred_json}...')
342 | with open(pred_json, 'w') as f:
343 | json.dump(jdict, f)
344 |
345 | try: # https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocoEvalDemo.ipynb
346 | check_requirements('pycocotools')
347 | from pycocotools.coco import COCO
348 | from pycocotools.cocoeval import COCOeval
349 |
350 | anno = COCO(anno_json) # init annotations api
351 | pred = anno.loadRes(pred_json) # init predictions api
352 | eval = COCOeval(anno, pred, 'bbox')
353 | if is_coco:
354 | eval.params.imgIds = [int(Path(x).stem) for x in dataloader.dataset.im_files] # image IDs to evaluate
355 | eval.evaluate()
356 | eval.accumulate()
357 | eval.summarize()
358 | map, map50 = eval.stats[:2] # update results (mAP@0.5:0.95, mAP@0.5)
359 | except Exception as e:
360 | LOGGER.info(f'pycocotools unable to run: {e}')
361 |
362 | # Return results
363 | if not training:
364 | s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
365 | LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}{s}")
366 | maps = np.zeros(nc) + map
367 | for i, c in enumerate(ap_class):
368 | maps[c] = ap[i]
369 |
370 | eval_mp, eval_mr, eval_map50, eval_map = round(mp, 4), round(mr, 4), round(map50, 4), round(map, 4)
371 | LOGGER.info(f'\n{prefix} Eval TRT - AP: {eval_map} AP50: {eval_map50} Precision: {eval_mp} Recall: {eval_mr}')
372 |
373 | return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
374 |
375 |
376 | def parse_opt():
377 | parser = argparse.ArgumentParser()
378 | parser.add_argument('--data', type=str, default=ROOT / 'data/coco.yaml', help='dataset.yaml path')
379 | parser.add_argument('--engine-file', type=str, default=ROOT / 'yolo.engine', help='model path(s)')
380 | parser.add_argument('--conf-thres', type=float, default=0.001, help='confidence threshold')
381 | parser.add_argument('--iou-thres', type=float, default=0.7, help='NMS IoU threshold')
382 | parser.add_argument('--max-det', type=int, default=300, help='maximum detections per image')
383 | parser.add_argument('--task', default='val', help='train, val, test, speed or study')
384 | parser.add_argument('--device', default='0', help='cuda device, i.e. 0 or 0,1,2,3 ')
385 | parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
386 | parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset')
387 | parser.add_argument('--verbose', action='store_true', help='report mAP by class')
388 | parser.add_argument('--save-txt', action='store_true', help='save results to *.txt')
389 | parser.add_argument('--save-hybrid', action='store_true', help='save label+prediction hybrid results to *.txt')
390 | parser.add_argument('--save-conf', action='store_true', help='save confidences in --save-txt labels')
391 | parser.add_argument('--save-json', action='store_true', help='save a COCO-JSON results file')
392 | parser.add_argument('--project', default=ROOT / 'runs/val_trt', help='save to project/name')
393 | parser.add_argument('--name', default='exp', help='save to project/name')
394 | parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
395 | opt = parser.parse_args()
396 | opt.data = check_yaml(opt.data) # check YAML
397 | opt.save_json |= opt.data.endswith('coco.yaml')
398 | opt.save_txt |= opt.save_hybrid
399 | print_args(vars(opt))
400 | return opt
401 |
402 |
403 | def main(opt):
404 | #check_requirements(exclude=('tensorboard', 'thop'))
405 |
406 | if opt.task in ('train', 'val', 'test'): # run normally
407 | if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
408 | LOGGER.info(f'WARNING ⚠️ confidence threshold {opt.conf_thres} > 0.001 produces invalid results')
409 | if opt.save_hybrid:
410 | LOGGER.info('WARNING ⚠️ --save-hybrid will return high mAP from hybrid labels, not from predictions alone')
411 | run(**vars(opt))
412 |
413 | else:
414 | engine_file = opt.engine_file if isinstance(opt.engine_file, list) else [opt.engine_file]
415 | opt.half = torch.cuda.is_available() and opt.device != 'cpu' # FP16 for fastest results
416 | if opt.task == 'speed': # speed benchmarks
417 | # python val.py --task speed --data coco.yaml --batch 1 --weights yolo.pt...
418 | opt.conf_thres, opt.iou_thres, opt.save_json = 0.25, 0.45, False
419 | for opt.engine_file in engine_file:
420 | run(**vars(opt), plots=False)
421 |
422 | elif opt.task == 'study': # speed vs mAP benchmarks
423 | # python val.py --task study --data coco.yaml --iou 0.7 --weights yolo.pt...
424 | for opt.engine_file in engine_file:
425 | f = f'study_{Path(opt.data).stem}_{Path(opt.engine_file).stem}.txt' # filename to save to
426 | x, y = list(range(256, 1536 + 128, 128)), [] # x axis (image sizes), y axis
427 | for opt.imgsz in x: # img-size
428 | LOGGER.info(f'\nRunning {f} --imgsz {opt.imgsz}...')
429 | r, _, t = run(**vars(opt), plots=False)
430 | y.append(r + t) # results and times
431 | np.savetxt(f, y, fmt='%10.4g') # save
432 | os.system('zip -r study.zip study_*.txt')
433 | plot_val_study(x=x) # plot
434 |
435 |
436 | if __name__ == "__main__":
437 | opt = parse_opt()
438 | main(opt)
439 |
--------------------------------------------------------------------------------