├── README.md ├── __pycache__ ├── postprocess.cpython-37.pyc └── warpaffine.cpython-37.pyc ├── img ├── bus.jpg ├── yolov5_test.jpg └── yolov8_test.jpg ├── postprocess.py ├── tool └── add_transpose_node.py ├── trt_inference.py ├── warpaffine.py └── weights ├── yolov5s.engine ├── yolov5s.onnx ├── yolov8n_transpose.engine └── yolov8n_transpose.onnx /README.md: -------------------------------------------------------------------------------- 1 | ___ 2 | # 基于pycuda的yolo前后处理 3 | 4 | 1、使用pycuda对yolo前后处理进gpu加速。前处理包含操作:缩放、补边、bgr->rgb,转换维度(1,640,640,3)->(1,3,640,640),除255.后处理支持yolov5、yolov8. 5 | 2、yolov8模型需要./tool/add_transpose_node.py 将官方onnx模型输出进行转换和重命名(output:1,8400,84),方便处理。 6 | 7 | 8 | 9 | ### 使用例子 10 | ``` 11 | yolov8_inference = TRT_inference("./weights/yolov8n_transpose.onnx",model="yolov8") 12 | img = cv2.imread("./img/bus.jpg") 13 | img1 = copy.deepcopy(img) 14 | 15 | boxs = yolov8_inference(img) 16 | for box in boxs: 17 | cv2.rectangle(img,(int(box[0]),int(box[1])),(int(box[2]),int(box[3])),(255,0,0),2) 18 | cv2.imwrite("./img/yolov8_test.jpg",img) 19 | 20 | yolov5_inference = TRT_inference("./weights/yolov5s.onnx",model="yolov5") 21 | boxs = yolov5_inference(img1) 22 | for box in boxs: 23 | cv2.rectangle(img1,(int(box[0]),int(box[1])),(int(box[2]),int(box[3])),(255,0,0),2) 24 | cv2.imwrite("./img/yolov5_test.jpg",img1) 25 | ``` -------------------------------------------------------------------------------- /__pycache__/postprocess.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/__pycache__/postprocess.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/warpaffine.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/__pycache__/warpaffine.cpython-37.pyc -------------------------------------------------------------------------------- /img/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/img/bus.jpg -------------------------------------------------------------------------------- /img/yolov5_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/img/yolov5_test.jpg -------------------------------------------------------------------------------- /img/yolov8_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/img/yolov8_test.jpg -------------------------------------------------------------------------------- /postprocess.py: -------------------------------------------------------------------------------- 1 | import pycuda.driver as cuda 2 | import pycuda.autoinit 3 | import pycuda.gpuarray as gpuarray 4 | from pycuda.compiler import SourceModule 5 | 6 | import numpy as np 7 | import torch 8 | import cv2 9 | import torchvision 10 | import time 11 | import operator 12 | from warpaffine import Warpaffine 13 | 14 | 15 | 16 | class gpu_decode(object): 17 | def __init__(self, rows, cols, confidence_threshold = 0.6,nms_threshold = 0.45,model="yolov5",stream=None): 18 | super(gpu_decode, self).__init__() 19 | self.rows = rows 20 | self.cols = cols 21 | self.model = model 22 | self.block = 512 if rows > 512 else rows 23 | self.grid = (rows + self.block - 1) // self.block 24 | self.block = (self.block,1,1) 25 | self.grid = (self.grid,1,1) 26 | 27 | self.max_objects = 1000 28 | self.NUM_BOX_ELEMENT = 7 29 | self.num_bboxes = cuda.In(np.array([rows]).astype(np.int32)) 30 | if self.model == "yolov5": 31 | self.num_classes = cuda.In(np.array([cols-5]).astype(np.int32)) 32 | elif self.model == "yolov8": 33 | self.num_classes = cuda.In(np.array([cols-4]).astype(np.int32)) 34 | 35 | self.confidence_threshold = cuda.In(np.array([confidence_threshold]).astype(np.float32)) 36 | self.nms_threshold = cuda.In(np.array([nms_threshold]).astype(np.float32)) 37 | 38 | self.nms_block = 512 if self.max_objects > 512 else self.max_objects 39 | self.nms_grid = (self.max_objects + self.nms_block - 1) / self.nms_block; 40 | self.nms_block = (self.nms_block,1,1) 41 | self.nms_grid = (self.nms_grid,1,1) 42 | 43 | if stream == None: 44 | self.stream = cuda.Stream() 45 | else: 46 | self.stream = stream 47 | 48 | # self.predict_host = cuda.register_host_memory(np.ones((1,self.rows,self.cols)).astype(np.float32)) 49 | # self.predict_device = cuda.mem_alloc(self.predict_host.nbytes) 50 | 51 | self.output_host = cuda.pagelocked_empty_like(np.ones((self.max_objects, self.NUM_BOX_ELEMENT)).astype(np.float32)) 52 | self.output_device_nbytes = self.output_host.nbytes 53 | self.output_device = cuda.mem_alloc(self.output_device_nbytes) 54 | self.max_objects = cuda.In(np.array([self.max_objects]).astype(np.int32)) 55 | self.NUM_BOX_ELEMENT = cuda.In(np.array([self.NUM_BOX_ELEMENT]).astype(np.int32)) 56 | 57 | self.filter_boxs = np.array([0]).astype(np.uint32) #获取第一次过滤后的box数量 58 | 59 | self.decode_kernel,self.fast_nms_kernel = self.cuda_func() 60 | 61 | def cuda_func(self): 62 | mod = SourceModule(""" 63 | __device__ void affine_project(float* matrix, float x, float y, float* ox, float* oy){ 64 | *ox = matrix[0] * x + matrix[1] * y + matrix[2]; 65 | *oy = matrix[3] * x + matrix[4] * y + matrix[5]; 66 | } 67 | 68 | __global__ void decode_kernelv5( 69 | float* predict, int* num_bboxes, int* num_classes, float* confidence_threshold, 70 | float* invert_affine_matrix, float* parray, int* max_objects, int* filter_boxs, int* NUM_BOX_ELEMENT 71 | ) 72 | { 73 | int position = blockDim.x * blockIdx.x + threadIdx.x; 74 | if (position >= *num_bboxes) return; 75 | 76 | float* pitem = predict + (5 + *num_classes) * position; 77 | float objectness = pitem[4]; 78 | if(objectness < *confidence_threshold) 79 | return; 80 | 81 | float* class_confidence = pitem + 5; 82 | float confidence = *class_confidence++; 83 | int label = 0; 84 | for(int i = 1; i < *num_classes; ++i, ++class_confidence){ 85 | if(*class_confidence > confidence){ 86 | confidence = *class_confidence; 87 | label = i; 88 | } 89 | } 90 | 91 | confidence *= objectness; 92 | if(confidence < *confidence_threshold) 93 | return; 94 | 95 | int index = atomicAdd(filter_boxs, 1); 96 | if(index >= *max_objects) 97 | return; 98 | 99 | float cx = *pitem++; 100 | float cy = *pitem++; 101 | float width = *pitem++; 102 | float height = *pitem++; 103 | float left = cx - width * 0.5f; 104 | float top = cy - height * 0.5f; 105 | float right = cx + width * 0.5f; 106 | float bottom = cy + height * 0.5f; 107 | 108 | affine_project(invert_affine_matrix, left, top, &left, &top); 109 | affine_project(invert_affine_matrix, right, bottom, &right, &bottom); 110 | // left, top, right, bottom, confidence, class, keepflag 111 | float* pout_item = parray + index * (*NUM_BOX_ELEMENT); 112 | *pout_item++ = left; 113 | *pout_item++ = top; 114 | *pout_item++ = right; 115 | *pout_item++ = bottom; 116 | *pout_item++ = confidence; 117 | *pout_item++ = label; 118 | *pout_item++ = 1; // 1 = keep, 0 = ignore 119 | } 120 | 121 | __global__ void decode_kernelv8( 122 | float* predict, int* num_bboxes, int* num_classes, float* confidence_threshold, 123 | float* invert_affine_matrix, float* parray, int* max_objects, int* filter_boxs, int* NUM_BOX_ELEMENT 124 | ) 125 | { 126 | int position = blockDim.x * blockIdx.x + threadIdx.x; 127 | if (position >= *num_bboxes) return; 128 | 129 | float* pitem = predict + (4 + *num_classes) * position; 130 | 131 | 132 | float* class_confidence = pitem + 4; 133 | float confidence = *class_confidence++; 134 | int label = 0; 135 | for(int i = 1; i < *num_classes; ++i, ++class_confidence){ 136 | if(*class_confidence > confidence){ 137 | confidence = *class_confidence; 138 | label = i; 139 | } 140 | } 141 | 142 | if(confidence < *confidence_threshold) 143 | return; 144 | 145 | int index = atomicAdd(filter_boxs, 1); 146 | if(index >= *max_objects) 147 | return; 148 | 149 | float cx = *pitem++; 150 | float cy = *pitem++; 151 | float width = *pitem++; 152 | float height = *pitem++; 153 | float left = cx - width * 0.5f; 154 | float top = cy - height * 0.5f; 155 | float right = cx + width * 0.5f; 156 | float bottom = cy + height * 0.5f; 157 | 158 | affine_project(invert_affine_matrix, left, top, &left, &top); 159 | affine_project(invert_affine_matrix, right, bottom, &right, &bottom); 160 | // left, top, right, bottom, confidence, class, keepflag 161 | float* pout_item = parray + index * (*NUM_BOX_ELEMENT); 162 | *pout_item++ = left; 163 | *pout_item++ = top; 164 | *pout_item++ = right; 165 | *pout_item++ = bottom; 166 | *pout_item++ = confidence; 167 | *pout_item++ = label; 168 | *pout_item++ = 1; // 1 = keep, 0 = ignore 169 | } 170 | 171 | __device__ float box_iou( 172 | float aleft, float atop, float aright, float abottom, 173 | float bleft, float btop, float bright, float bbottom 174 | ){ 175 | 176 | float cleft = max(aleft, bleft); 177 | float ctop = max(atop, btop); 178 | float cright = min(aright, bright); 179 | float cbottom = min(abottom, bbottom); 180 | 181 | float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f); 182 | if(c_area == 0.0f) 183 | return 0.0f; 184 | 185 | float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop); 186 | float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop); 187 | return c_area / (a_area + b_area - c_area); 188 | } 189 | 190 | __global__ void fast_nms_kernel(float* bboxes, int*filter_boxs, int* max_objects, float* threshold, int* NUM_BOX_ELEMENT){ 191 | 192 | int position = (blockDim.x * blockIdx.x + threadIdx.x); 193 | int count = min(*filter_boxs, *max_objects); 194 | 195 | 196 | if (position >= count) 197 | return; 198 | 199 | // left, top, right, bottom, confidence, class, keepflag 200 | float* pcurrent = bboxes + position * (*NUM_BOX_ELEMENT); 201 | for(int i = 0; i < count; ++i){ 202 | float* pitem = bboxes + i * (*NUM_BOX_ELEMENT); 203 | if(i == position || pcurrent[5] != pitem[5]) continue; 204 | 205 | if(pitem[4] >= pcurrent[4]){ 206 | if(pitem[4] == pcurrent[4] && i < position) 207 | continue; 208 | 209 | float iou = box_iou( 210 | pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], 211 | pitem[0], pitem[1], pitem[2], pitem[3] 212 | ); 213 | 214 | if(iou > *threshold){ 215 | pcurrent[6] = 0; // 1=keep, 0=ignore 216 | return; 217 | } 218 | } 219 | } 220 | } 221 | 222 | """) 223 | if self.model == "yolov5": 224 | decode_kernel = mod.get_function("decode_kernelv5") 225 | elif self.model == "yolov8": 226 | decode_kernel = mod.get_function("decode_kernelv8") 227 | 228 | return decode_kernel,mod.get_function("fast_nms_kernel") 229 | 230 | 231 | def decode_kernel_invoker(self,predict, affine): 232 | 233 | # np.copyto(self.predict_host,predict[0].data) 234 | # cuda.memcpy_htod_async(self.predict_device, self.predict_host, self.stream) 235 | 236 | self.filter_boxs[0]=0 237 | self.decode_kernel(predict,\ 238 | self.num_bboxes,\ 239 | self.num_classes,\ 240 | self.confidence_threshold,\ 241 | cuda.In(affine),\ 242 | self.output_device,\ 243 | self.max_objects,\ 244 | cuda.InOut(self.filter_boxs),\ 245 | self.NUM_BOX_ELEMENT,\ 246 | stream=self.stream,block=self.block,grid=self.grid) 247 | 248 | self.fast_nms_kernel(self.output_device,\ 249 | cuda.In(self.filter_boxs),\ 250 | self.max_objects,\ 251 | self.nms_threshold,\ 252 | self.NUM_BOX_ELEMENT,\ 253 | stream=self.stream,\ 254 | block=self.block,\ 255 | grid=self.grid) 256 | cuda.memcpy_dtoh_async(self.output_host, self.output_device, self.stream) 257 | self.stream.synchronize() 258 | 259 | cuda.memset_d8(self.output_device, 0, self.output_device_nbytes) #清空 260 | 261 | return self.output_host[self.output_host[:,6]>0] 262 | 263 | def __call__(self, predict, affine): 264 | return self.decode_kernel_invoker(predict, affine) 265 | 266 | 267 | 268 | 269 | 270 | if __name__ == "__main__": 271 | device = "cpu" 272 | model = torch.jit.load("yolov5s.torchscript") 273 | warpaffine = Warpaffine(dst_size=(640,640)) 274 | postprocess = gpu_decode(rows=25200, cols=85) 275 | img = cv2.imread("bus.jpg") 276 | pdst_img = warpaffine(img) 277 | pdst_img = torch.from_numpy(pdst_img).to(device) 278 | predict = model(pdst_img)[0].numpy() 279 | 280 | img1 = warpaffine(img)*255 281 | img1 = img1[0].transpose(1, 2, 0) 282 | print(img1.shape) 283 | img1 = cv2.cvtColor(img1,cv2.COLOR_RGB2BGR) 284 | 285 | t1 = time.time() 286 | for _ in range(1000): 287 | boxs = postprocess(predict) 288 | t2 = time.time() 289 | print(t2-t1) 290 | for box in boxs: 291 | cv2.rectangle(img1,(box[0],box[1]),(box[2],box[3]),(255,0,0),2) 292 | cv2.imwrite("test.jpg",img1) 293 | 294 | 295 | -------------------------------------------------------------------------------- /tool/add_transpose_node.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnx import numpy_helper, helper, shape_inference 3 | import onnxruntime as ort 4 | import numpy as np 5 | 6 | model_path = './model/yolov8n.onnx' 7 | model = onnx.load(model_path) 8 | 9 | last_node = model.graph.node[-1] 10 | output_name = "output" 11 | perm = [0, 2, 1] 12 | 13 | transpose_node = helper.make_node('Transpose', inputs=["Transpose_output"], outputs=[output_name], perm=perm) 14 | last_node.output[0] = "Transpose_output" 15 | model.graph.node.extend([transpose_node]) 16 | 17 | last_node1 = model.graph.node[-1] 18 | output_tensor1 = last_node.output[0] 19 | 20 | inferred_model = shape_inference.infer_shapes(model) 21 | output_shape = inferred_model.graph.output[0].type.tensor_type.shape.dim 22 | model.graph.output[0].name = output_name 23 | model.graph.output[0].type.tensor_type.shape.dim[1].dim_value = output_shape[2].dim_value 24 | model.graph.output[0].type.tensor_type.shape.dim[2].dim_value = output_shape[1].dim_value 25 | # # Save the modified ONNX model 26 | onnx.save(model, 'yolov8n_transpose.onnx') 27 | -------------------------------------------------------------------------------- /trt_inference.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import cv2 3 | from warpaffine import Warpaffine 4 | from postprocess import gpu_decode 5 | # import pycuda.autoinit #负责数据初始化,内存管理,销毁等 6 | import pycuda.driver as cuda #GPU CPU之间的数据传输 7 | import time 8 | import os 9 | import copy 10 | 11 | class TRT_inference(object): 12 | def __init__(self, model_path,model="yolov5"): 13 | super(TRT_inference, self).__init__() 14 | 15 | logger = trt.Logger(trt.Logger.WARNING) 16 | builder = trt.Builder(logger) 17 | network = builder.create_network(1 <= dst_width || dy >= dst_height) return; 82 | 83 | 84 | 85 | float c0 = fill_value[0], c1 = fill_value[1], c2 = fill_value[2]; 86 | float src_x = 0; float src_y = 0; 87 | 88 | affine_project(affine, dx, dy, &src_x, &src_y); 89 | 90 | if(src_x < -1 || src_x >= src_width || src_y < -1 || src_y >= src_height){ 91 | // out of range 92 | // src_x < -1时,其高位high_x < 0,超出范围 93 | // src_x >= -1时,其高位high_x >= 0,存在取值 94 | }else{ 95 | int y_low = floorf(src_y); 96 | int x_low = floorf(src_x); 97 | int y_high = y_low + 1; 98 | int x_high = x_low + 1; 99 | 100 | 101 | unsigned char const_values[] = {fill_value[0], fill_value[1], fill_value[2]}; 102 | float ly = src_y - y_low; 103 | float lx = src_x - x_low; 104 | float hy = 1 - ly; 105 | float hx = 1 - lx; 106 | float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; 107 | unsigned char* v1 = const_values; 108 | unsigned char* v2 = const_values; 109 | unsigned char* v3 = const_values; 110 | unsigned char* v4 = const_values; 111 | 112 | if(y_low >= 0){ 113 | if (x_low >= 0) 114 | v1 = src + y_low * src_line_size + x_low * 3; 115 | 116 | if (x_high < src_width) 117 | v2 = src + y_low * src_line_size + x_high * 3; 118 | } 119 | 120 | if(y_high < src_height){ 121 | if (x_low >= 0) 122 | v3 = src + y_high * src_line_size + x_low * 3; 123 | 124 | if (x_high < src_width) 125 | v4 = src + y_high * src_line_size + x_high * 3; 126 | } 127 | 128 | c0 = floorf(w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0] + 0.5f); 129 | c1 = floorf(w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1] + 0.5f); 130 | c2 = floorf(w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2] + 0.5f); 131 | } 132 | 133 | //unsigned char * pdst = dst + dy * dst_line_size + dx * 3; 134 | //[0] = c0; pdst[1] = c1; pdst[2] = c2; 135 | 136 | float* pdst0 = dst + dy * dst_width + dx; 137 | float* pdst1 = dst + dy * dst_width + dx + dst_line_size; 138 | float* pdst2 = dst + dy * dst_width + dx + 2*dst_line_size; 139 | 140 | *pdst2 = c0/255.0; *pdst1 = c1/255.0; *pdst0 = c2/255.0; //rgb->bgr sp 141 | 142 | } 143 | """) 144 | return mod.get_function("warp_affine_bilinear_kernel") 145 | 146 | def up_information(self,img): 147 | self.src_size = (img.shape[1],img.shape[0]) 148 | self.img_device = cuda.mem_alloc(img.nbytes) 149 | self.img_host = cuda.pagelocked_empty_like(np.ones((self.src_size[1],self.src_size[0],3)).astype(np.uint8), mem_flags=0) 150 | # self.img_host = cuda.register_host_memory(np.ones((self.src_size[1],self.src_size[0],3)).astype(np.uint8)) 151 | self.src_info = np.array([self.src_size[0]*3,self.src_size[0],self.src_size[1]]).astype(np.int32) 152 | self.affine = affine_compute(self.src_size,self.dst_size) 153 | 154 | 155 | def preprocess(self,img): 156 | 157 | if operator.eq(self.src_size,(img.shape[1],img.shape[0])) is False: 158 | self.up_information(img) 159 | 160 | np.copyto(self.img_host,img.data) 161 | cuda.memcpy_htod_async(self.img_device, self.img_host, self.stream) 162 | self.func(self.img_device,cuda.In(self.src_info),self.pdst_device,cuda.In(self.dst_info),\ 163 | cuda.In(self.fill_value),cuda.In(self.affine),stream=self.stream,block=self.block_size,grid=self.grid_size) 164 | # cuda.memcpy_dtoh_async(self.pdst_host, self.pdst_device, self.stream) 165 | # self.stream.synchronize() 166 | 167 | return self.pdst_device,self.affine 168 | 169 | def __call__(self, img): 170 | return self.preprocess(img) 171 | 172 | 173 | if __name__ == "__main__": 174 | 175 | warpaffine = Warpaffine(dst_size=(640,384)) 176 | 177 | img =cv2.imread("dog1.jpg") 178 | pdst_img = warpaffine(img)*255 179 | pdst_img = pdst_img[0].transpose(1, 2, 0) 180 | print(pdst_img.shape) 181 | pdst_img = cv2.cvtColor(pdst_img,cv2.COLOR_RGB2BGR) 182 | cv2.imwrite("my.jpg",pdst_img) 183 | 184 | 185 | -------------------------------------------------------------------------------- /weights/yolov5s.engine: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/weights/yolov5s.engine -------------------------------------------------------------------------------- /weights/yolov5s.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/weights/yolov5s.onnx -------------------------------------------------------------------------------- /weights/yolov8n_transpose.engine: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/weights/yolov8n_transpose.engine -------------------------------------------------------------------------------- /weights/yolov8n_transpose.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangzongmou/Pycuda-warpaffine-trt-post/c9e81ce04147cbd32cae856b04b0f0be12e04bb1/weights/yolov8n_transpose.onnx --------------------------------------------------------------------------------