├── CMakeLists.txt ├── README.md ├── images ├── .ipynb_checkpoints │ └── result-checkpoint.jpg ├── result.jpg └── test.jpg ├── main.cpp ├── models ├── yolov11n.onnx └── yolov11n.trt └── src ├── CNN.cpp ├── CNN.hpp ├── common ├── common.hpp └── logging.h ├── common_struct.hpp ├── kernels ├── .ipynb_checkpoints │ ├── get_nms_before_boxes-checkpoint.cu │ └── get_nms_before_boxes-checkpoint.cuh ├── get_nms_before_boxes.cu └── get_nms_before_boxes.cuh ├── postprocess_cuda.cpp └── postprocess_cuda.hpp /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | 3 | project(yolo_trt) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | 7 | # CUDA 8 | find_package(CUDA REQUIRED) 9 | 10 | message(STATUS "Find CUDA include at ${CUDA_INCLUDE_DIRS}") 11 | message(STATUS "Find CUDA libraries: ${CUDA_LIBRARIES}") 12 | 13 | # find_library(CUDA_NPP_LIBRARY NAMES npp PATHS ${CUDA_TOOLKIT_ROOT_DIR}/lib64) 14 | 15 | # TensorRT 16 | set(TENSORRT_ROOT /root/autodl-tmp/TensorRT-8.6.1.6) 17 | 18 | find_path(TENSORRT_INCLUDE_DIR NvInfer.h 19 | HINTS ${TENSORRT_ROOT} PATH_SUFFIXES include/) 20 | message(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") 21 | find_library(TENSORRT_LIBRARY_INFER nvinfer 22 | HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 23 | PATH_SUFFIXES lib lib64 lib/x64) 24 | find_library(TENSORRT_LIBRARY_ONNXPARSER nvonnxparser 25 | HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 26 | PATH_SUFFIXES lib lib64 lib/x64) 27 | set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_ONNXPARSER}) 28 | message(STATUS "Find TensorRT libs: ${TENSORRT_LIBRARY}") 29 | 30 | # OpenCV 31 | find_package(OpenCV REQUIRED) 32 | 33 | message(STATUS "Find OpenCV include at ${OpenCV_INCLUDE_DIRS}") 34 | message(STATUS "Find OpenCV libraries: ${OpenCV_LIBRARIES}") 35 | 36 | set(COMMON_INCLUDE ./common) 37 | set(KERNELS_INCLUDE ./kernels) 38 | 39 | include_directories(${CUDA_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR} ${OpenCV_INCLUDE_DIRS} ${COMMON_INCLUDE} ${KERNELS_INCLUDE}) 40 | 41 | # set(CMAKE_CUDA_ARCHITECTURES 86) 42 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 43 | set(CUDA_GEN_CODE "-gencode=arch=compute_89,code=sm_89") 44 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --default-stream=per-thread --threads 0 -O0 -Xcompiler -fPIC -g -w ${CUDA_GEN_CODE}") 45 | 46 | enable_language(CUDA) 47 | 48 | SET(CMAKE_PREFIX_PATH "/usr/local/cuda") 49 | 50 | add_executable(yolo_trt main.cpp src/CNN.cpp src/postprocess_cuda.cpp src/kernels/get_nms_before_boxes.cu) 51 | 52 | # 使用 gpu 进行预处理 打开这一行,关闭 后后 面一行 53 | # target_link_libraries(yolo_trt ${OpenCV_LIBRARIES} ${CUDA_LIBRARIES} ${TENSORRT_LIBRARY} ${CUDA_npp_LIBRARY} ${CUDA_cublas_LIBRARY}) 54 | 55 | target_link_libraries(yolo_trt ${OpenCV_LIBRARIES} ${CUDA_LIBRARIES} ${TENSORRT_LIBRARY}) 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # yolov11 tensorRT 的 C++ 部署,后处理用cuda实现的 2 | 3 | 本示例中,包含完整的代码、模型、测试图片、测试结果。 4 | 5 | 后处理部分用cuda 核函数实现,并不是全部后处理都用cuda实现;纯cpu实现后处理部分代码分支[【cpu实现后处理代码分支】](https://github.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/tree/yolov11_postprocess_cpu) 6 | 7 | TensorRT版本:TensorRT-8.6.1.6 8 | 9 | ## 导出onnx模型 10 | 11 | 按照yolov11官方导出的方式如下: 12 | 13 | ```python 14 | from ultralytics import YOLO 15 | model = YOLO(model='yolov11n.pt') # load a pretrained model (recommended for training) 16 | results = model(task='detect', source=r'./bus.jpg', save=True) # predict on an image 17 | 18 | model.export(format="onnx", imgsz=640, simplify=True) 19 | 20 | ``` 21 | 22 | ## 编译 23 | 24 | 修改 CMakeLists.txt 对应的TensorRT位置 25 | 26 | ![image](https://github.com/user-attachments/assets/ac92b3d7-855a-40ac-9b5f-a3fabd262634) 27 | 28 | 29 | ```powershell 30 | cd yolov11_tensorRT_postprocess_cuda 31 | mkdir build 32 | cd build 33 | cmake .. 34 | make 35 | ``` 36 | 37 | ## 运行 38 | 39 | ```powershell 40 | # 运行时如果.trt模型存在则直接加载,若不存会自动先将onnx转换成 trt 模型,并存在给定的位置,然后运行推理。 41 | cd build 42 | ./yolo_trt 43 | ``` 44 | 45 | ## 测试效果 46 | 47 | onnx 测试效果 48 | 49 | ![image](https://github.com/user-attachments/assets/da904ce0-4e0c-414e-9339-39dca4747328) 50 | 51 | tensorRT 测试效果 52 | 53 | ![image](https://github.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/blob/main/images/result.jpg) 54 | 55 | ### tensorRT 时耗(cuda实现部分后处理) 56 | 57 | 示例中用cpu对图像进行预处理、用rtx4090显卡、模型yolov11n(输入分辨率640x640,80个类别)、量化成FP16模型 58 | 59 | ![image](https://github.com/user-attachments/assets/4522185b-9064-4489-8022-8304c61ba82d) 60 | 61 | ### tensorRT 时耗(纯cpu实现后处理)[【cpu实现后处理代码分支】](https://github.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/tree/yolov11_postprocess_cpu) 62 | ![image](https://github.com/user-attachments/assets/bbbc6777-d3e3-4349-b623-4f0f78e39910) 63 | 64 | 65 | 66 | ## 替换模型说明 67 | 68 | 修改相关的路径 69 | ```cpp 70 | 71 | int main() 72 | { 73 | std::string OnnxFile = "/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/models/yolov11n.onnx"; 74 | std::string SaveTrtFilePath = "/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/models/yolov11n.trt"; 75 | cv::Mat SrcImage = cv::imread("/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/images/test.jpg"); 76 | 77 | int img_width = SrcImage.cols; 78 | int img_height = SrcImage.rows; 79 | std::cout << "img_width: " << img_width << " img_height: " << img_height << std::endl; 80 | 81 | CNN YOLO(OnnxFile, SaveTrtFilePath, 1, 3, 640, 640); 82 | 83 | auto t_start = std::chrono::high_resolution_clock::now(); 84 | int Temp = 2000; 85 | 86 | int SleepTimes = 0; 87 | for (int i = 0; i < Temp; i++) 88 | { 89 | YOLO.Inference(SrcImage); 90 | std::this_thread::sleep_for(std::chrono::milliseconds(SleepTimes)); 91 | } 92 | auto t_end = std::chrono::high_resolution_clock::now(); 93 | float total_inf = std::chrono::duration(t_end - t_start).count(); 94 | std::cout << "Info: " << Temp << " times infer and postprocess ave cost: " << total_inf / float(Temp) - SleepTimes << " ms." << std::endl; 95 | 96 | 97 | for (int i = 0; i < YOLO.DetectiontRects_.size(); i += 6) 98 | { 99 | int classId = int(YOLO.DetectiontRects_[i + 0]); 100 | float conf = YOLO.DetectiontRects_[i + 1]; 101 | int xmin = int(YOLO.DetectiontRects_[i + 2] * float(img_width) + 0.5); 102 | int ymin = int(YOLO.DetectiontRects_[i + 3] * float(img_height) + 0.5); 103 | int xmax = int(YOLO.DetectiontRects_[i + 4] * float(img_width) + 0.5); 104 | int ymax = int(YOLO.DetectiontRects_[i + 5] * float(img_height) + 0.5); 105 | 106 | char text1[256]; 107 | sprintf(text1, "%d:%.2f", classId, conf); 108 | rectangle(SrcImage, cv::Point(xmin, ymin), cv::Point(xmax, ymax), cv::Scalar(255, 0, 0), 2); 109 | putText(SrcImage, text1, cv::Point(xmin, ymin + 15), cv::FONT_HERSHEY_SIMPLEX, 0.7, cv::Scalar(0, 0, 255), 2); 110 | } 111 | 112 | imwrite("/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/images/result.jpg", SrcImage); 113 | 114 | printf("== obj: %d \n", int(float(YOLO.DetectiontRects_.size()) / 6.0)); 115 | 116 | return 0; 117 | } 118 | 119 | ``` 120 | 121 | 122 | ## 预处理加速 123 | 124 | 如果环境支持CUDA_npp_LIBRARY进行预处理(如果有环境可以打开进一步加速(修改位置:CMakelist.txt、用CPU或GPU预处理打开对应的宏 #define USE_GPU_PREPROCESS 1)) 125 | 126 | **重新搭建了一个支持用gpu做处理操作:rtx4090显卡、模型yolov11n(输入分辨率640x640,80个类别)、量化成FP16模型** 127 | 128 | cpu做预处理+cpu做后处理 129 | 130 | ![image](https://github.com/user-attachments/assets/e3d44672-38cf-47f7-84e3-9436dc0e6c0c) 131 | 132 | 133 | cpu做预处理+gpu做后处理 134 | 135 | ![image](https://github.com/user-attachments/assets/482bb1cc-3454-454a-ae2e-362c59cb9eaa) 136 | 137 | gpu做预处理+gpu做后处理 138 | 139 | ![image](https://github.com/user-attachments/assets/a05a3fab-35d0-45ff-bbf3-e292093bb725) 140 | 141 | 142 | ## 后续优化点 143 | 1、把nms过程也用cuda实现,参加nms的框不多,但也是一个优化点,持续更新中 144 | -------------------------------------------------------------------------------- /images/.ipynb_checkpoints/result-checkpoint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/7c9d3e4fd4b50978fe68756d77d0c568893279c3/images/.ipynb_checkpoints/result-checkpoint.jpg -------------------------------------------------------------------------------- /images/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/7c9d3e4fd4b50978fe68756d77d0c568893279c3/images/result.jpg -------------------------------------------------------------------------------- /images/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/7c9d3e4fd4b50978fe68756d77d0c568893279c3/images/test.jpg -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "src/CNN.hpp" 2 | #include 3 | #include 4 | #include 5 | 6 | int main() 7 | { 8 | std::string OnnxFile = "/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/models/yolov11n.onnx"; 9 | std::string SaveTrtFilePath = "/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/models/yolov11n.trt"; 10 | cv::Mat SrcImage = cv::imread("/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/images/test.jpg"); 11 | 12 | int img_width = SrcImage.cols; 13 | int img_height = SrcImage.rows; 14 | std::cout << "img_width: " << img_width << " img_height: " << img_height << std::endl; 15 | 16 | CNN YOLO(OnnxFile, SaveTrtFilePath, 1, 3, 640, 640); 17 | 18 | auto t_start = std::chrono::high_resolution_clock::now(); 19 | int Temp = 2000; 20 | 21 | int SleepTimes = 0; 22 | for (int i = 0; i < Temp; i++) 23 | { 24 | YOLO.Inference(SrcImage); 25 | std::this_thread::sleep_for(std::chrono::milliseconds(SleepTimes)); 26 | } 27 | auto t_end = std::chrono::high_resolution_clock::now(); 28 | float total_inf = std::chrono::duration(t_end - t_start).count(); 29 | std::cout << "Info: " << Temp << " times infer and postprocess ave cost: " << total_inf / float(Temp) - SleepTimes << " ms." << std::endl; 30 | 31 | 32 | for (int i = 0; i < YOLO.DetectiontRects_.size(); i += 6) 33 | { 34 | int classId = int(YOLO.DetectiontRects_[i + 0]); 35 | float conf = YOLO.DetectiontRects_[i + 1]; 36 | int xmin = int(YOLO.DetectiontRects_[i + 2] * float(img_width) + 0.5); 37 | int ymin = int(YOLO.DetectiontRects_[i + 3] * float(img_height) + 0.5); 38 | int xmax = int(YOLO.DetectiontRects_[i + 4] * float(img_width) + 0.5); 39 | int ymax = int(YOLO.DetectiontRects_[i + 5] * float(img_height) + 0.5); 40 | 41 | char text1[256]; 42 | sprintf(text1, "%d:%.2f", classId, conf); 43 | rectangle(SrcImage, cv::Point(xmin, ymin), cv::Point(xmax, ymax), cv::Scalar(255, 0, 0), 2); 44 | putText(SrcImage, text1, cv::Point(xmin, ymin + 15), cv::FONT_HERSHEY_SIMPLEX, 0.7, cv::Scalar(0, 0, 255), 2); 45 | } 46 | 47 | imwrite("/root/autodl-tmp/yolov11_tensorRT_postprocess_cuda/images/result.jpg", SrcImage); 48 | 49 | printf("== obj: %d \n", int(float(YOLO.DetectiontRects_.size()) / 6.0)); 50 | 51 | return 0; 52 | } 53 | -------------------------------------------------------------------------------- /models/yolov11n.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/7c9d3e4fd4b50978fe68756d77d0c568893279c3/models/yolov11n.onnx -------------------------------------------------------------------------------- /models/yolov11n.trt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/7c9d3e4fd4b50978fe68756d77d0c568893279c3/models/yolov11n.trt -------------------------------------------------------------------------------- /src/CNN.cpp: -------------------------------------------------------------------------------- 1 | #include "CNN.hpp" 2 | #include "common/common.hpp" 3 | #include 4 | #include 5 | #include "kernels/get_nms_before_boxes.cuh" 6 | 7 | #define USE_GPU_PREPROCESS 0 8 | 9 | CNN::CNN(const std::string &OnnxFilePath, const std::string &SaveTrtFilePath, int BatchSize, int InputChannel, int InputImageWidth, int InputImageHeight) 10 | { 11 | OnnxFilePath_ = OnnxFilePath; 12 | SaveTrtFilePath_ = SaveTrtFilePath; 13 | 14 | BatchSize_ = BatchSize; 15 | InputChannel_ = InputChannel; 16 | InputImageWidth_ = InputImageWidth; 17 | InputImageHeight_ = InputImageHeight; 18 | 19 | ModelInit(); 20 | } 21 | 22 | CNN::~CNN() 23 | { 24 | // release the stream and the Buffers 25 | cudaStreamDestroy(Stream_); 26 | 27 | for (int i = 0; i < BuffersDataSize_.size(); i ++) 28 | { 29 | cudaFree(Buffers_[i]); 30 | } 31 | 32 | // destroy the engine 33 | if (nullptr != PtrContext_) 34 | { 35 | PtrContext_->destroy(); 36 | } 37 | 38 | if (nullptr != PtrEngine_) 39 | { 40 | PtrEngine_->destroy(); 41 | } 42 | 43 | if (nullptr != GpuSrcImgBuf_) 44 | { 45 | cudaFree(GpuSrcImgBuf_); 46 | } 47 | 48 | if (nullptr != GpuImgResizeBuf_) 49 | { 50 | cudaFree(GpuImgResizeBuf_); 51 | } 52 | 53 | if (nullptr != GpuImgF32Buf_) 54 | { 55 | cudaFree(GpuImgF32Buf_); 56 | } 57 | 58 | if (nullptr != GpuDataPlanes_) 59 | { 60 | cudaFree(GpuDataPlanes_); 61 | } 62 | 63 | if (nullptr != GpuOutputCount_) 64 | { 65 | cudaFree(GpuOutputCount_); 66 | } 67 | 68 | if (nullptr != CpuOutputCount_) 69 | { 70 | free(CpuOutputCount_); 71 | } 72 | if (nullptr != CpuOutputRects_) 73 | { 74 | free(CpuOutputRects_); 75 | } 76 | 77 | } 78 | 79 | void CNN::ModelInit() 80 | { 81 | std::fstream existEngine; 82 | existEngine.open(SaveTrtFilePath_, std::ios::in); 83 | if (existEngine) 84 | { 85 | ReadTrtFile(SaveTrtFilePath_, PtrEngine_); 86 | assert(PtrEngine_ != nullptr); 87 | } 88 | else 89 | { 90 | OnnxToTRTModel(OnnxFilePath_, SaveTrtFilePath_, PtrEngine_, BatchSize_); 91 | assert(PtrEngine_ != nullptr); 92 | } 93 | 94 | assert(PtrEngine_ != nullptr); 95 | PtrContext_ = PtrEngine_->createExecutionContext(); 96 | PtrContext_->setOptimizationProfile(0); 97 | auto InputDims = nvinfer1::Dims4 {BatchSize_, InputChannel_, InputImageHeight_, InputImageWidth_}; 98 | PtrContext_->setBindingDimensions(0, InputDims); 99 | 100 | cudaStreamCreate(&Stream_); 101 | 102 | int64_t TotalSize = 0; 103 | int nbBindings = PtrEngine_->getNbBindings(); 104 | BuffersDataSize_.resize(nbBindings); 105 | for (int i = 0; i < nbBindings; ++ i) 106 | { 107 | nvinfer1::Dims dims = PtrEngine_->getBindingDimensions(i); 108 | nvinfer1::DataType dtype = PtrEngine_->getBindingDataType(i); 109 | TotalSize = Volume(dims) * 1 * GetElementSize(dtype); 110 | BuffersDataSize_[i] = TotalSize; 111 | cudaMalloc(&Buffers_[i], TotalSize); 112 | 113 | if (0 == i) 114 | { 115 | std::cout << "input node name: "<< PtrEngine_->getBindingName(i) << ", dims: " << dims.nbDims << std::endl; 116 | } 117 | else 118 | { 119 | std::cout << "output node" << i - 1 << " name: "<< PtrEngine_->getBindingName(i) << ", dims: " << dims.nbDims << std::endl; 120 | } 121 | 122 | for (int j = 0; j < dims.nbDims; j++) 123 | { 124 | std::cout << "demension[" << j << "], size = " << dims.d[j] << std::endl; 125 | } 126 | } 127 | 128 | cudaMalloc(&GpuOutputCount_, sizeof(int)); 129 | cudaMalloc(&GpuOutputRects_, sizeof(DetectRect) * NmsBeforeMaxNum_); 130 | 131 | CpuOutputCount_ = (int *)malloc(sizeof(int)); 132 | CpuOutputRects_ = (DetectRect *)malloc(sizeof(DetectRect) * NmsBeforeMaxNum_); 133 | 134 | #if USE_GPU_PREPROCESS 135 | #else 136 | PreprocessResult_.resize(BatchSize_ * InputImageWidth_ * InputImageHeight_ * InputChannel_); 137 | #endif 138 | 139 | 140 | } 141 | 142 | 143 | void CNN::Inference(cv::Mat &SrcImage) 144 | { 145 | DetectiontRects_.clear(); 146 | if(PtrContext_ == nullptr) 147 | { 148 | std::cout << "Error, PtrContext_" << std::endl; 149 | } 150 | 151 | #if USE_GPU_PREPROCESS 152 | PrepareImage(SrcImage, Buffers_[0]); 153 | #else 154 | PrepareImage(SrcImage, PreprocessResult_); 155 | cudaMemcpyAsync(Buffers_[0], PreprocessResult_.data(), BuffersDataSize_[0], cudaMemcpyHostToDevice, Stream_); 156 | #endif 157 | 158 | PtrContext_->enqueueV2(Buffers_, Stream_, nullptr); 159 | 160 | cudaMemsetAsync(GpuOutputCount_, 0, 4, Stream_); 161 | GetNmsBeforeBoxes((float*)Buffers_[1], Postprocess_.CoordIndex, Postprocess_.ClassNum, Postprocess_.ObjectThresh, NmsBeforeMaxNum_, 162 | GpuOutputRects_, GpuOutputCount_, Stream_); 163 | 164 | cudaMemcpyAsync(CpuOutputCount_, GpuOutputCount_, sizeof(int), cudaMemcpyDeviceToHost, Stream_); 165 | cudaMemcpyAsync(CpuOutputRects_, GpuOutputRects_, sizeof(DetectRect) * NmsBeforeMaxNum_, cudaMemcpyDeviceToHost, Stream_); 166 | 167 | cudaStreamSynchronize(Stream_); 168 | 169 | // Postprocess 170 | int ret = Postprocess_.GetConvDetectionResult(CpuOutputRects_, CpuOutputCount_, DetectiontRects_); 171 | } 172 | 173 | void CNN::PrepareImage(cv::Mat &SrcImage, std::vector &PreprocessResult) 174 | { 175 | float *Imagedata = PreprocessResult.data(); 176 | 177 | cv::Mat rsz_img; 178 | cv::resize(SrcImage, rsz_img, cv::Size(InputImageWidth_, InputImageHeight_)); 179 | rsz_img.convertTo(rsz_img, CV_32FC3, 1.0 / 255); 180 | 181 | // HWC TO CHW 182 | int channelLength = InputImageWidth_ * InputImageHeight_; 183 | std::vector split_img = {cv::Mat(InputImageHeight_, InputImageWidth_, CV_32FC1, Imagedata + channelLength * 2), 184 | cv::Mat(InputImageHeight_, InputImageWidth_, CV_32FC1, Imagedata + channelLength * 1), 185 | cv::Mat(InputImageHeight_, InputImageWidth_, CV_32FC1, Imagedata + channelLength * 0)}; 186 | 187 | cv::split(rsz_img, split_img); 188 | } 189 | 190 | 191 | void CNN::PrepareImage(cv::Mat &SrcImage, void *InputBuffer) 192 | { 193 | #if USE_GPU_PREPROCESS 194 | int src_width = SrcImage.cols; 195 | int src_height = SrcImage.rows; 196 | int src_channel = SrcImage.channels(); 197 | 198 | NppiSize dstSize = {InputImageWidth_, InputImageHeight_}; 199 | NppiRect dstROI = {0, 0, InputImageWidth_, InputImageHeight_}; 200 | if (GpuImgResizeBuf_ == nullptr) 201 | { 202 | cudaMalloc(&GpuImgResizeBuf_, InputImageWidth_ * InputImageHeight_ * src_channel * sizeof(uchar)); 203 | cudaMalloc(&GpuImgF32Buf_, InputImageWidth_ * InputImageHeight_ * src_channel * sizeof(float)); 204 | cudaMalloc(&GpuDataPlanes_, InputImageWidth_ * InputImageHeight_ * src_channel * sizeof(float)); 205 | } 206 | 207 | 208 | NppiSize srcSize = {src_width, src_height}; 209 | NppiRect srcROI = {0, 0, src_width, src_height}; 210 | if(GpuSrcImgBuf_ == nullptr) 211 | { 212 | cudaMalloc(&GpuSrcImgBuf_, src_width * src_height * src_channel * sizeof(uchar)); 213 | } 214 | 215 | DstPlanes_[0] = GpuDataPlanes_; 216 | DstPlanes_[1] = GpuDataPlanes_ + InputImageWidth_ * InputImageHeight_; 217 | DstPlanes_[2] = GpuDataPlanes_ + InputImageWidth_ * InputImageHeight_ * 2; 218 | 219 | // 将cpu图像拷贝到gpu 220 | cudaMemcpy(GpuSrcImgBuf_, (void *)SrcImage.data, src_width * src_height * src_channel, cudaMemcpyHostToDevice); 221 | 222 | // resize 223 | nppiResize_8u_C3R(GpuSrcImgBuf_, src_width * src_channel, srcSize, srcROI, GpuImgResizeBuf_, InputImageWidth_ * src_channel, dstSize, dstROI, NPPI_INTER_LINEAR); 224 | 225 | // bgr 转 rgb 226 | nppiSwapChannels_8u_C3IR(GpuImgResizeBuf_, InputImageWidth_ * src_channel, dstSize, DstOrder_); 227 | 228 | // int8(uchar) 转 f32 229 | nppiConvert_8u32f_C3R(GpuImgResizeBuf_, InputImageWidth_ * src_channel, GpuImgF32Buf_, InputImageWidth_ * src_channel * sizeof(float), dstSize); 230 | 231 | // 减均值、除方差 232 | nppiMulC_32f_C3IR(MeanScale_, GpuImgF32Buf_, InputImageWidth_ * src_channel * sizeof(float), dstSize); 233 | 234 | nppiCopy_32f_C3P3R(GpuImgF32Buf_, InputImageWidth_ * src_channel * sizeof(float), DstPlanes_, InputImageWidth_ * sizeof(float), dstSize); 235 | 236 | cudaMemcpyAsync(InputBuffer, GpuDataPlanes_, src_channel * InputImageWidth_ * InputImageHeight_ * sizeof(float), cudaMemcpyDeviceToDevice, Stream_); 237 | #endif 238 | } 239 | -------------------------------------------------------------------------------- /src/CNN.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CNN_HPP 2 | #define CNN_HPP 3 | 4 | #include "NvInfer.h" 5 | #include "postprocess_cuda.hpp" 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include "common_struct.hpp" 14 | 15 | class CNN 16 | { 17 | public: 18 | CNN(const std::string &OnnxFilePath, const std::string &SaveTrtFilePath, int BatchSize, int InputChannel, int InputImageWidth, int InputImageHeight); 19 | ~CNN(); 20 | 21 | void Inference(cv::Mat &SrcImage); 22 | 23 | std::vector DetectiontRects_; 24 | 25 | private: 26 | void ModelInit(); 27 | void PrepareImage(cv::Mat &SrcImage, std::vector &PreprocessResult); 28 | void PrepareImage(cv::Mat &vec_img, void *InputBuffer); 29 | 30 | std::string OnnxFilePath_; 31 | std::string SaveTrtFilePath_; 32 | 33 | int BatchSize_ = 0; 34 | int InputChannel_ = 0; 35 | int InputImageWidth_ = 0; 36 | int InputImageHeight_ = 0; 37 | int ModelOutputSize_ = 0; 38 | 39 | nvinfer1::ICudaEngine *PtrEngine_ = nullptr; 40 | nvinfer1::IExecutionContext *PtrContext_ = nullptr; 41 | cudaStream_t Stream_; 42 | 43 | void *Buffers_[10]; 44 | std::vector BuffersDataSize_; 45 | std::vector PreprocessResult_; 46 | 47 | Npp8u *GpuSrcImgBuf_ = nullptr; // gpu:装 src 图像 48 | Npp8u *GpuImgResizeBuf_ = nullptr; // gpu 装 resize后的图像 49 | Npp32f *GpuImgF32Buf_ = nullptr; // gpu: int8 转 F32 50 | Npp32f *GpuDataPlanes_ = nullptr; 51 | 52 | Npp32f MeanScale_[3] = {0.00392157, 0.00392157, 0.00392157 }; 53 | int DstOrder_[3] = {2, 1, 0 }; 54 | Npp32f* DstPlanes_[3]; 55 | 56 | GetResultRectYolov11 Postprocess_; 57 | const int NmsBeforeMaxNum_ = 512; 58 | int* GpuOutputCount_ = nullptr; 59 | DetectRect *GpuOutputRects_ = nullptr; 60 | 61 | int* CpuOutputCount_ = nullptr; 62 | DetectRect *CpuOutputRects_ = nullptr; 63 | }; 64 | 65 | #endif 66 | -------------------------------------------------------------------------------- /src/common/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_HPP 2 | #define COMMON_HPP 3 | 4 | #include "NvOnnxParser.h" 5 | #include "logging.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // These is necessary if we want to be able to write 1_GiB instead of 1.0_GiB. 12 | // Since the return type is signed, -1_GiB will work as expected. 13 | constexpr long long int operator"" _GiB(long long unsigned int val) 14 | { 15 | return val * (1 << 30); 16 | } 17 | constexpr long long int operator"" _MiB(long long unsigned int val) 18 | { 19 | return val * (1 << 20); 20 | } 21 | constexpr long long int operator"" _KiB(long long unsigned int val) 22 | { 23 | return val * (1 << 10); 24 | } 25 | 26 | inline unsigned int GetElementSize(nvinfer1::DataType t) 27 | { 28 | switch (t) 29 | { 30 | case nvinfer1::DataType::kINT32: 31 | return 4; 32 | case nvinfer1::DataType::kFLOAT: 33 | return 4; 34 | case nvinfer1::DataType::kHALF: 35 | return 2; 36 | case nvinfer1::DataType::kBOOL: 37 | case nvinfer1::DataType::kINT8: 38 | return 1; 39 | } 40 | throw std::runtime_error("Invalid DataType."); 41 | return 0; 42 | } 43 | 44 | inline int64_t Volume(const nvinfer1::Dims &d) 45 | { 46 | return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies()); 47 | } 48 | 49 | Logger gLogger{Logger::Severity::kINFO}; 50 | LogStreamConsumer gLogVerbose{LOG_VERBOSE(gLogger)}; 51 | LogStreamConsumer gLogInfo{LOG_INFO(gLogger)}; 52 | LogStreamConsumer gLogWarning{LOG_WARN(gLogger)}; 53 | LogStreamConsumer gLogError{LOG_ERROR(gLogger)}; 54 | LogStreamConsumer gLogFatal{LOG_FATAL(gLogger)}; 55 | 56 | void setReportableSeverity(Logger::Severity severity) 57 | { 58 | gLogger.setReportableSeverity(severity); 59 | gLogVerbose.setReportableSeverity(severity); 60 | gLogInfo.setReportableSeverity(severity); 61 | gLogWarning.setReportableSeverity(severity); 62 | gLogError.setReportableSeverity(severity); 63 | gLogFatal.setReportableSeverity(severity); 64 | } 65 | 66 | bool ReadTrtFile(const std::string &engineFile, nvinfer1::ICudaEngine *&engine) 67 | { 68 | std::string cached_engine; 69 | std::fstream file; 70 | std::cout << "loading filename from:" << engineFile << std::endl; 71 | nvinfer1::IRuntime *trtRuntime; 72 | file.open(engineFile, std::ios::binary | std::ios::in); 73 | 74 | if (!file.is_open()) 75 | { 76 | std::cout << "read file error: " << engineFile << std::endl; 77 | cached_engine = ""; 78 | } 79 | 80 | while (file.peek() != EOF) 81 | { 82 | std::stringstream buffer; 83 | buffer << file.rdbuf(); 84 | cached_engine.append(buffer.str()); 85 | } 86 | file.close(); 87 | 88 | trtRuntime = nvinfer1::createInferRuntime(gLogger.getTRTLogger()); 89 | engine = trtRuntime->deserializeCudaEngine(cached_engine.data(), cached_engine.size(), nullptr); 90 | std::cout << "deserialize done" << std::endl; 91 | 92 | return true; 93 | } 94 | 95 | void OnnxToTRTModel(const std::string &modelFile, const std::string &filename, nvinfer1::ICudaEngine *&engine, const int &BATCH_SIZE) 96 | { 97 | // create the builder 98 | nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(gLogger.getTRTLogger()); 99 | assert(builder != nullptr); 100 | 101 | const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 102 | auto network = builder->createNetworkV2(explicitBatch); 103 | auto config = builder->createBuilderConfig(); 104 | 105 | auto parser = nvonnxparser::createParser(*network, gLogger.getTRTLogger()); 106 | if (!parser->parseFromFile(modelFile.c_str(), static_cast(gLogger.getReportableSeverity()))) 107 | { 108 | gLogError << "Failure while parsing ONNX file" << std::endl; 109 | } 110 | 111 | // Build the engine 112 | builder->setMaxBatchSize(BATCH_SIZE); 113 | config->setMaxWorkspaceSize(1_GiB); 114 | config->setFlag(nvinfer1::BuilderFlag::kFP16); 115 | 116 | std::cout << "start building engine" << std::endl; 117 | engine = builder->buildEngineWithConfig(*network, *config); 118 | std::cout << "build engine done" << std::endl; 119 | assert(engine); 120 | 121 | // we can destroy the parser 122 | parser->destroy(); 123 | // save engine 124 | nvinfer1::IHostMemory *data = engine->serialize(); 125 | std::ofstream file; 126 | file.open(filename, std::ios::binary | std::ios::out); 127 | std::cout << "writing engine file..." << std::endl; 128 | file.write((const char *)data->data(), data->size()); 129 | std::cout << "save engine file done" << std::endl; 130 | file.close(); 131 | 132 | // then close everything down 133 | network->destroy(); 134 | builder->destroy(); 135 | } 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /src/common/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | using Severity = nvinfer1::ILogger::Severity; 30 | 31 | class LogStreamConsumerBuffer : public std::stringbuf 32 | { 33 | public: 34 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 35 | : mOutput(stream) 36 | , mPrefix(prefix) 37 | , mShouldLog(shouldLog) 38 | { 39 | } 40 | 41 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 42 | : mOutput(other.mOutput) 43 | { 44 | } 45 | 46 | ~LogStreamConsumerBuffer() 47 | { 48 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 49 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 50 | // if the pointer to the beginning is not equal to the pointer to the current position, 51 | // call putOutput() to log the output to the stream 52 | if (pbase() != pptr()) 53 | { 54 | putOutput(); 55 | } 56 | } 57 | 58 | // synchronizes the stream buffer and returns 0 on success 59 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 60 | // resetting the buffer and flushing the stream 61 | virtual int sync() 62 | { 63 | putOutput(); 64 | return 0; 65 | } 66 | 67 | void putOutput() 68 | { 69 | if (mShouldLog) 70 | { 71 | // prepend timestamp 72 | std::time_t timestamp = std::time(nullptr); 73 | tm* tm_local = std::localtime(×tamp); 74 | std::cout << "["; 75 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 76 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 77 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 78 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 81 | // std::stringbuf::str() gets the string contents of the buffer 82 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 83 | mOutput << mPrefix << str(); 84 | // set the buffer to empty 85 | str(""); 86 | // flush the stream 87 | mOutput.flush(); 88 | } 89 | } 90 | 91 | void setShouldLog(bool shouldLog) 92 | { 93 | mShouldLog = shouldLog; 94 | } 95 | 96 | private: 97 | std::ostream& mOutput; 98 | std::string mPrefix; 99 | bool mShouldLog; 100 | }; 101 | 102 | //! 103 | //! \class LogStreamConsumerBase 104 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 105 | //! 106 | class LogStreamConsumerBase 107 | { 108 | public: 109 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 110 | : mBuffer(stream, prefix, shouldLog) 111 | { 112 | } 113 | 114 | protected: 115 | LogStreamConsumerBuffer mBuffer; 116 | }; 117 | 118 | //! 119 | //! \class LogStreamConsumer 120 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 121 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 122 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 123 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 124 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 125 | //! Please do not change the order of the parent classes. 126 | //! 127 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 128 | { 129 | public: 130 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 131 | //! Reportable severity determines if the messages are severe enough to be logged. 132 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 133 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 134 | , std::ostream(&mBuffer) // links the stream buffer with the stream 135 | , mShouldLog(severity <= reportableSeverity) 136 | , mSeverity(severity) 137 | { 138 | } 139 | 140 | LogStreamConsumer(LogStreamConsumer&& other) 141 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 142 | , std::ostream(&mBuffer) // links the stream buffer with the stream 143 | , mShouldLog(other.mShouldLog) 144 | , mSeverity(other.mSeverity) 145 | { 146 | } 147 | 148 | void setReportableSeverity(Severity reportableSeverity) 149 | { 150 | mShouldLog = mSeverity <= reportableSeverity; 151 | mBuffer.setShouldLog(mShouldLog); 152 | } 153 | 154 | private: 155 | static std::ostream& severityOstream(Severity severity) 156 | { 157 | return severity >= Severity::kINFO ? std::cout : std::cerr; 158 | } 159 | 160 | static std::string severityPrefix(Severity severity) 161 | { 162 | switch (severity) 163 | { 164 | case Severity::kINTERNAL_ERROR: return "[F] "; 165 | case Severity::kERROR: return "[E] "; 166 | case Severity::kWARNING: return "[W] "; 167 | case Severity::kINFO: return "[I] "; 168 | case Severity::kVERBOSE: return "[V] "; 169 | default: assert(0); return ""; 170 | } 171 | } 172 | 173 | bool mShouldLog; 174 | Severity mSeverity; 175 | }; 176 | 177 | //! \class Logger 178 | //! 179 | //! \brief Class which manages logging of TensorRT tools and samples 180 | //! 181 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 182 | //! and supports logging two types of messages: 183 | //! 184 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 185 | //! - Test pass/fail messages 186 | //! 187 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 188 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 189 | //! 190 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 191 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 192 | //! 193 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 194 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 195 | //! library and messages coming from the sample. 196 | //! 197 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 198 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 199 | //! object. 200 | 201 | class Logger : public nvinfer1::ILogger 202 | { 203 | public: 204 | Logger(Severity severity = Severity::kWARNING) 205 | : mReportableSeverity(severity) 206 | { 207 | } 208 | 209 | //! 210 | //! \enum TestResult 211 | //! \brief Represents the state of a given test 212 | //! 213 | enum class TestResult 214 | { 215 | kRUNNING, //!< The test is running 216 | kPASSED, //!< The test passed 217 | kFAILED, //!< The test failed 218 | kWAIVED //!< The test was waived 219 | }; 220 | 221 | //! 222 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 223 | //! \return The nvinfer1::ILogger associated with this Logger 224 | //! 225 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 226 | //! we can eliminate the inheritance of Logger from ILogger 227 | //! 228 | nvinfer1::ILogger& getTRTLogger() 229 | { 230 | return *this; 231 | } 232 | 233 | //! 234 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 235 | //! 236 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 237 | //! inheritance from nvinfer1::ILogger 238 | //! 239 | void log(Severity severity, const char* msg) noexcept 240 | { 241 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 242 | } 243 | 244 | //! 245 | //! \brief Method for controlling the verbosity of logging output 246 | //! 247 | //! \param severity The logger will only emit messages that have severity of this level or higher. 248 | //! 249 | void setReportableSeverity(Severity severity) 250 | { 251 | mReportableSeverity = severity; 252 | } 253 | 254 | //! 255 | //! \brief Opaque handle that holds logging information for a particular test 256 | //! 257 | //! This object is an opaque handle to information used by the Logger to print test results. 258 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 259 | //! with Logger::reportTest{Start,End}(). 260 | //! 261 | class TestAtom 262 | { 263 | public: 264 | TestAtom(TestAtom&&) = default; 265 | 266 | private: 267 | friend class Logger; 268 | 269 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 270 | : mStarted(started) 271 | , mName(name) 272 | , mCmdline(cmdline) 273 | { 274 | } 275 | 276 | bool mStarted; 277 | std::string mName; 278 | std::string mCmdline; 279 | }; 280 | 281 | //! 282 | //! \brief Define a test for logging 283 | //! 284 | //! \param[in] name The name of the test. This should be a string starting with 285 | //! "TensorRT" and containing dot-separated strings containing 286 | //! the characters [A-Za-z0-9_]. 287 | //! For example, "TensorRT.sample_googlenet" 288 | //! \param[in] cmdline The command line used to reproduce the test 289 | // 290 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 291 | //! 292 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 293 | { 294 | return TestAtom(false, name, cmdline); 295 | } 296 | 297 | //! 298 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 299 | //! as input 300 | //! 301 | //! \param[in] name The name of the test 302 | //! \param[in] argc The number of command-line arguments 303 | //! \param[in] argv The array of command-line arguments (given as C strings) 304 | //! 305 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 306 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 307 | { 308 | auto cmdline = genCmdlineString(argc, argv); 309 | return defineTest(name, cmdline); 310 | } 311 | 312 | //! 313 | //! \brief Report that a test has started. 314 | //! 315 | //! \pre reportTestStart() has not been called yet for the given testAtom 316 | //! 317 | //! \param[in] testAtom The handle to the test that has started 318 | //! 319 | static void reportTestStart(TestAtom& testAtom) 320 | { 321 | reportTestResult(testAtom, TestResult::kRUNNING); 322 | assert(!testAtom.mStarted); 323 | testAtom.mStarted = true; 324 | } 325 | 326 | //! 327 | //! \brief Report that a test has ended. 328 | //! 329 | //! \pre reportTestStart() has been called for the given testAtom 330 | //! 331 | //! \param[in] testAtom The handle to the test that has ended 332 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 333 | //! TestResult::kFAILED, TestResult::kWAIVED 334 | //! 335 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 336 | { 337 | assert(result != TestResult::kRUNNING); 338 | assert(testAtom.mStarted); 339 | reportTestResult(testAtom, result); 340 | } 341 | 342 | static int reportPass(const TestAtom& testAtom) 343 | { 344 | reportTestEnd(testAtom, TestResult::kPASSED); 345 | return EXIT_SUCCESS; 346 | } 347 | 348 | static int reportFail(const TestAtom& testAtom) 349 | { 350 | reportTestEnd(testAtom, TestResult::kFAILED); 351 | return EXIT_FAILURE; 352 | } 353 | 354 | static int reportWaive(const TestAtom& testAtom) 355 | { 356 | reportTestEnd(testAtom, TestResult::kWAIVED); 357 | return EXIT_SUCCESS; 358 | } 359 | 360 | static int reportTest(const TestAtom& testAtom, bool pass) 361 | { 362 | return pass ? reportPass(testAtom) : reportFail(testAtom); 363 | } 364 | 365 | Severity getReportableSeverity() const 366 | { 367 | return mReportableSeverity; 368 | } 369 | 370 | private: 371 | //! 372 | //! \brief returns an appropriate string for prefixing a log message with the given severity 373 | //! 374 | static const char* severityPrefix(Severity severity) 375 | { 376 | switch (severity) 377 | { 378 | case Severity::kINTERNAL_ERROR: return "[F] "; 379 | case Severity::kERROR: return "[E] "; 380 | case Severity::kWARNING: return "[W] "; 381 | case Severity::kINFO: return "[I] "; 382 | case Severity::kVERBOSE: return "[V] "; 383 | default: assert(0); return ""; 384 | } 385 | } 386 | 387 | //! 388 | //! \brief returns an appropriate string for prefixing a test result message with the given result 389 | //! 390 | static const char* testResultString(TestResult result) 391 | { 392 | switch (result) 393 | { 394 | case TestResult::kRUNNING: return "RUNNING"; 395 | case TestResult::kPASSED: return "PASSED"; 396 | case TestResult::kFAILED: return "FAILED"; 397 | case TestResult::kWAIVED: return "WAIVED"; 398 | default: assert(0); return ""; 399 | } 400 | } 401 | 402 | //! 403 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 404 | //! 405 | static std::ostream& severityOstream(Severity severity) 406 | { 407 | return severity >= Severity::kINFO ? std::cout : std::cerr; 408 | } 409 | 410 | //! 411 | //! \brief method that implements logging test results 412 | //! 413 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 414 | { 415 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 416 | << testAtom.mCmdline << std::endl; 417 | } 418 | 419 | //! 420 | //! \brief generate a command line string from the given (argc, argv) values 421 | //! 422 | static std::string genCmdlineString(int argc, char const* const* argv) 423 | { 424 | std::stringstream ss; 425 | for (int i = 0; i < argc; i++) 426 | { 427 | if (i > 0) 428 | ss << " "; 429 | ss << argv[i]; 430 | } 431 | return ss.str(); 432 | } 433 | 434 | Severity mReportableSeverity; 435 | }; 436 | 437 | namespace 438 | { 439 | 440 | //! 441 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 442 | //! 443 | //! Example usage: 444 | //! 445 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 446 | //! 447 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 448 | { 449 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 450 | } 451 | 452 | //! 453 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 454 | //! 455 | //! Example usage: 456 | //! 457 | //! LOG_INFO(logger) << "hello world" << std::endl; 458 | //! 459 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 460 | { 461 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 462 | } 463 | 464 | //! 465 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 466 | //! 467 | //! Example usage: 468 | //! 469 | //! LOG_WARN(logger) << "hello world" << std::endl; 470 | //! 471 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 472 | { 473 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 474 | } 475 | 476 | //! 477 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 478 | //! 479 | //! Example usage: 480 | //! 481 | //! LOG_ERROR(logger) << "hello world" << std::endl; 482 | //! 483 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 484 | { 485 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 486 | } 487 | 488 | //! 489 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 490 | // ("fatal" severity) 491 | //! 492 | //! Example usage: 493 | //! 494 | //! LOG_FATAL(logger) << "hello world" << std::endl; 495 | //! 496 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 497 | { 498 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 499 | } 500 | 501 | } // anonymous namespace 502 | 503 | #endif // TENSORRT_LOGGING_H 504 | -------------------------------------------------------------------------------- /src/common_struct.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_STRUCT_HPP 2 | #define COMMON_STRUCT_HPP 3 | 4 | struct DetectRect 5 | { 6 | float classId; 7 | float score; 8 | float xmin; 9 | float ymin; 10 | float xmax; 11 | float ymax; 12 | }; 13 | 14 | #endif 15 | -------------------------------------------------------------------------------- /src/kernels/.ipynb_checkpoints/get_nms_before_boxes-checkpoint.cu: -------------------------------------------------------------------------------- 1 | #include "get_nms_before_boxes.cuh" 2 | 3 | 4 | __global__ void GetNmsBeforeBoxesKernel(float *SrcInput, int AnchorCount, int ClassNum, float ObjectThresh, int NmsBeforeMaxNum, DetectRect* OutputRects, int *OutputCount) 5 | { 6 | int ThreadId = blockIdx.x * blockDim.x + threadIdx.x; 7 | 8 | if (ThreadId >= AnchorCount) 9 | { 10 | return; 11 | } 12 | 13 | float* XywhConf = SrcInput + ThreadId; 14 | float CenterX = 0, CenterY = 0, CenterW = 0, CenterH = 0; 15 | 16 | float MaxScore = 0; 17 | int MaxIndex = 0; 18 | 19 | DetectRect TempRect; 20 | for (int j = 4; j < ClassNum + 4; j ++) 21 | { 22 | if (4 == j) 23 | { 24 | MaxScore = XywhConf[j * AnchorCount]; 25 | MaxIndex = j; 26 | } 27 | else 28 | { 29 | if (MaxScore < XywhConf[j * AnchorCount]) 30 | { 31 | MaxScore = XywhConf[j * AnchorCount]; 32 | MaxIndex = j; 33 | } 34 | } 35 | } 36 | 37 | if (MaxScore > ObjectThresh) 38 | { 39 | int index = atomicAdd(OutputCount, 1); 40 | 41 | if (index > NmsBeforeMaxNum) 42 | { 43 | return; 44 | } 45 | 46 | CenterX = XywhConf[0 * AnchorCount]; 47 | CenterY = XywhConf[1 * AnchorCount]; 48 | CenterW = XywhConf[2 * AnchorCount]; 49 | CenterH = XywhConf[3 * AnchorCount ]; 50 | 51 | TempRect.classId = MaxIndex - 4; 52 | TempRect.score = MaxScore; 53 | TempRect.xmin = CenterX - 0.5 * CenterW; 54 | TempRect.ymin = CenterY - 0.5 * CenterH; 55 | TempRect.xmax = CenterX + 0.5 * CenterW; 56 | TempRect.ymax = CenterY + 0.5 * CenterH; 57 | 58 | OutputRects[index] = TempRect; 59 | } 60 | } 61 | 62 | 63 | void GetNmsBeforeBoxes(float *SrcInput, int AnchorCount, int ClassNum, float ObjectThresh, int NmsBeforeMaxNum, DetectRect* OutputRects, int *OutputCount, cudaStream_t Stream) 64 | { 65 | int Block = 512; 66 | int Grid = (AnchorCount + Block - 1) / Block; 67 | 68 | GetNmsBeforeBoxesKernel<<>>(SrcInput, AnchorCount, ClassNum, ObjectThresh, NmsBeforeMaxNum, OutputRects, OutputCount); 69 | return; 70 | } 71 | 72 | 73 | -------------------------------------------------------------------------------- /src/kernels/.ipynb_checkpoints/get_nms_before_boxes-checkpoint.cuh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov11_tensorRT_postprocess_cuda/7c9d3e4fd4b50978fe68756d77d0c568893279c3/src/kernels/.ipynb_checkpoints/get_nms_before_boxes-checkpoint.cuh -------------------------------------------------------------------------------- /src/kernels/get_nms_before_boxes.cu: -------------------------------------------------------------------------------- 1 | #include "get_nms_before_boxes.cuh" 2 | 3 | 4 | __global__ void GetNmsBeforeBoxesKernel(float *SrcInput, int AnchorCount, int ClassNum, float ObjectThresh, int NmsBeforeMaxNum, DetectRect* OutputRects, int *OutputCount) 5 | { 6 | int ThreadId = blockIdx.x * blockDim.x + threadIdx.x; 7 | 8 | if (ThreadId >= AnchorCount) 9 | { 10 | return; 11 | } 12 | 13 | float* XywhConf = SrcInput + ThreadId; 14 | float CenterX = 0, CenterY = 0, CenterW = 0, CenterH = 0; 15 | 16 | float MaxScore = 0; 17 | int MaxIndex = 0; 18 | 19 | DetectRect TempRect; 20 | for (int j = 4; j < ClassNum + 4; j ++) 21 | { 22 | if (4 == j) 23 | { 24 | MaxScore = XywhConf[j * AnchorCount]; 25 | MaxIndex = j; 26 | } 27 | else 28 | { 29 | if (MaxScore < XywhConf[j * AnchorCount]) 30 | { 31 | MaxScore = XywhConf[j * AnchorCount]; 32 | MaxIndex = j; 33 | } 34 | } 35 | } 36 | 37 | if (MaxScore > ObjectThresh) 38 | { 39 | int index = atomicAdd(OutputCount, 1); 40 | 41 | if (index > NmsBeforeMaxNum) 42 | { 43 | return; 44 | } 45 | 46 | CenterX = XywhConf[0 * AnchorCount]; 47 | CenterY = XywhConf[1 * AnchorCount]; 48 | CenterW = XywhConf[2 * AnchorCount]; 49 | CenterH = XywhConf[3 * AnchorCount ]; 50 | 51 | TempRect.classId = MaxIndex - 4; 52 | TempRect.score = MaxScore; 53 | TempRect.xmin = CenterX - 0.5 * CenterW; 54 | TempRect.ymin = CenterY - 0.5 * CenterH; 55 | TempRect.xmax = CenterX + 0.5 * CenterW; 56 | TempRect.ymax = CenterY + 0.5 * CenterH; 57 | 58 | OutputRects[index] = TempRect; 59 | } 60 | } 61 | 62 | 63 | void GetNmsBeforeBoxes(float *SrcInput, int AnchorCount, int ClassNum, float ObjectThresh, int NmsBeforeMaxNum, DetectRect* OutputRects, int *OutputCount, cudaStream_t Stream) 64 | { 65 | int Block = 512; 66 | int Grid = (AnchorCount + Block - 1) / Block; 67 | 68 | GetNmsBeforeBoxesKernel<<>>(SrcInput, AnchorCount, ClassNum, ObjectThresh, NmsBeforeMaxNum, OutputRects, OutputCount); 69 | return; 70 | } 71 | 72 | 73 | -------------------------------------------------------------------------------- /src/kernels/get_nms_before_boxes.cuh: -------------------------------------------------------------------------------- 1 | #ifndef GET_NMS_BEFORE_BOXES_CUH__ 2 | #define GET_NMS_BEFORE_BOXES_CUH__ 3 | 4 | #include 5 | #include "../common_struct.hpp" 6 | 7 | void GetNmsBeforeBoxes(float *SrcInput, int AnchorCount, int ClassNum, float ObjectThresh, int NMSBeforeMaxNum, DetectRect* OutputRects, int *OutputCount, cudaStream_t Stream); 8 | 9 | 10 | #endif 11 | 12 | -------------------------------------------------------------------------------- /src/postprocess_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include "postprocess_cuda.hpp" 2 | #include 3 | #include 4 | 5 | #define ZQ_MAX(a, b) ((a) > (b) ? (a) : (b)) 6 | #define ZQ_MIN(a, b) ((a) < (b) ? (a) : (b)) 7 | 8 | 9 | static inline float IOU(float XMin1, float YMin1, float XMax1, float YMax1, float XMin2, float YMin2, float XMax2, float YMax2) 10 | { 11 | float Inter = 0; 12 | float Total = 0; 13 | float XMin = 0; 14 | float YMin = 0; 15 | float XMax = 0; 16 | float YMax = 0; 17 | float Area1 = 0; 18 | float Area2 = 0; 19 | float InterWidth = 0; 20 | float InterHeight = 0; 21 | 22 | XMin = ZQ_MAX(XMin1, XMin2); 23 | YMin = ZQ_MAX(YMin1, YMin2); 24 | XMax = ZQ_MIN(XMax1, XMax2); 25 | YMax = ZQ_MIN(YMax1, YMax2); 26 | 27 | InterWidth = XMax - XMin; 28 | InterHeight = YMax - YMin; 29 | 30 | InterWidth = (InterWidth >= 0) ? InterWidth : 0; 31 | InterHeight = (InterHeight >= 0) ? InterHeight : 0; 32 | 33 | Inter = InterWidth * InterHeight; 34 | 35 | Area1 = (XMax1 - XMin1) * (YMax1 - YMin1); 36 | Area2 = (XMax2 - XMin2) * (YMax2 - YMin2); 37 | 38 | Total = Area1 + Area2 - Inter; 39 | 40 | return float(Inter) / float(Total); 41 | } 42 | 43 | /****** yolov11 ****/ 44 | GetResultRectYolov11::GetResultRectYolov11() 45 | { 46 | CoordIndex = MapSize[0][0] * MapSize[0][1] + MapSize[1][0] * MapSize[1][1] + MapSize[2][0] * MapSize[2][1]; 47 | } 48 | 49 | GetResultRectYolov11::~GetResultRectYolov11() 50 | { 51 | } 52 | 53 | 54 | int GetResultRectYolov11::GetConvDetectionResult(DetectRect *OutputRects, int *OutputCount, std::vector &DetectiontRects) 55 | { 56 | int ret = 0; 57 | std::vector detectRects; 58 | float xmin = 0, ymin = 0, xmax = 0, ymax = 0; 59 | 60 | DetectRect temp; 61 | for (int i = 0; i < *OutputCount; i ++) 62 | { 63 | xmin = OutputRects[i].xmin; 64 | ymin = OutputRects[i].ymin; 65 | xmax = OutputRects[i].xmax; 66 | ymax = OutputRects[i].ymax; 67 | 68 | xmin = xmin > 0 ? xmin : 0; 69 | ymin = ymin > 0 ? ymin : 0; 70 | xmax = xmax < InputW ? xmax : InputW; 71 | ymax = ymax < InputH ? ymax : InputH; 72 | 73 | temp.xmin = xmin / InputW; 74 | temp.ymin = ymin / InputH; 75 | temp.xmax = xmax / InputW; 76 | temp.ymax = ymax / InputH; 77 | temp.classId = OutputRects[i].classId; 78 | temp.score = OutputRects[i].score; 79 | detectRects.push_back(temp); 80 | } 81 | 82 | std::sort(detectRects.begin(), detectRects.end(), [](DetectRect &Rect1, DetectRect &Rect2) -> bool 83 | { return (Rect1.score > Rect2.score); }); 84 | 85 | // std::cout << "NMS Before num :" << detectRects.size() << std::endl; 86 | for (int i = 0; i < detectRects.size(); ++i) 87 | { 88 | float xmin1 = detectRects[i].xmin; 89 | float ymin1 = detectRects[i].ymin; 90 | float xmax1 = detectRects[i].xmax; 91 | float ymax1 = detectRects[i].ymax; 92 | int classId = detectRects[i].classId; 93 | float score = detectRects[i].score; 94 | 95 | if (classId != -1) 96 | { 97 | DetectiontRects.push_back(float(classId)); 98 | DetectiontRects.push_back(float(score)); 99 | DetectiontRects.push_back(float(xmin1)); 100 | DetectiontRects.push_back(float(ymin1)); 101 | DetectiontRects.push_back(float(xmax1)); 102 | DetectiontRects.push_back(float(ymax1)); 103 | 104 | for (int j = i + 1; j < detectRects.size(); ++j) 105 | { 106 | float xmin2 = detectRects[j].xmin; 107 | float ymin2 = detectRects[j].ymin; 108 | float xmax2 = detectRects[j].xmax; 109 | float ymax2 = detectRects[j].ymax; 110 | float iou = IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2); 111 | if (iou > NmsThresh) 112 | { 113 | detectRects[j].classId = -1; 114 | } 115 | } 116 | } 117 | } 118 | 119 | return ret; 120 | } 121 | -------------------------------------------------------------------------------- /src/postprocess_cuda.hpp: -------------------------------------------------------------------------------- 1 | #ifndef _POSTPROCESS_CUDA_HPP_ 2 | #define _POSTPROCESS_CUDA_HPP_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include "common_struct.hpp" 9 | 10 | typedef signed char int8_t; 11 | typedef unsigned int uint32_t; 12 | 13 | 14 | // yolov11 15 | class GetResultRectYolov11 16 | { 17 | public: 18 | GetResultRectYolov11(); 19 | 20 | ~GetResultRectYolov11(); 21 | 22 | int GetConvDetectionResult(DetectRect *OutputRects, int *OutputCount, std::vector &DetectiontRects); 23 | 24 | 25 | public: 26 | 27 | const int ClassNum = 80; 28 | 29 | int InputW = 640; 30 | int InputH = 640; 31 | 32 | int MapSize[3][2] = {{80, 80}, {40, 40}, {20, 20}}; 33 | int CoordIndex = 0; 34 | 35 | float NmsThresh = 0.45; 36 | float ObjectThresh = 0.5; 37 | }; 38 | 39 | #endif 40 | --------------------------------------------------------------------------------