├── README.md ├── yolov11n_onnx ├── test.jpg ├── test_onnx_result.jpg ├── test_pytorch_result.jpg ├── yolov11_onnx_demo.py └── yolov11n_80class_ZQ.onnx ├── yolov11n_rknn ├── data │ └── test.jpg ├── dataset.txt ├── onnx2rknn_zq.py ├── test.jpg ├── test_rknn_result.jpg ├── yolov11n_80class_ZQ.onnx └── yolov11n_80class_ZQ.rknn └── yolov11n_tensorrt ├── images ├── test1.jpg ├── test2.jpg ├── test3.jpg ├── test4.jpg ├── test5.jpg ├── test6.jpg ├── test7.jpg └── test8.jpg ├── onnx2trt.py ├── tensorrt_infer_demo.py ├── test.jpg ├── test_result_tensorRT.jpg ├── yolov11n.onnx ├── yolov11n_fp32.trt └── yolov11n_int8.trt /README.md: -------------------------------------------------------------------------------- 1 | # yolov11_onnx_rknn 2 | yolov11 部署版本,将DFL放在后处理中,便于移植不同平台,后处理为C++部署而写,python 测试后处理时耗意义不大。 3 | 4 | 5 | 导出onnx的流程说明[【yolov11 部署瑞芯微rk3588、RKNN部署工程难度小、模型推理速度快】](https://blog.csdn.net/zhangqian_1/article/details/142722526) 6 | 7 | # 文件夹结构说明 8 | 9 | yolov11n_onnx:onnx模型、测试图像、测试结果、测试demo脚本 10 | 11 | yolov11n_rknn:rknn模型、测试(量化)图像、测试结果、onnx2rknn转换测试脚本(使用的版本rknn_toolkit2-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl) 12 | 13 | yolov11n_TensorRT:TensorRT版本模型、测试图像、测试结果、测试demo脚本、onnx模型、onnx2tensorRT脚本(tensorRT-8.6.1),支持fp32、fp16、int8 14 | 15 | # 测试结果 16 | 17 | pytorch结果 18 | 19 | ![image](https://github.com/cqu20160901/yolov11_onnx_rknn/blob/main/yolov11n_onnx/test_pytorch_result.jpg) 20 | 21 | onnx 结果 22 | 23 | ![image](https://github.com/cqu20160901/yolov11_onnx_rknn/blob/main/yolov11n_onnx/test_onnx_result.jpg) 24 | 25 | 26 | # rk3588 部署结果 27 | 28 | [rk3588 C++部署代码参考链接](https://github.com/cqu20160901/yolov11_dfl_rknn_Cplusplus) 29 | 30 | ![image](https://github.com/cqu20160901/yolov11_dfl_rknn_Cplusplus/blob/main/examples/rknn_yolov11_demo_dfl_open/test_result.jpg) 31 | 32 | 时耗 33 | ![image](https://github.com/cqu20160901/yolov11_dfl_rknn_Cplusplus/blob/main/examples/rknn_yolov11_demo_dfl_open/yolov11_rk3588_costtime.png) 34 | -------------------------------------------------------------------------------- /yolov11n_onnx/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_onnx/test.jpg -------------------------------------------------------------------------------- /yolov11n_onnx/test_onnx_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_onnx/test_onnx_result.jpg -------------------------------------------------------------------------------- /yolov11n_onnx/test_pytorch_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_onnx/test_pytorch_result.jpg -------------------------------------------------------------------------------- /yolov11n_onnx/yolov11_onnx_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding:utf-8 -*- 3 | import argparse 4 | import os 5 | import sys 6 | import os.path as osp 7 | import cv2 8 | import torch 9 | import numpy as np 10 | import onnxruntime as ort 11 | from math import exp 12 | 13 | ROOT = os.getcwd() 14 | if str(ROOT) not in sys.path: 15 | sys.path.append(str(ROOT)) 16 | 17 | 18 | CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 19 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 20 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 21 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 22 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 23 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 24 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 25 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 26 | 'hair drier', 'toothbrush'] 27 | 28 | meshgrid = [] 29 | 30 | class_num = len(CLASSES) 31 | headNum = 3 32 | strides = [8, 16, 32] 33 | mapSize = [[80, 80], [40, 40], [20, 20]] 34 | nmsThresh = 0.45 35 | objectThresh = 0.5 36 | 37 | input_imgH = 640 38 | input_imgW = 640 39 | 40 | 41 | class DetectBox: 42 | def __init__(self, classId, score, xmin, ymin, xmax, ymax): 43 | self.classId = classId 44 | self.score = score 45 | self.xmin = xmin 46 | self.ymin = ymin 47 | self.xmax = xmax 48 | self.ymax = ymax 49 | 50 | 51 | def GenerateMeshgrid(): 52 | for index in range(headNum): 53 | for i in range(mapSize[index][0]): 54 | for j in range(mapSize[index][1]): 55 | meshgrid.append(j + 0.5) 56 | meshgrid.append(i + 0.5) 57 | 58 | 59 | def IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2): 60 | xmin = max(xmin1, xmin2) 61 | ymin = max(ymin1, ymin2) 62 | xmax = min(xmax1, xmax2) 63 | ymax = min(ymax1, ymax2) 64 | 65 | innerWidth = xmax - xmin 66 | innerHeight = ymax - ymin 67 | 68 | innerWidth = innerWidth if innerWidth > 0 else 0 69 | innerHeight = innerHeight if innerHeight > 0 else 0 70 | 71 | innerArea = innerWidth * innerHeight 72 | 73 | area1 = (xmax1 - xmin1) * (ymax1 - ymin1) 74 | area2 = (xmax2 - xmin2) * (ymax2 - ymin2) 75 | 76 | total = area1 + area2 - innerArea 77 | 78 | return innerArea / total 79 | 80 | 81 | def NMS(detectResult): 82 | predBoxs = [] 83 | 84 | sort_detectboxs = sorted(detectResult, key=lambda x: x.score, reverse=True) 85 | 86 | for i in range(len(sort_detectboxs)): 87 | xmin1 = sort_detectboxs[i].xmin 88 | ymin1 = sort_detectboxs[i].ymin 89 | xmax1 = sort_detectboxs[i].xmax 90 | ymax1 = sort_detectboxs[i].ymax 91 | classId = sort_detectboxs[i].classId 92 | 93 | if sort_detectboxs[i].classId != -1: 94 | predBoxs.append(sort_detectboxs[i]) 95 | for j in range(i + 1, len(sort_detectboxs), 1): 96 | if classId == sort_detectboxs[j].classId: 97 | xmin2 = sort_detectboxs[j].xmin 98 | ymin2 = sort_detectboxs[j].ymin 99 | xmax2 = sort_detectboxs[j].xmax 100 | ymax2 = sort_detectboxs[j].ymax 101 | iou = IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2) 102 | if iou > nmsThresh: 103 | sort_detectboxs[j].classId = -1 104 | return predBoxs 105 | 106 | 107 | def sigmoid(x): 108 | return 1 / (1 + exp(-x)) 109 | 110 | 111 | def postprocess(out, img_h, img_w): 112 | print('postprocess ... ') 113 | 114 | detectResult = [] 115 | output = [] 116 | for i in range(len(out)): 117 | print(out[i].shape) 118 | output.append(out[i].reshape((-1))) 119 | 120 | scale_h = img_h / input_imgH 121 | scale_w = img_w / input_imgW 122 | 123 | gridIndex = -2 124 | cls_index = 0 125 | cls_max = 0 126 | 127 | for index in range(headNum): 128 | reg = output[index * 2 + 0] 129 | cls = output[index * 2 + 1] 130 | 131 | for h in range(mapSize[index][0]): 132 | for w in range(mapSize[index][1]): 133 | gridIndex += 2 134 | 135 | if 1 == class_num: 136 | cls_max = sigmoid(cls[0 * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) 137 | cls_index = 0 138 | else: 139 | for cl in range(class_num): 140 | cls_val = cls[cl * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w] 141 | if 0 == cl: 142 | cls_max = cls_val 143 | cls_index = cl 144 | else: 145 | if cls_val > cls_max: 146 | cls_max = cls_val 147 | cls_index = cl 148 | cls_max = sigmoid(cls_max) 149 | 150 | if cls_max > objectThresh: 151 | regdfl = [] 152 | for lc in range(4): 153 | sfsum = 0 154 | locval = 0 155 | for df in range(16): 156 | temp = exp(reg[((lc * 16) + df) * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) 157 | reg[((lc * 16) + df) * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w] = temp 158 | sfsum += temp 159 | 160 | for df in range(16): 161 | sfval = reg[((lc * 16) + df) * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w] / sfsum 162 | locval += sfval * df 163 | regdfl.append(locval) 164 | 165 | x1 = (meshgrid[gridIndex + 0] - regdfl[0]) * strides[index] 166 | y1 = (meshgrid[gridIndex + 1] - regdfl[1]) * strides[index] 167 | x2 = (meshgrid[gridIndex + 0] + regdfl[2]) * strides[index] 168 | y2 = (meshgrid[gridIndex + 1] + regdfl[3]) * strides[index] 169 | 170 | xmin = x1 * scale_w 171 | ymin = y1 * scale_h 172 | xmax = x2 * scale_w 173 | ymax = y2 * scale_h 174 | 175 | xmin = xmin if xmin > 0 else 0 176 | ymin = ymin if ymin > 0 else 0 177 | xmax = xmax if xmax < img_w else img_w 178 | ymax = ymax if ymax < img_h else img_h 179 | 180 | box = DetectBox(cls_index, cls_max, xmin, ymin, xmax, ymax) 181 | detectResult.append(box) 182 | # NMS 183 | print('detectResult:', len(detectResult)) 184 | predBox = NMS(detectResult) 185 | 186 | return predBox 187 | 188 | 189 | def precess_image(img_src, resize_w, resize_h): 190 | image = cv2.resize(img_src, (resize_w, resize_h), interpolation=cv2.INTER_LINEAR) 191 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 192 | image = image.astype(np.float32) 193 | image /= 255.0 194 | 195 | return image 196 | 197 | 198 | def detect(img_path): 199 | 200 | orig = cv2.imread(img_path) 201 | img_h, img_w = orig.shape[:2] 202 | image = precess_image(orig, input_imgW, input_imgH) 203 | 204 | image = image.transpose((2, 0, 1)) 205 | image = np.expand_dims(image, axis=0) 206 | 207 | # image = np.ones((1, 3, 384, 640), dtype=np.float32) 208 | # print(image.shape) 209 | 210 | ort_session = ort.InferenceSession('./yolov11n_80class_ZQ.onnx') 211 | pred_results = (ort_session.run(None, {'data': image})) 212 | 213 | out = [] 214 | for i in range(len(pred_results)): 215 | out.append(pred_results[i]) 216 | predbox = postprocess(out, img_h, img_w) 217 | 218 | print('obj num is :', len(predbox)) 219 | 220 | for i in range(len(predbox)): 221 | xmin = int(predbox[i].xmin) 222 | ymin = int(predbox[i].ymin) 223 | xmax = int(predbox[i].xmax) 224 | ymax = int(predbox[i].ymax) 225 | classId = predbox[i].classId 226 | score = predbox[i].score 227 | 228 | cv2.rectangle(orig, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) 229 | ptext = (xmin, ymin) 230 | title = CLASSES[classId] + "%.2f" % score 231 | cv2.putText(orig, title, ptext, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2, cv2.LINE_AA) 232 | 233 | cv2.imwrite('./test_onnx_result.jpg', orig) 234 | 235 | 236 | if __name__ == '__main__': 237 | print('This is main ....') 238 | GenerateMeshgrid() 239 | img_path = './test.jpg' 240 | detect(img_path) -------------------------------------------------------------------------------- /yolov11n_onnx/yolov11n_80class_ZQ.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_onnx/yolov11n_80class_ZQ.onnx -------------------------------------------------------------------------------- /yolov11n_rknn/data/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_rknn/data/test.jpg -------------------------------------------------------------------------------- /yolov11n_rknn/dataset.txt: -------------------------------------------------------------------------------- 1 | ./data/test.jpg -------------------------------------------------------------------------------- /yolov11n_rknn/onnx2rknn_zq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib 3 | import traceback 4 | import time 5 | import sys 6 | import numpy as np 7 | import cv2 8 | from rknn.api import RKNN 9 | from math import exp 10 | 11 | ONNX_MODEL = './yolov11n_80class_ZQ.onnx' 12 | RKNN_MODEL = './yolov11n_80class_ZQ.rknn' 13 | DATASET = './dataset.txt' 14 | 15 | QUANTIZE_ON = True 16 | 17 | CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 18 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 19 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 20 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 21 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 22 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 23 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 24 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 25 | 'hair drier', 'toothbrush'] 26 | 27 | meshgrid = [] 28 | 29 | class_num = len(CLASSES) 30 | headNum = 3 31 | strides = [8, 16, 32] 32 | mapSize = [[80, 80], [40, 40], [20, 20]] 33 | nmsThresh = 0.5 34 | objectThresh = 0.5 35 | 36 | input_imgH = 640 37 | input_imgW = 640 38 | 39 | 40 | class DetectBox: 41 | def __init__(self, classId, score, xmin, ymin, xmax, ymax): 42 | self.classId = classId 43 | self.score = score 44 | self.xmin = xmin 45 | self.ymin = ymin 46 | self.xmax = xmax 47 | self.ymax = ymax 48 | 49 | def GenerateMeshgrid(): 50 | for index in range(headNum): 51 | for i in range(mapSize[index][0]): 52 | for j in range(mapSize[index][1]): 53 | meshgrid.append(j + 0.5) 54 | meshgrid.append(i + 0.5) 55 | 56 | 57 | def IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2): 58 | xmin = max(xmin1, xmin2) 59 | ymin = max(ymin1, ymin2) 60 | xmax = min(xmax1, xmax2) 61 | ymax = min(ymax1, ymax2) 62 | 63 | innerWidth = xmax - xmin 64 | innerHeight = ymax - ymin 65 | 66 | innerWidth = innerWidth if innerWidth > 0 else 0 67 | innerHeight = innerHeight if innerHeight > 0 else 0 68 | 69 | innerArea = innerWidth * innerHeight 70 | 71 | area1 = (xmax1 - xmin1) * (ymax1 - ymin1) 72 | area2 = (xmax2 - xmin2) * (ymax2 - ymin2) 73 | 74 | total = area1 + area2 - innerArea 75 | 76 | return innerArea / total 77 | 78 | 79 | def NMS(detectResult): 80 | predBoxs = [] 81 | 82 | sort_detectboxs = sorted(detectResult, key=lambda x: x.score, reverse=True) 83 | 84 | for i in range(len(sort_detectboxs)): 85 | xmin1 = sort_detectboxs[i].xmin 86 | ymin1 = sort_detectboxs[i].ymin 87 | xmax1 = sort_detectboxs[i].xmax 88 | ymax1 = sort_detectboxs[i].ymax 89 | classId = sort_detectboxs[i].classId 90 | 91 | if sort_detectboxs[i].classId != -1: 92 | predBoxs.append(sort_detectboxs[i]) 93 | for j in range(i + 1, len(sort_detectboxs), 1): 94 | if classId == sort_detectboxs[j].classId: 95 | xmin2 = sort_detectboxs[j].xmin 96 | ymin2 = sort_detectboxs[j].ymin 97 | xmax2 = sort_detectboxs[j].xmax 98 | ymax2 = sort_detectboxs[j].ymax 99 | iou = IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2) 100 | if iou > nmsThresh: 101 | sort_detectboxs[j].classId = -1 102 | return predBoxs 103 | 104 | 105 | def sigmoid(x): 106 | return 1 / (1 + exp(-x)) 107 | 108 | 109 | def postprocess(out, img_h, img_w): 110 | print('postprocess ... ') 111 | 112 | detectResult = [] 113 | output = [] 114 | for i in range(len(out)): 115 | print(out[i].shape) 116 | output.append(out[i].reshape((-1))) 117 | 118 | scale_h = img_h / input_imgH 119 | scale_w = img_w / input_imgW 120 | 121 | gridIndex = -2 122 | cls_index = 0 123 | cls_max = 0 124 | 125 | for index in range(headNum): 126 | cls = output[index * 2 + 0] 127 | reg = output[index * 2 + 1] 128 | 129 | for h in range(mapSize[index][0]): 130 | for w in range(mapSize[index][1]): 131 | gridIndex += 2 132 | 133 | if 1 == class_num: 134 | cls_max = sigmoid(cls[0 * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) 135 | cls_index = 0 136 | else: 137 | for cl in range(class_num): 138 | cls_val = cls[cl * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w] 139 | if 0 == cl: 140 | cls_max = cls_val 141 | cls_index = cl 142 | else: 143 | if cls_val > cls_max: 144 | cls_max = cls_val 145 | cls_index = cl 146 | cls_max = sigmoid(cls_max) 147 | 148 | if cls_max > objectThresh: 149 | regdfl = [] 150 | for lc in range(4): 151 | sfsum = 0 152 | locval = 0 153 | for df in range(16): 154 | temp = exp(reg[((lc * 16) + df) * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) 155 | reg[((lc * 16) + df) * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w] = temp 156 | sfsum += temp 157 | 158 | for df in range(16): 159 | sfval = reg[((lc * 16) + df) * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w] / sfsum 160 | locval += sfval * df 161 | regdfl.append(locval) 162 | 163 | x1 = (meshgrid[gridIndex + 0] - regdfl[0]) * strides[index] 164 | y1 = (meshgrid[gridIndex + 1] - regdfl[1]) * strides[index] 165 | x2 = (meshgrid[gridIndex + 0] + regdfl[2]) * strides[index] 166 | y2 = (meshgrid[gridIndex + 1] + regdfl[3]) * strides[index] 167 | 168 | xmin = x1 * scale_w 169 | ymin = y1 * scale_h 170 | xmax = x2 * scale_w 171 | ymax = y2 * scale_h 172 | 173 | xmin = xmin if xmin > 0 else 0 174 | ymin = ymin if ymin > 0 else 0 175 | xmax = xmax if xmax < img_w else img_w 176 | ymax = ymax if ymax < img_h else img_h 177 | 178 | box = DetectBox(cls_index, cls_max, xmin, ymin, xmax, ymax) 179 | detectResult.append(box) 180 | # NMS 181 | print('detectResult:', len(detectResult)) 182 | predBox = NMS(detectResult) 183 | 184 | return predBox 185 | 186 | 187 | def export_rknn_inference(img): 188 | # Create RKNN object 189 | rknn = RKNN(verbose=False) 190 | 191 | # pre-process config 192 | print('--> Config model') 193 | rknn.config(mean_values=[[0, 0, 0]], std_values=[[255, 255, 255]], quantized_algorithm='normal', quantized_method='channel', target_platform='rk3588') 194 | print('done') 195 | 196 | # Load ONNX model 197 | print('--> Loading model') 198 | ret = rknn.load_onnx(model=ONNX_MODEL, outputs=['cls1', 'reg1', 'cls2', 'reg2', 'cls3', 'reg3']) 199 | if ret != 0: 200 | print('Load model failed!') 201 | exit(ret) 202 | print('done') 203 | 204 | # Build model 205 | print('--> Building model') 206 | ret = rknn.build(do_quantization=QUANTIZE_ON, dataset=DATASET, rknn_batch_size=1) 207 | if ret != 0: 208 | print('Build model failed!') 209 | exit(ret) 210 | print('done') 211 | 212 | # Export RKNN model 213 | print('--> Export rknn model') 214 | ret = rknn.export_rknn(RKNN_MODEL) 215 | if ret != 0: 216 | print('Export rknn model failed!') 217 | exit(ret) 218 | print('done') 219 | 220 | # Init runtime environment 221 | print('--> Init runtime environment') 222 | ret = rknn.init_runtime() 223 | # ret = rknn.init_runtime(target='rk3566') 224 | if ret != 0: 225 | print('Init runtime environment failed!') 226 | exit(ret) 227 | print('done') 228 | 229 | # Inference 230 | print('--> Running model') 231 | outputs = rknn.inference(inputs=[img]) 232 | rknn.release() 233 | print('done') 234 | 235 | return outputs 236 | 237 | 238 | if __name__ == '__main__': 239 | print('This is main ...') 240 | GenerateMeshgrid() 241 | 242 | img_path = './test.jpg' 243 | orig_img = cv2.imread(img_path) 244 | img_h, img_w = orig_img.shape[:2] 245 | 246 | 247 | origimg = cv2.resize(orig_img, (input_imgW, input_imgH), interpolation=cv2.INTER_LINEAR) 248 | origimg = cv2.cvtColor(origimg, cv2.COLOR_BGR2RGB) 249 | 250 | img = np.expand_dims(origimg, 0) 251 | 252 | outputs = export_rknn_inference(img) 253 | 254 | out = [] 255 | for i in range(len(outputs)): 256 | out.append(outputs[i]) 257 | 258 | predbox = postprocess(out, img_h, img_w) 259 | 260 | print(len(predbox)) 261 | 262 | for i in range(len(predbox)): 263 | xmin = int(predbox[i].xmin) 264 | ymin = int(predbox[i].ymin) 265 | xmax = int(predbox[i].xmax) 266 | ymax = int(predbox[i].ymax) 267 | classId = predbox[i].classId 268 | score = predbox[i].score 269 | 270 | cv2.rectangle(orig_img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) 271 | ptext = (xmin, ymin) 272 | title = CLASSES[classId] + ":%.2f" % (score) 273 | cv2.putText(orig_img, title, ptext, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2, cv2.LINE_AA) 274 | 275 | cv2.imwrite('./test_rknn_result.jpg', orig_img) 276 | # cv2.imshow("test", origimg) 277 | # cv2.waitKey(0) 278 | -------------------------------------------------------------------------------- /yolov11n_rknn/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_rknn/test.jpg -------------------------------------------------------------------------------- /yolov11n_rknn/test_rknn_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_rknn/test_rknn_result.jpg -------------------------------------------------------------------------------- /yolov11n_rknn/yolov11n_80class_ZQ.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_rknn/yolov11n_80class_ZQ.onnx -------------------------------------------------------------------------------- /yolov11n_rknn/yolov11n_80class_ZQ.rknn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_rknn/yolov11n_80class_ZQ.rknn -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test1.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test2.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test3.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test4.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test5.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test6.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test7.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/images/test8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/images/test8.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/onnx2trt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & 3 | # AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import sys 20 | import logging 21 | import argparse 22 | from typing import Dict 23 | import numpy as np 24 | import tensorrt as trt 25 | import pycuda.driver as cuda 26 | 27 | # Use autoprimaryctx if available (pycuda >= 2021.1) to 28 | # prevent issues with other modules that rely on the primary 29 | # device context. 30 | try: 31 | import pycuda.autoprimaryctx 32 | except ModuleNotFoundError: 33 | import pycuda.autoinit 34 | 35 | logging.basicConfig(level=logging.INFO) 36 | logging.getLogger("EngineBuilder").setLevel(logging.INFO) 37 | log = logging.getLogger("EngineBuilder") 38 | 39 | 40 | def preprocess(img_path): 41 | img_src = cv2.imread(img_path) 42 | image = cv2.resize(img_src, (640, 640), interpolation=cv2.INTER_LINEAR) 43 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 44 | image = image.astype(np.float32) 45 | image /= 255.0 46 | image = image.transpose((2, 0, 1)) 47 | image = np.expand_dims(image, axis=0) 48 | return image 49 | 50 | 51 | class ImageBatcher: 52 | """ 53 | Creates batches of pre-processed images. 54 | """ 55 | 56 | def __init__(self, input, shape, dtype, max_num_images=None, 57 | exact_batches=False): 58 | """ 59 | Args: 60 | input: The input directory to read images from. 61 | shape: The tensor shape of the batch to prepare, 62 | either in NCHW or NHWC format. 63 | dtype: The (numpy) datatype to cast the batched data to. 64 | max_num_images: The maximum number of images to read from 65 | the directory. 66 | exact_batches: This defines how to handle a number of images that 67 | is not an exact multiple of the batch size. If false, it will 68 | pad the final batch with zeros to reach the batch size. 69 | If true, it will *remove* the last few images in excess of a 70 | batch size multiple, to guarantee batches are exact (useful 71 | for calibration). 72 | """ 73 | # Find images in the given input path 74 | input = os.path.realpath(input) 75 | self.images = [] 76 | 77 | extensions = [".jpg", ".jpeg", ".png", ".bmp"] 78 | 79 | def is_image(path): 80 | return os.path.isfile(path) and os.path.splitext(path)[ 81 | 1].lower() in extensions 82 | 83 | if os.path.isdir(input): 84 | self.images = [os.path.join(input, f) for f in os.listdir(input) if is_image(os.path.join(input, f))] 85 | self.images.sort() 86 | elif os.path.isfile(input): 87 | if is_image(input): 88 | self.images.append(input) 89 | self.num_images = len(self.images) 90 | if self.num_images < 1: 91 | print("No valid {} images found in {}".format("/".join(extensions), input)) 92 | sys.exit(1) 93 | 94 | # Handle Tensor Shape 95 | self.dtype = dtype 96 | self.shape = shape 97 | assert len(self.shape) == 4 98 | self.batch_size = shape[0] 99 | assert self.batch_size > 0 100 | self.format = None 101 | self.width = -1 102 | self.height = -1 103 | if self.shape[1] == 3: 104 | self.format = "NCHW" 105 | self.height = self.shape[2] 106 | self.width = self.shape[3] 107 | elif self.shape[3] == 3: 108 | self.format = "NHWC" 109 | self.height = self.shape[1] 110 | self.width = self.shape[2] 111 | assert all([self.format, self.width > 0, self.height > 0]) 112 | 113 | # Adapt the number of images as needed 114 | if max_num_images and 0 < max_num_images < len(self.images): 115 | self.num_images = max_num_images 116 | if exact_batches: 117 | self.num_images = self.batch_size * (self.num_images // self.batch_size) 118 | if self.num_images < 1: 119 | print("Not enough images to create batches") 120 | sys.exit(1) 121 | self.images = self.images[0: self.num_images] 122 | print('') 123 | # Subdivide the list of images into batches 124 | self.num_batches = 1 + int((self.num_images - 1) / self.batch_size) 125 | self.batches = [] 126 | for i in range(self.num_batches): 127 | start = i * self.batch_size 128 | end = min(start + self.batch_size, self.num_images) 129 | self.batches.append(self.images[start:end]) 130 | # Indices 131 | self.image_index = 0 132 | self.batch_index = 0 133 | 134 | def get_batch(self): 135 | """ 136 | Retrieve the batches. This is a generator object, so you can use it 137 | within a loop as: for batch, images in batcher.get_batch(): ... Or 138 | outside of a batch with the next() function. 139 | 140 | Returns: 141 | A generator yielding two items per iteration: a numpy array holding 142 | a batch of images, and the list of paths to the images loaded 143 | within this batch. 144 | """ 145 | for _, batch_images in enumerate(self.batches): 146 | batch_data = np.zeros(self.shape, dtype=self.dtype) 147 | for i, image in enumerate(batch_images): 148 | self.image_index += 1 149 | batch_data[i] = preprocess(image) 150 | self.batch_index += 1 151 | yield batch_data, batch_images 152 | 153 | 154 | 155 | class EngineCalibrator(trt.IInt8EntropyCalibrator2): 156 | """ 157 | Implements the INT8 Entropy Calibrator 2. 158 | """ 159 | 160 | def __init__(self, cache_file): 161 | """ 162 | :param cache_file: The location of the cache file. 163 | """ 164 | super().__init__() 165 | self.cache_file = cache_file 166 | self.image_batcher = None 167 | self.batch_allocation = None 168 | self.batch_generator = None 169 | 170 | def set_image_batcher(self, image_batcher: ImageBatcher): 171 | """ 172 | Define the image batcher to use, if any. If using only the cache 173 | file, an image batcher doesn't need to be defined. :param 174 | image_batcher: The ImageBatcher object 175 | """ 176 | self.image_batcher = image_batcher 177 | size = int(np.dtype(self.image_batcher.dtype).itemsize * np.prod( 178 | self.image_batcher.shape)) 179 | self.batch_allocation = cuda.mem_alloc(size) 180 | self.batch_generator = self.image_batcher.get_batch() 181 | 182 | def get_batch_size(self): 183 | """ 184 | Overrides from trt.IInt8EntropyCalibrator2. 185 | Get the batch size to use for calibration. 186 | :return: Batch size. 187 | """ 188 | if self.image_batcher: 189 | return self.image_batcher.batch_size 190 | return 1 191 | 192 | def get_batch(self, names): 193 | """ 194 | Overrides from trt.IInt8EntropyCalibrator2. Get the next batch to 195 | use for calibration, as a list of device memory pointers. :param 196 | names: The names of the inputs, if useful to define the order of 197 | inputs. :return: A list of int-casted memory pointers. 198 | """ 199 | if not self.image_batcher: 200 | return None 201 | try: 202 | batch, _ = next(self.batch_generator) 203 | log.info("Calibrating image {} / {}".format( 204 | self.image_batcher.image_index, self.image_batcher.num_images)) 205 | cuda.memcpy_htod(self.batch_allocation, 206 | np.ascontiguousarray(batch)) 207 | return [int(self.batch_allocation)] 208 | except StopIteration: 209 | log.info("Finished calibration batches") 210 | return None 211 | 212 | def read_calibration_cache(self): 213 | """ 214 | Overrides from trt.IInt8EntropyCalibrator2. 215 | Read the calibration cache file stored on disk, if it exists. 216 | :return: The contents of the cache file, if any. 217 | """ 218 | if os.path.exists(self.cache_file): 219 | with open(self.cache_file, "rb") as f: 220 | log.info("Using calibration cache file: {}".format(self.cache_file)) 221 | return f.read() 222 | 223 | def write_calibration_cache(self, cache): 224 | """ 225 | Overrides from trt.IInt8EntropyCalibrator2. 226 | Store the calibration cache to a file on disk. 227 | :param cache: The contents of the calibration cache to store. 228 | """ 229 | with open(self.cache_file, "wb") as f: 230 | log.info("Writing calibration cache data to: {}".format( 231 | self.cache_file)) 232 | f.write(cache) 233 | 234 | 235 | class EngineBuilder: 236 | """ 237 | Parses an ONNX graph and builds a TensorRT engine from it. 238 | """ 239 | 240 | def __init__(self, verbose=False): 241 | """ 242 | :param verbose: If enabled, a higher verbosity level will be set on 243 | the TensorRT logger. 244 | """ 245 | self.trt_logger = trt.Logger(trt.Logger.INFO) 246 | if verbose: 247 | self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE 248 | 249 | trt.init_libnvinfer_plugins(self.trt_logger, namespace="") 250 | 251 | self.builder = trt.Builder(self.trt_logger) 252 | self.config = self.builder.create_builder_config() 253 | self.config.max_workspace_size = 4 * (2 ** 30) # 4 GB 254 | 255 | self.batch_size = None 256 | self.network = None 257 | self.parser = None 258 | 259 | def create_network(self, onnx_path, input_shapes: Dict = None): 260 | """ 261 | Parse the ONNX graph and create the corresponding TensorRT network 262 | definition. :param onnx_path: The path to the ONNX graph to load. 263 | 264 | """ 265 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 266 | 267 | self.network = self.builder.create_network(network_flags) 268 | self.parser = trt.OnnxParser(self.network, self.trt_logger) 269 | 270 | onnx_path = os.path.realpath(onnx_path) 271 | with open(onnx_path, "rb") as f: 272 | if not self.parser.parse(f.read()): 273 | log.error("Failed to load ONNX file: {}".format(onnx_path)) 274 | for error in range(self.parser.num_errors): 275 | log.error(self.parser.get_error(error)) 276 | sys.exit(1) 277 | 278 | inputs = [self.network.get_input(i) for i in 279 | range(self.network.num_inputs)] 280 | outputs = [self.network.get_output(i) for i in 281 | range(self.network.num_outputs)] 282 | 283 | log.info("Network Description") 284 | for input in inputs: 285 | self.batch_size = input.shape[0] 286 | log.info("Input '{}' with shape {} and dtype {}".format(input.name, 287 | input.shape, 288 | input.dtype)) 289 | for output in outputs: 290 | print(output.name, output.shape, output.dtype) 291 | log.info( 292 | "Output '{}' with shape {} and dtype {}".format(output.name, 293 | output.shape, 294 | output.dtype)) 295 | profile = self.builder.create_optimization_profile() 296 | for input_name, param in input_shapes.items(): 297 | min_shape = param['min_shape'] 298 | opt_shape = param['opt_shape'] 299 | max_shape = param['max_shape'] 300 | profile.set_shape(input_name, min_shape, opt_shape, max_shape) 301 | if self.config.add_optimization_profile(profile) < 0: 302 | log.warning(f'Invalid optimization profile {profile}.') 303 | 304 | # assert self.batch_size > 0 305 | # self.builder.max_batch_size = self.batch_size 306 | 307 | def create_engine( 308 | self, 309 | engine_path, 310 | precision, 311 | calib_input=None, 312 | calib_cache=None, 313 | calib_num_images=25000, 314 | calib_batch_size=8 315 | ): 316 | """ 317 | Build the TensorRT engine and serialize it to disk. 318 | 319 | Args: 320 | engine_path: The path where to serialize the engine to. 321 | precision: The datatype to use for the engine, either 'fp32', 322 | 'fp16' or 'int8'. 323 | calib_input: The path to a directory, holding the calibration 324 | images. 325 | calib_cache: The path where to write the calibration cache to, 326 | or if it already exists, load it from. 327 | calib_num_images: The maximum number of images to use for 328 | calibration. 329 | calib_batch_size: The batch size to use for the calibration 330 | process. 331 | """ 332 | engine_path = os.path.realpath(engine_path) 333 | engine_dir = os.path.dirname(engine_path) 334 | os.makedirs(engine_dir, exist_ok=True) 335 | log.info("Building {} Engine in {}".format(precision, engine_path)) 336 | 337 | inputs = [self.network.get_input(i) for i in 338 | range(self.network.num_inputs)] 339 | 340 | if precision == "fp16": 341 | if not self.builder.platform_has_fast_fp16: 342 | log.warning( 343 | "FP16 is not supported natively on this platform/device") 344 | else: 345 | self.config.set_flag(trt.BuilderFlag.FP16) 346 | elif precision == "int8": 347 | if not self.builder.platform_has_fast_int8: 348 | log.warning( 349 | "INT8 is not supported natively on this platform/device") 350 | else: 351 | self.config.set_flag(trt.BuilderFlag.INT8) 352 | self.config.int8_calibrator = EngineCalibrator(calib_cache) 353 | if not os.path.exists(calib_cache): 354 | calib_shape = [calib_batch_size] + list( 355 | inputs[0].shape[1:]) 356 | calib_dtype = trt.nptype(inputs[0].dtype) 357 | self.config.int8_calibrator.set_image_batcher( 358 | ImageBatcher( 359 | calib_input, 360 | calib_shape, 361 | calib_dtype, 362 | max_num_images=calib_num_images, 363 | exact_batches=True 364 | ) 365 | ) 366 | 367 | 368 | engine = self.builder.build_engine(self.network, self.config) 369 | if engine is None: 370 | print("ERROR: Failed to build the TensorRT engine.") 371 | exit(1) 372 | with open(engine_path, "wb") as f: 373 | f.write(engine.serialize()) 374 | 375 | 376 | def main(args): 377 | builder = EngineBuilder(args.verbose) 378 | builder.create_network(args.onnx, input_shapes=dict( 379 | input=dict(min_shape=[1, 3, 640, 640], 380 | opt_shape=[2, 3, 640, 640], 381 | max_shape=[4, 3, 640, 640]))) 382 | builder.create_engine( 383 | args.engine, 384 | args.precision, 385 | args.calib_input, 386 | args.calib_cache, 387 | args.calib_num_images, 388 | args.calib_batch_size 389 | ) 390 | 391 | 392 | if __name__ == "__main__": 393 | parser = argparse.ArgumentParser() 394 | parser.add_argument("--onnx", default="./yolov11n.onnx", help="The input ONNX model file to load") 395 | parser.add_argument("--engine", default="./yolov11n.trt", help="The output path for the TRT engine") 396 | parser.add_argument( 397 | "-p", 398 | "--precision", 399 | default="fp32", 400 | choices=["fp32", "fp16", "int8"], 401 | help="The precision mode to build in, either 'fp32', 'fp16' or " 402 | "'int8', default: 'fp16'", 403 | ) 404 | parser.add_argument("--verbose", default=False, help="Enable more verbose log output") 405 | parser.add_argument("--calib_input", default="./images", help="The directory holding images to use for calibration") 406 | parser.add_argument( 407 | "--calib_cache", 408 | default="./calibration", 409 | help="The file path for INT8 calibration cache to use, default: ./calibration.cache", 410 | ) 411 | parser.add_argument( 412 | "--calib_num_images", 413 | default=8, 414 | type=int, 415 | help="The maximum number of images to use for calibration, default: " 416 | "8", 417 | ) 418 | parser.add_argument( 419 | "--calib_batch_size", default=2, type=int, 420 | help="The batch size for the calibration process, default: 1" 421 | ) 422 | args = parser.parse_args() 423 | if not all([args.onnx, args.engine]): 424 | parser.print_help() 425 | log.error("These arguments are required: --onnx and --engine") 426 | sys.exit(1) 427 | if args.precision == "int8" and not any( 428 | [args.calib_input, args.calib_cache]): 429 | parser.print_help() 430 | log.error( 431 | "When building in int8 precision, either --calib_input or " 432 | "--calib_cache are required") 433 | sys.exit(1) 434 | main(args) -------------------------------------------------------------------------------- /yolov11n_tensorrt/tensorrt_infer_demo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import tensorrt as trt 4 | import pycuda.driver as cuda 5 | import pycuda.autoinit 6 | from math import exp 7 | from math import sqrt 8 | import time 9 | 10 | TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) 11 | 12 | 13 | CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 14 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 15 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 16 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 17 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 18 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 19 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 20 | 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 21 | 'hair drier', 'toothbrush'] 22 | 23 | meshgrid = [] 24 | 25 | class_num = len(CLASSES) 26 | headNum = 3 27 | strides = [8, 16, 32] 28 | mapSize = [[80, 80], [40, 40], [20, 20]] 29 | nmsThresh = 0.45 30 | objectThresh = 0.35 31 | 32 | input_imgH = 640 33 | input_imgW = 640 34 | 35 | # Simple helper data class that's a little nicer to use than a 2-tuple. 36 | class HostDeviceMem(object): 37 | def __init__(self, host_mem, device_mem): 38 | self.host = host_mem 39 | self.device = device_mem 40 | 41 | def __str__(self): 42 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) 43 | 44 | def __repr__(self): 45 | return self.__str__() 46 | 47 | 48 | def allocate_buffers(engine): 49 | inputs = [] 50 | outputs = [] 51 | bindings = [] 52 | stream = cuda.Stream() 53 | for binding in engine: 54 | size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size 55 | dtype = trt.nptype(engine.get_binding_dtype(binding)) 56 | # Allocate host and device buffers 57 | host_mem = cuda.pagelocked_empty(size, dtype) 58 | device_mem = cuda.mem_alloc(host_mem.nbytes) 59 | # Append the device buffer to device bindings. 60 | bindings.append(int(device_mem)) 61 | # Append to the appropriate list. 62 | if engine.binding_is_input(binding): 63 | inputs.append(HostDeviceMem(host_mem, device_mem)) 64 | else: 65 | outputs.append(HostDeviceMem(host_mem, device_mem)) 66 | return inputs, outputs, bindings, stream 67 | 68 | 69 | def get_engine_from_bin(engine_file_path): 70 | print('Reading engine from file {}'.format(engine_file_path)) 71 | with open(engine_file_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime: 72 | return runtime.deserialize_cuda_engine(f.read()) 73 | 74 | 75 | # This function is generalized for multiple inputs/outputs. 76 | # inputs and outputs are expected to be lists of HostDeviceMem objects. 77 | def do_inference(context, bindings, inputs, outputs, stream, batch_size=1): 78 | # Transfer input data to the GPU. 79 | [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs] 80 | # Run inference. 81 | context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle) 82 | # Transfer predictions back from the GPU. 83 | [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs] 84 | # Synchronize the stream 85 | stream.synchronize() 86 | # Return only the host outputs. 87 | return [out.host for out in outputs] 88 | 89 | 90 | class DetectBox: 91 | def __init__(self, classId, score, xmin, ymin, xmax, ymax): 92 | self.classId = classId 93 | self.score = score 94 | self.xmin = xmin 95 | self.ymin = ymin 96 | self.xmax = xmax 97 | self.ymax = ymax 98 | 99 | 100 | def GenerateMeshgrid(): 101 | for index in range(headNum): 102 | for i in range(mapSize[index][0]): 103 | for j in range(mapSize[index][1]): 104 | meshgrid.append(j + 0.5) 105 | meshgrid.append(i + 0.5) 106 | 107 | 108 | def IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2): 109 | xmin = max(xmin1, xmin2) 110 | ymin = max(ymin1, ymin2) 111 | xmax = min(xmax1, xmax2) 112 | ymax = min(ymax1, ymax2) 113 | 114 | innerWidth = xmax - xmin 115 | innerHeight = ymax - ymin 116 | 117 | innerWidth = innerWidth if innerWidth > 0 else 0 118 | innerHeight = innerHeight if innerHeight > 0 else 0 119 | 120 | innerArea = innerWidth * innerHeight 121 | 122 | area1 = (xmax1 - xmin1) * (ymax1 - ymin1) 123 | area2 = (xmax2 - xmin2) * (ymax2 - ymin2) 124 | 125 | total = area1 + area2 - innerArea 126 | 127 | return innerArea / total 128 | 129 | 130 | def NMS(detectResult): 131 | predBoxs = [] 132 | 133 | sort_detectboxs = sorted(detectResult, key=lambda x: x.score, reverse=True) 134 | 135 | for i in range(len(sort_detectboxs)): 136 | xmin1 = sort_detectboxs[i].xmin 137 | ymin1 = sort_detectboxs[i].ymin 138 | xmax1 = sort_detectboxs[i].xmax 139 | ymax1 = sort_detectboxs[i].ymax 140 | classId = sort_detectboxs[i].classId 141 | 142 | if sort_detectboxs[i].classId != -1: 143 | predBoxs.append(sort_detectboxs[i]) 144 | for j in range(i + 1, len(sort_detectboxs), 1): 145 | # if classId == sort_detectboxs[j].classId: 146 | xmin2 = sort_detectboxs[j].xmin 147 | ymin2 = sort_detectboxs[j].ymin 148 | xmax2 = sort_detectboxs[j].xmax 149 | ymax2 = sort_detectboxs[j].ymax 150 | iou = IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2) 151 | if iou > nmsThresh: 152 | sort_detectboxs[j].classId = -1 153 | return predBoxs 154 | 155 | def sigmoid(x): 156 | return 1 / (1 + exp(-x)) 157 | 158 | 159 | def postprocess(output, img_h, img_w): 160 | print('postprocess ... ') 161 | output = output[0] 162 | 163 | detectResult = [] 164 | 165 | scale_h = img_h / input_imgH 166 | scale_w = img_w / input_imgW 167 | 168 | coord_index = mapSize[0][0] * mapSize[0][1] + mapSize[1][0] * mapSize[1][1] + mapSize[2][0] * mapSize[2][1] 169 | 170 | for i in range(coord_index): 171 | for j in range(4, len(CLASSES) + 4, 1): 172 | if output[coord_index * j + i] > objectThresh: 173 | classId = j - 4 174 | score = output[coord_index * j + i] 175 | 176 | cx = output[coord_index * 0 + i] 177 | cy = output[coord_index * 1 + i] 178 | cw = output[coord_index * 2 + i] 179 | ch = output[coord_index * 3 + i] 180 | 181 | xmin = (cx - cw * 0.5) * scale_w 182 | ymin = (cy - ch * 0.5) * scale_h 183 | xmax = (cx + cw * 0.5) * scale_w 184 | ymax = (cy + ch * 0.5) * scale_h 185 | detectResult.append(DetectBox(classId, score, xmin, ymin, xmax, ymax)) 186 | 187 | print('detectResult:', len(detectResult)) 188 | predBox = NMS(detectResult) 189 | 190 | return predBox 191 | 192 | def preprocess(src): 193 | img = cv2.resize(src, (input_imgW, input_imgH)).astype(np.float32) 194 | img = img * 0.00392156 195 | img = img.transpose(2, 0, 1) 196 | img_input = img.copy() 197 | return img_input 198 | 199 | 200 | def main(): 201 | engine_file_path = 'yolov11n.trt' 202 | input_image_path = 'test.jpg' 203 | 204 | orig_image = cv2.imread(input_image_path) 205 | orig = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) 206 | img_h, img_w = orig.shape[:2] 207 | image = preprocess(orig) 208 | 209 | with get_engine_from_bin(engine_file_path) as engine, engine.create_execution_context() as context: 210 | inputs, outputs, bindings, stream = allocate_buffers(engine) 211 | 212 | inputs[0].host = image 213 | t1 = time.time() 214 | trt_outputs = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=1) 215 | t2 = time.time() 216 | print('run tiems time:', (t2 - t1)) 217 | 218 | print('outputs heads num: ', len(trt_outputs)) 219 | 220 | out = [] 221 | for i in range(len(trt_outputs)): 222 | out.append(trt_outputs[i]) 223 | 224 | predbox = postprocess(out, img_h, img_w) 225 | 226 | print(len(predbox)) 227 | 228 | for i in range(len(predbox)): 229 | xmin = int(predbox[i].xmin) 230 | ymin = int(predbox[i].ymin) 231 | xmax = int(predbox[i].xmax) 232 | ymax = int(predbox[i].ymax) 233 | classId = predbox[i].classId 234 | score = predbox[i].score 235 | 236 | cv2.rectangle(orig_image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2) 237 | ptext = (xmin, ymin) 238 | title = CLASSES[classId] + "%.2f" % score 239 | cv2.putText(orig_image, title, ptext, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2, cv2.LINE_AA) 240 | 241 | cv2.imwrite('./test_result_tensorRT.jpg', orig_image) 242 | # cv2.imshow("test", orig_image) 243 | # cv2.waitKey(0) 244 | 245 | 246 | if __name__ == '__main__': 247 | print('This is main ...') 248 | GenerateMeshgrid() 249 | main() -------------------------------------------------------------------------------- /yolov11n_tensorrt/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/test.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/test_result_tensorRT.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/test_result_tensorRT.jpg -------------------------------------------------------------------------------- /yolov11n_tensorrt/yolov11n.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/yolov11n.onnx -------------------------------------------------------------------------------- /yolov11n_tensorrt/yolov11n_fp32.trt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/yolov11n_fp32.trt -------------------------------------------------------------------------------- /yolov11n_tensorrt/yolov11n_int8.trt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_onnx_rknn_tensorRT/520454f197571166b885387b749e1e1b12065ac7/yolov11n_tensorrt/yolov11n_int8.trt --------------------------------------------------------------------------------