├── CMakeLists.txt ├── README.md ├── images ├── result.jpg └── test.jpg ├── main.cpp ├── models └── yolov8n_ZQ.onnx └── src ├── CNN.cpp ├── CNN.hpp ├── common ├── common.hpp └── logging.h ├── postprocess.cpp └── postprocess.hpp /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | 3 | project(yolo_trt) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | 7 | # CUDA 8 | find_package(CUDA REQUIRED) 9 | message(STATUS "Find CUDA include at ${CUDA_INCLUDE_DIRS}") 10 | message(STATUS "Find CUDA libraries: ${CUDA_LIBRARIES}") 11 | 12 | # TensorRT 13 | set(TENSORRT_ROOT /home/zq/Detect/TensorRT-8.6.1.6) 14 | 15 | find_path(TENSORRT_INCLUDE_DIR NvInfer.h 16 | HINTS ${TENSORRT_ROOT} PATH_SUFFIXES include/) 17 | message(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") 18 | find_library(TENSORRT_LIBRARY_INFER nvinfer 19 | HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 20 | PATH_SUFFIXES lib lib64 lib/x64) 21 | find_library(TENSORRT_LIBRARY_ONNXPARSER nvonnxparser 22 | HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 23 | PATH_SUFFIXES lib lib64 lib/x64) 24 | set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_ONNXPARSER}) 25 | message(STATUS "Find TensorRT libs: ${TENSORRT_LIBRARY}") 26 | 27 | # OpenCV 28 | find_package(OpenCV REQUIRED) 29 | message(STATUS "Find OpenCV include at ${OpenCV_INCLUDE_DIRS}") 30 | message(STATUS "Find OpenCV libraries: ${OpenCV_LIBRARIES}") 31 | 32 | set(COMMON_INCLUDE ./common) 33 | 34 | include_directories(${CUDA_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR} ${OpenCV_INCLUDE_DIRS} ${COMMON_INCLUDE}) 35 | 36 | add_executable(yolo_trt main.cpp src/CNN.cpp src/postprocess.cpp) 37 | target_link_libraries(yolo_trt ${OpenCV_LIBRARIES} ${CUDA_LIBRARIES} ${TENSORRT_LIBRARY} ${CUDA_npp_LIBRARY} ${CUDA_cublas_LIBRARY}) 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # yolov8 tensorRT 的 C++ 部署 2 | 3 | 本示例中,包含完整的代码、模型、测试图片、测试结果。 4 | 5 | TensorRT版本:TensorRT-7.1.3.4 6 | 7 | ## 导出onnx模型 8 | 9 | 导出适配本实例的onnx模型参考[【yolov8 导出onnx-2023年11月15日版本】](https://blog.csdn.net/zhangqian_1/article/details/134438275)。 10 | 11 | 12 | ## 编译 13 | 14 | 修改 CMakeLists.txt 对应的TensorRT位置 15 | 16 | ![17012439652981](https://github.com/cqu20160901/yolov8_tensorRT_Cplusplus/assets/22290931/68586322-556a-42ef-8444-1064b42c86f9) 17 | 18 | ```powershell 19 | cd yolov8_tensorRT_Cplusplus 20 | mkdir build 21 | cd build 22 | cmake .. 23 | make 24 | ``` 25 | 26 | ## 运行 27 | 28 | ```powershell 29 | # 运行时如果.trt模型存在则直接加载,若不存会自动先将onnx转换成 trt 模型,并存在给定的位置,然后运行推理。 30 | cd build 31 | ./yolo_trt 32 | ``` 33 | 34 | ## 测试效果 35 | 36 | onnx 测试效果 37 | 38 | ![image](https://github.com/cqu20160901/yolov8_tensorRT_Cplusplus/assets/22290931/8574c0ce-fc56-4b3c-9c7e-ec31e29b01ed) 39 | 40 | tensorRT 测试效果 41 | 42 | ![result](https://github.com/cqu20160901/yolov8_tensorRT_Cplusplus/assets/22290931/29a8115d-a5ce-4c58-9b1a-c48766cdfcd5) 43 | 44 | tensorRT 时耗 45 | 46 | ![image](https://github.com/user-attachments/assets/6490069e-4d8a-48f3-88e7-58eb70ae3abe) 47 | 48 | 49 | 50 | ## 替换模型说明 51 | 52 | 1)按照本实例给的导出onnx方式导出对应的onnx;导出的onnx模型建议simplify后再转trt模型。 53 | 54 | 2)注意修改后处理相关 postprocess.hpp 中相关的参数(类别、输入分辨率等)。 55 | 56 | 修改相关的路径 57 | ```cpp 58 | std::string OnnxFile = "/zhangqian/workspaces1/TensorRT/yolov8_trt_Cplusplus/models/yolov8n_ZQ.onnx"; 59 | std::string SaveTrtFilePath = "/zhangqian/workspaces1/TensorRT/yolov8_trt_Cplusplus/models/yolov8n_ZQ.trt"; 60 | cv::Mat SrcImage = cv::imread("/zhangqian/workspaces1/TensorRT/yolov8_trt_Cplusplus/images/test.jpg"); 61 | 62 | int img_width = SrcImage.cols; 63 | int img_height = SrcImage.rows; 64 | 65 | CNN YOLO(OnnxFile, SaveTrtFilePath, 1, 3, 640, 640); 66 | YOLO.Inference(SrcImage); 67 | 68 | for (int i = 0; i < YOLO.DetectiontRects_.size(); i += 6) 69 | { 70 | int classId = int(YOLO.DetectiontRects_[i + 0]); 71 | float conf = YOLO.DetectiontRects_[i + 1]; 72 | int xmin = int(YOLO.DetectiontRects_[i + 2] * float(img_width) + 0.5); 73 | int ymin = int(YOLO.DetectiontRects_[i + 3] * float(img_height) + 0.5); 74 | int xmax = int(YOLO.DetectiontRects_[i + 4] * float(img_width) + 0.5); 75 | int ymax = int(YOLO.DetectiontRects_[i + 5] * float(img_height) + 0.5); 76 | 77 | char text1[256]; 78 | sprintf(text1, "%d:%.2f", classId, conf); 79 | rectangle(SrcImage, cv::Point(xmin, ymin), cv::Point(xmax, ymax), cv::Scalar(255, 0, 0), 2); 80 | putText(SrcImage, text1, cv::Point(xmin, ymin + 15), cv::FONT_HERSHEY_SIMPLEX, 0.7, cv::Scalar(0, 0, 255), 2); 81 | } 82 | 83 | imwrite("/zhangqian/workspaces1/TensorRT/yolov8_trt_Cplusplus/images/result.jpg", SrcImage); 84 | 85 | printf("== obj: %d \n", int(float(YOLO.DetectiontRects_.size()) / 6.0)); 86 | 87 | ``` 88 | 89 | ## 特别说明 90 | 91 | 本示例只是用来测试流程,模型效果并不保证,且代码整理的布局合理性没有做过多的考虑。 92 | 93 | ## 相关链接 94 | 95 | [yolov8 瑞芯微 RKNN 的 C++部署](https://github.com/cqu20160901/yolov8n_onnx_tensorRT_rknn_horizon) 96 | 97 | [yolov8 瑞芯微RKNN和地平线Horizon芯片仿真测试部署-2023年11月15日版本](https://blog.csdn.net/zhangqian_1/article/details/134438275) 98 | 99 | [yolov8 瑞芯微RKNN和地平线Horizon芯片仿真测试部署](https://blog.csdn.net/zhangqian_1/article/details/128918268) 100 | 101 | 102 | ## 2024-10-06 103 | 104 | ### 1)预处理优化 105 | 106 | 原来为:用 opencv 进行预处理(resize-转rgb-转float-减均值除方差) 107 | 108 | 修改为:用 cuda 提供的 nppi 库进行预处理(resize-转rgb-转float-减均值除方差) 109 | 110 | 优化效果:10FPS情况下 CPU 占用减少 62% 111 | -------------------------------------------------------------------------------- /images/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov8_tensorRT_Cplusplus/ee8bac0498648c414e180ac9eae63b21c257cbdc/images/result.jpg -------------------------------------------------------------------------------- /images/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov8_tensorRT_Cplusplus/ee8bac0498648c414e180ac9eae63b21c257cbdc/images/test.jpg -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "src/CNN.hpp" 2 | #include 3 | #include 4 | #include 5 | 6 | int main() 7 | { 8 | std::string OnnxFile = "/home/zq/Detect/yolov8_tensorRT_Cplusplus-master/models/yolov8n_ZQ.onnx"; 9 | std::string SaveTrtFilePath = "/home/zq/Detect/yolov8_tensorRT_Cplusplus-master/models/yolov8n_ZQ.trt"; 10 | cv::Mat SrcImage = cv::imread("/home/zq/Detect/yolov8_tensorRT_Cplusplus-master/images/test.jpg"); 11 | 12 | int img_width = SrcImage.cols; 13 | int img_height = SrcImage.rows; 14 | 15 | CNN YOLO(OnnxFile, SaveTrtFilePath, 1, 3, 640, 640); 16 | 17 | auto t_start = std::chrono::high_resolution_clock::now(); 18 | int Temp = 1000; 19 | 20 | for (int i = 0; i < Temp; i++) 21 | { 22 | YOLO.Inference(SrcImage); 23 | // std::this_thread::sleep_for(std::chrono::milliseconds(95)); 24 | } 25 | auto t_end = std::chrono::high_resolution_clock::now(); 26 | float total_inf = std::chrono::duration(t_end - t_start).count(); 27 | std::cout << "Info: " << Temp << " times ave cost: " << total_inf / float(Temp) << " ms." << std::endl; 28 | 29 | 30 | 31 | for (int i = 0; i < YOLO.DetectiontRects_.size(); i += 6) 32 | { 33 | int classId = int(YOLO.DetectiontRects_[i + 0]); 34 | float conf = YOLO.DetectiontRects_[i + 1]; 35 | int xmin = int(YOLO.DetectiontRects_[i + 2] * float(img_width) + 0.5); 36 | int ymin = int(YOLO.DetectiontRects_[i + 3] * float(img_height) + 0.5); 37 | int xmax = int(YOLO.DetectiontRects_[i + 4] * float(img_width) + 0.5); 38 | int ymax = int(YOLO.DetectiontRects_[i + 5] * float(img_height) + 0.5); 39 | 40 | char text1[256]; 41 | sprintf(text1, "%d:%.2f", classId, conf); 42 | rectangle(SrcImage, cv::Point(xmin, ymin), cv::Point(xmax, ymax), cv::Scalar(255, 0, 0), 2); 43 | putText(SrcImage, text1, cv::Point(xmin, ymin + 15), cv::FONT_HERSHEY_SIMPLEX, 0.7, cv::Scalar(0, 0, 255), 2); 44 | } 45 | 46 | imwrite("/home/zq/Detect/yolov8_tensorRT_Cplusplus-master/images/result.jpg", SrcImage); 47 | 48 | printf("== obj: %d \n", int(float(YOLO.DetectiontRects_.size()) / 6.0)); 49 | 50 | return 0; 51 | } 52 | -------------------------------------------------------------------------------- /models/yolov8n_ZQ.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cqu20160901/yolov8_tensorRT_Cplusplus/ee8bac0498648c414e180ac9eae63b21c257cbdc/models/yolov8n_ZQ.onnx -------------------------------------------------------------------------------- /src/CNN.cpp: -------------------------------------------------------------------------------- 1 | #include "CNN.hpp" 2 | #include "common/common.hpp" 3 | #include 4 | #include 5 | 6 | 7 | CNN::CNN(const std::string &OnnxFilePath, const std::string &SaveTrtFilePath, int BatchSize, int InputChannel, int InputImageWidth, int InputImageHeight) 8 | { 9 | OnnxFilePath_ = OnnxFilePath; 10 | SaveTrtFilePath_ = SaveTrtFilePath; 11 | 12 | BatchSize_ = BatchSize; 13 | InputChannel_ = InputChannel; 14 | InputImageWidth_ = InputImageWidth; 15 | InputImageHeight_ = InputImageHeight; 16 | 17 | ModelInit(); 18 | } 19 | 20 | CNN::~CNN() 21 | { 22 | // release the stream and the Buffers 23 | cudaStreamDestroy(Stream_); 24 | 25 | for (int i = 0; i < BuffersDataSize_.size(); i ++) 26 | { 27 | cudaFree(Buffers_[i]); 28 | if (i >= 1) 29 | { 30 | delete OutputData_[i-1]; 31 | } 32 | } 33 | 34 | // destroy the engine 35 | if (nullptr != PtrContext_) 36 | { 37 | PtrContext_->destroy(); 38 | } 39 | 40 | if (nullptr != PtrEngine_) 41 | { 42 | PtrEngine_->destroy(); 43 | } 44 | 45 | if (nullptr != GpuSrcImgBuf_) 46 | { 47 | cudaFree(GpuSrcImgBuf_); 48 | } 49 | 50 | if (nullptr != GpuImgResizeBuf_) 51 | { 52 | cudaFree(GpuImgResizeBuf_); 53 | } 54 | 55 | if (nullptr != GpuImgF32Buf_) 56 | { 57 | cudaFree(GpuImgF32Buf_); 58 | } 59 | 60 | if (nullptr != GpuDataPlanes_) 61 | { 62 | cudaFree(GpuDataPlanes_); 63 | } 64 | 65 | } 66 | 67 | void CNN::ModelInit() 68 | { 69 | std::fstream existEngine; 70 | existEngine.open(SaveTrtFilePath_, std::ios::in); 71 | if (existEngine) 72 | { 73 | ReadTrtFile(SaveTrtFilePath_, PtrEngine_); 74 | assert(PtrEngine_ != nullptr); 75 | } 76 | else 77 | { 78 | OnnxToTRTModel(OnnxFilePath_, SaveTrtFilePath_, PtrEngine_, BatchSize_); 79 | assert(PtrEngine_ != nullptr); 80 | } 81 | 82 | assert(PtrEngine_ != nullptr); 83 | PtrContext_ = PtrEngine_->createExecutionContext(); 84 | PtrContext_->setOptimizationProfile(0); 85 | auto InputDims = nvinfer1::Dims4 {BatchSize_, InputChannel_, InputImageHeight_, InputImageWidth_}; 86 | PtrContext_->setBindingDimensions(0, InputDims); 87 | 88 | cudaStreamCreate(&Stream_); 89 | 90 | int64_t TotalSize = 0; 91 | int nbBindings = PtrEngine_->getNbBindings(); 92 | BuffersDataSize_.resize(nbBindings); 93 | OutputData_.resize(nbBindings - 1); 94 | for (int i = 0; i < nbBindings; ++ i) 95 | { 96 | nvinfer1::Dims dims = PtrEngine_->getBindingDimensions(i); 97 | nvinfer1::DataType dtype = PtrEngine_->getBindingDataType(i); 98 | TotalSize = Volume(dims) * 1 * GetElementSize(dtype); 99 | 100 | BuffersDataSize_[i] = TotalSize; 101 | cudaMalloc(&Buffers_[i], TotalSize); 102 | if (i >= 1) 103 | { 104 | OutputData_[i - 1] = new float[int(TotalSize / sizeof(float))]; 105 | } 106 | 107 | if (0 == i) 108 | { 109 | std::cout << "input node name: "<< PtrEngine_->getBindingName(i) << ", dims: " << dims.nbDims << std::endl; 110 | } 111 | else 112 | { 113 | std::cout << "output node" << i - 1 << " name: "<< PtrEngine_->getBindingName(i) << ", dims: " << dims.nbDims << std::endl; 114 | } 115 | 116 | for (int j = 0; j < dims.nbDims; j++) 117 | { 118 | std::cout << "demension[" << j << "], size = " << dims.d[j] << std::endl; 119 | } 120 | } 121 | 122 | PreprocessResult_.resize(BatchSize_ * InputImageWidth_ * InputImageHeight_ * InputChannel_); 123 | } 124 | 125 | 126 | void CNN::Inference(cv::Mat &SrcImage) 127 | { 128 | DetectiontRects_.clear(); 129 | if(PtrContext_ == nullptr) 130 | { 131 | std::cout << "Error, PtrContext_" << std::endl; 132 | } 133 | 134 | // PrepareImage(SrcImage, PreprocessResult_); 135 | // cudaMemcpyAsync(Buffers_[0], PreprocessResult_.data(), BuffersDataSize_[0], cudaMemcpyHostToDevice, Stream_); 136 | PrepareImage(SrcImage, Buffers_[0]); 137 | 138 | PtrContext_->enqueueV2(Buffers_, Stream_, nullptr); 139 | 140 | for (int i = 1; i < BuffersDataSize_.size(); i++) 141 | { 142 | cudaMemcpyAsync(OutputData_[i - 1], Buffers_[i], BuffersDataSize_[i], cudaMemcpyDeviceToHost, Stream_); 143 | } 144 | 145 | cudaStreamSynchronize(Stream_); 146 | 147 | // Postprocess 148 | static GetResultRectYolov8 Postprocess; 149 | int ret = Postprocess.GetConvDetectionResult(OutputData_, DetectiontRects_); 150 | } 151 | 152 | void CNN::PrepareImage(cv::Mat &SrcImage, std::vector &PreprocessResult) 153 | { 154 | float *Imagedata = PreprocessResult.data(); 155 | 156 | cv::Mat rsz_img; 157 | cv::resize(SrcImage, rsz_img, cv::Size(InputImageWidth_, InputImageHeight_)); 158 | rsz_img.convertTo(rsz_img, CV_32FC3, 1.0 / 255); 159 | 160 | // HWC TO CHW 161 | int channelLength = InputImageWidth_ * InputImageHeight_; 162 | std::vector split_img = {cv::Mat(InputImageHeight_, InputImageWidth_, CV_32FC1, Imagedata + channelLength * 2), 163 | cv::Mat(InputImageHeight_, InputImageWidth_, CV_32FC1, Imagedata + channelLength * 1), 164 | cv::Mat(InputImageHeight_, InputImageWidth_, CV_32FC1, Imagedata + channelLength * 0)}; 165 | 166 | cv::split(rsz_img, split_img); 167 | } 168 | 169 | 170 | void CNN::PrepareImage(cv::Mat &SrcImage, void *InputBuffer) 171 | { 172 | int src_width = SrcImage.cols; 173 | int src_height = SrcImage.rows; 174 | int src_channel = SrcImage.channels(); 175 | 176 | NppiSize dstSize = {InputImageWidth_, InputImageHeight_}; 177 | NppiRect dstROI = {0, 0, InputImageWidth_, InputImageHeight_}; 178 | if (GpuImgResizeBuf_ == nullptr) 179 | { 180 | cudaMalloc(&GpuImgResizeBuf_, InputImageWidth_ * InputImageHeight_ * src_channel * sizeof(uchar)); 181 | cudaMalloc(&GpuImgF32Buf_, InputImageWidth_ * InputImageHeight_ * src_channel * sizeof(float)); 182 | cudaMalloc(&GpuDataPlanes_, InputImageWidth_ * InputImageHeight_ * src_channel * sizeof(float)); 183 | } 184 | 185 | 186 | NppiSize srcSize = {src_width, src_height}; 187 | NppiRect srcROI = {0, 0, src_width, src_height}; 188 | if(GpuSrcImgBuf_ == nullptr) 189 | { 190 | cudaMalloc(&GpuSrcImgBuf_, src_width * src_height * src_channel * sizeof(uchar)); 191 | } 192 | 193 | DstPlanes_[0] = GpuDataPlanes_; 194 | DstPlanes_[1] = GpuDataPlanes_ + InputImageWidth_ * InputImageHeight_; 195 | DstPlanes_[2] = GpuDataPlanes_ + InputImageWidth_ * InputImageHeight_ * 2; 196 | 197 | // 将cpu图像拷贝到gpu 198 | cudaMemcpy(GpuSrcImgBuf_, (void *)SrcImage.data, src_width * src_height * src_channel, cudaMemcpyHostToDevice); 199 | 200 | // resize 201 | nppiResize_8u_C3R(GpuSrcImgBuf_, src_width * src_channel, srcSize, srcROI, GpuImgResizeBuf_, InputImageWidth_ * src_channel, dstSize, dstROI, NPPI_INTER_LINEAR); 202 | 203 | // bgr 转 rgb 204 | nppiSwapChannels_8u_C3IR(GpuImgResizeBuf_, InputImageWidth_ * src_channel, dstSize, DstOrder_); 205 | 206 | // int8(uchar) 转 f32 207 | nppiConvert_8u32f_C3R(GpuImgResizeBuf_, InputImageWidth_ * src_channel, GpuImgF32Buf_, InputImageWidth_ * src_channel * sizeof(float), dstSize); 208 | 209 | // 减均值、除方差 210 | nppiMulC_32f_C3IR(MeanScale_, GpuImgF32Buf_, InputImageWidth_ * src_channel * sizeof(float), dstSize); 211 | 212 | nppiCopy_32f_C3P3R(GpuImgF32Buf_, InputImageWidth_ * src_channel * sizeof(float), DstPlanes_, InputImageWidth_ * sizeof(float), dstSize); 213 | 214 | cudaMemcpyAsync(InputBuffer, GpuDataPlanes_, src_channel * InputImageWidth_ * InputImageHeight_ * sizeof(float), cudaMemcpyDeviceToDevice, Stream_); 215 | } 216 | -------------------------------------------------------------------------------- /src/CNN.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CNN_HPP 2 | #define CNN_HPP 3 | 4 | #include "NvInfer.h" 5 | #include "postprocess.hpp" 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | 14 | class CNN 15 | { 16 | public: 17 | CNN(const std::string &OnnxFilePath, const std::string &SaveTrtFilePath, int BatchSize, int InputChannel, int InputImageWidth, int InputImageHeight); 18 | ~CNN(); 19 | 20 | void Inference(cv::Mat &SrcImage); 21 | 22 | std::vector DetectiontRects_; 23 | 24 | private: 25 | void ModelInit(); 26 | void PrepareImage(cv::Mat &vec_img, std::vector &PreprocessResult); 27 | void PrepareImage(cv::Mat &vec_img, void *InputBuffer); 28 | 29 | std::string OnnxFilePath_; 30 | std::string SaveTrtFilePath_; 31 | 32 | int BatchSize_ = 0; 33 | int InputChannel_ = 0; 34 | int InputImageWidth_ = 0; 35 | int InputImageHeight_ = 0; 36 | int ModelOutputSize_ = 0; 37 | 38 | nvinfer1::ICudaEngine *PtrEngine_ = nullptr; 39 | nvinfer1::IExecutionContext *PtrContext_ = nullptr; 40 | cudaStream_t Stream_; 41 | 42 | void *Buffers_[10]; 43 | std::vector BuffersDataSize_; 44 | std::vector OutputData_; 45 | std::vector PreprocessResult_; 46 | 47 | Npp8u *GpuSrcImgBuf_ = nullptr; // gpu:装 src 图像 48 | Npp8u *GpuImgResizeBuf_ = nullptr; // gpu 装 resize后的图像 49 | Npp32f *GpuImgF32Buf_ = nullptr; // gpu: int8 转 F32 50 | Npp32f *GpuDataPlanes_ = nullptr; 51 | 52 | Npp32f MeanScale_[3] = {0.00392157, 0.00392157, 0.00392157}; 53 | int DstOrder_[3] = {2, 1, 0}; 54 | Npp32f* DstPlanes_[3]; 55 | }; 56 | 57 | #endif 58 | -------------------------------------------------------------------------------- /src/common/common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_HPP 2 | #define COMMON_HPP 3 | 4 | #include "NvOnnxParser.h" 5 | #include "logging.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | // These is necessary if we want to be able to write 1_GiB instead of 1.0_GiB. 12 | // Since the return type is signed, -1_GiB will work as expected. 13 | constexpr long long int operator"" _GiB(long long unsigned int val) 14 | { 15 | return val * (1 << 30); 16 | } 17 | constexpr long long int operator"" _MiB(long long unsigned int val) 18 | { 19 | return val * (1 << 20); 20 | } 21 | constexpr long long int operator"" _KiB(long long unsigned int val) 22 | { 23 | return val * (1 << 10); 24 | } 25 | 26 | inline unsigned int GetElementSize(nvinfer1::DataType t) 27 | { 28 | switch (t) 29 | { 30 | case nvinfer1::DataType::kINT32: 31 | return 4; 32 | case nvinfer1::DataType::kFLOAT: 33 | return 4; 34 | case nvinfer1::DataType::kHALF: 35 | return 2; 36 | case nvinfer1::DataType::kBOOL: 37 | case nvinfer1::DataType::kINT8: 38 | return 1; 39 | } 40 | throw std::runtime_error("Invalid DataType."); 41 | return 0; 42 | } 43 | 44 | inline int64_t Volume(const nvinfer1::Dims &d) 45 | { 46 | return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies()); 47 | } 48 | 49 | Logger gLogger{Logger::Severity::kINFO}; 50 | LogStreamConsumer gLogVerbose{LOG_VERBOSE(gLogger)}; 51 | LogStreamConsumer gLogInfo{LOG_INFO(gLogger)}; 52 | LogStreamConsumer gLogWarning{LOG_WARN(gLogger)}; 53 | LogStreamConsumer gLogError{LOG_ERROR(gLogger)}; 54 | LogStreamConsumer gLogFatal{LOG_FATAL(gLogger)}; 55 | 56 | void setReportableSeverity(Logger::Severity severity) 57 | { 58 | gLogger.setReportableSeverity(severity); 59 | gLogVerbose.setReportableSeverity(severity); 60 | gLogInfo.setReportableSeverity(severity); 61 | gLogWarning.setReportableSeverity(severity); 62 | gLogError.setReportableSeverity(severity); 63 | gLogFatal.setReportableSeverity(severity); 64 | } 65 | 66 | bool ReadTrtFile(const std::string &engineFile, nvinfer1::ICudaEngine *&engine) 67 | { 68 | std::string cached_engine; 69 | std::fstream file; 70 | std::cout << "loading filename from:" << engineFile << std::endl; 71 | nvinfer1::IRuntime *trtRuntime; 72 | file.open(engineFile, std::ios::binary | std::ios::in); 73 | 74 | if (!file.is_open()) 75 | { 76 | std::cout << "read file error: " << engineFile << std::endl; 77 | cached_engine = ""; 78 | } 79 | 80 | while (file.peek() != EOF) 81 | { 82 | std::stringstream buffer; 83 | buffer << file.rdbuf(); 84 | cached_engine.append(buffer.str()); 85 | } 86 | file.close(); 87 | 88 | trtRuntime = nvinfer1::createInferRuntime(gLogger.getTRTLogger()); 89 | engine = trtRuntime->deserializeCudaEngine(cached_engine.data(), cached_engine.size(), nullptr); 90 | std::cout << "deserialize done" << std::endl; 91 | 92 | return true; 93 | } 94 | 95 | void OnnxToTRTModel(const std::string &modelFile, const std::string &filename, nvinfer1::ICudaEngine *&engine, const int &BATCH_SIZE) 96 | { 97 | // create the builder 98 | nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(gLogger.getTRTLogger()); 99 | assert(builder != nullptr); 100 | 101 | const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 102 | auto network = builder->createNetworkV2(explicitBatch); 103 | auto config = builder->createBuilderConfig(); 104 | 105 | auto parser = nvonnxparser::createParser(*network, gLogger.getTRTLogger()); 106 | if (!parser->parseFromFile(modelFile.c_str(), static_cast(gLogger.getReportableSeverity()))) 107 | { 108 | gLogError << "Failure while parsing ONNX file" << std::endl; 109 | } 110 | 111 | // Build the engine 112 | builder->setMaxBatchSize(BATCH_SIZE); 113 | config->setMaxWorkspaceSize(1_GiB); 114 | config->setFlag(nvinfer1::BuilderFlag::kFP16); 115 | 116 | std::cout << "start building engine" << std::endl; 117 | engine = builder->buildEngineWithConfig(*network, *config); 118 | std::cout << "build engine done" << std::endl; 119 | assert(engine); 120 | 121 | // we can destroy the parser 122 | parser->destroy(); 123 | // save engine 124 | nvinfer1::IHostMemory *data = engine->serialize(); 125 | std::ofstream file; 126 | file.open(filename, std::ios::binary | std::ios::out); 127 | std::cout << "writing engine file..." << std::endl; 128 | file.write((const char *)data->data(), data->size()); 129 | std::cout << "save engine file done" << std::endl; 130 | file.close(); 131 | 132 | // then close everything down 133 | network->destroy(); 134 | builder->destroy(); 135 | } 136 | 137 | #endif 138 | -------------------------------------------------------------------------------- /src/common/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | using Severity = nvinfer1::ILogger::Severity; 30 | 31 | class LogStreamConsumerBuffer : public std::stringbuf 32 | { 33 | public: 34 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 35 | : mOutput(stream) 36 | , mPrefix(prefix) 37 | , mShouldLog(shouldLog) 38 | { 39 | } 40 | 41 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 42 | : mOutput(other.mOutput) 43 | { 44 | } 45 | 46 | ~LogStreamConsumerBuffer() 47 | { 48 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 49 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 50 | // if the pointer to the beginning is not equal to the pointer to the current position, 51 | // call putOutput() to log the output to the stream 52 | if (pbase() != pptr()) 53 | { 54 | putOutput(); 55 | } 56 | } 57 | 58 | // synchronizes the stream buffer and returns 0 on success 59 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 60 | // resetting the buffer and flushing the stream 61 | virtual int sync() 62 | { 63 | putOutput(); 64 | return 0; 65 | } 66 | 67 | void putOutput() 68 | { 69 | if (mShouldLog) 70 | { 71 | // prepend timestamp 72 | std::time_t timestamp = std::time(nullptr); 73 | tm* tm_local = std::localtime(×tamp); 74 | std::cout << "["; 75 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 76 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 77 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 78 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 81 | // std::stringbuf::str() gets the string contents of the buffer 82 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 83 | mOutput << mPrefix << str(); 84 | // set the buffer to empty 85 | str(""); 86 | // flush the stream 87 | mOutput.flush(); 88 | } 89 | } 90 | 91 | void setShouldLog(bool shouldLog) 92 | { 93 | mShouldLog = shouldLog; 94 | } 95 | 96 | private: 97 | std::ostream& mOutput; 98 | std::string mPrefix; 99 | bool mShouldLog; 100 | }; 101 | 102 | //! 103 | //! \class LogStreamConsumerBase 104 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 105 | //! 106 | class LogStreamConsumerBase 107 | { 108 | public: 109 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 110 | : mBuffer(stream, prefix, shouldLog) 111 | { 112 | } 113 | 114 | protected: 115 | LogStreamConsumerBuffer mBuffer; 116 | }; 117 | 118 | //! 119 | //! \class LogStreamConsumer 120 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 121 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 122 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 123 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 124 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 125 | //! Please do not change the order of the parent classes. 126 | //! 127 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 128 | { 129 | public: 130 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 131 | //! Reportable severity determines if the messages are severe enough to be logged. 132 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 133 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 134 | , std::ostream(&mBuffer) // links the stream buffer with the stream 135 | , mShouldLog(severity <= reportableSeverity) 136 | , mSeverity(severity) 137 | { 138 | } 139 | 140 | LogStreamConsumer(LogStreamConsumer&& other) 141 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 142 | , std::ostream(&mBuffer) // links the stream buffer with the stream 143 | , mShouldLog(other.mShouldLog) 144 | , mSeverity(other.mSeverity) 145 | { 146 | } 147 | 148 | void setReportableSeverity(Severity reportableSeverity) 149 | { 150 | mShouldLog = mSeverity <= reportableSeverity; 151 | mBuffer.setShouldLog(mShouldLog); 152 | } 153 | 154 | private: 155 | static std::ostream& severityOstream(Severity severity) 156 | { 157 | return severity >= Severity::kINFO ? std::cout : std::cerr; 158 | } 159 | 160 | static std::string severityPrefix(Severity severity) 161 | { 162 | switch (severity) 163 | { 164 | case Severity::kINTERNAL_ERROR: return "[F] "; 165 | case Severity::kERROR: return "[E] "; 166 | case Severity::kWARNING: return "[W] "; 167 | case Severity::kINFO: return "[I] "; 168 | case Severity::kVERBOSE: return "[V] "; 169 | default: assert(0); return ""; 170 | } 171 | } 172 | 173 | bool mShouldLog; 174 | Severity mSeverity; 175 | }; 176 | 177 | //! \class Logger 178 | //! 179 | //! \brief Class which manages logging of TensorRT tools and samples 180 | //! 181 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 182 | //! and supports logging two types of messages: 183 | //! 184 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 185 | //! - Test pass/fail messages 186 | //! 187 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 188 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 189 | //! 190 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 191 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 192 | //! 193 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 194 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 195 | //! library and messages coming from the sample. 196 | //! 197 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 198 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 199 | //! object. 200 | 201 | class Logger : public nvinfer1::ILogger 202 | { 203 | public: 204 | Logger(Severity severity = Severity::kWARNING) 205 | : mReportableSeverity(severity) 206 | { 207 | } 208 | 209 | //! 210 | //! \enum TestResult 211 | //! \brief Represents the state of a given test 212 | //! 213 | enum class TestResult 214 | { 215 | kRUNNING, //!< The test is running 216 | kPASSED, //!< The test passed 217 | kFAILED, //!< The test failed 218 | kWAIVED //!< The test was waived 219 | }; 220 | 221 | //! 222 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 223 | //! \return The nvinfer1::ILogger associated with this Logger 224 | //! 225 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 226 | //! we can eliminate the inheritance of Logger from ILogger 227 | //! 228 | nvinfer1::ILogger& getTRTLogger() 229 | { 230 | return *this; 231 | } 232 | 233 | //! 234 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 235 | //! 236 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 237 | //! inheritance from nvinfer1::ILogger 238 | //! 239 | void log(Severity severity, const char* msg) noexcept 240 | { 241 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 242 | } 243 | 244 | //! 245 | //! \brief Method for controlling the verbosity of logging output 246 | //! 247 | //! \param severity The logger will only emit messages that have severity of this level or higher. 248 | //! 249 | void setReportableSeverity(Severity severity) 250 | { 251 | mReportableSeverity = severity; 252 | } 253 | 254 | //! 255 | //! \brief Opaque handle that holds logging information for a particular test 256 | //! 257 | //! This object is an opaque handle to information used by the Logger to print test results. 258 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 259 | //! with Logger::reportTest{Start,End}(). 260 | //! 261 | class TestAtom 262 | { 263 | public: 264 | TestAtom(TestAtom&&) = default; 265 | 266 | private: 267 | friend class Logger; 268 | 269 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 270 | : mStarted(started) 271 | , mName(name) 272 | , mCmdline(cmdline) 273 | { 274 | } 275 | 276 | bool mStarted; 277 | std::string mName; 278 | std::string mCmdline; 279 | }; 280 | 281 | //! 282 | //! \brief Define a test for logging 283 | //! 284 | //! \param[in] name The name of the test. This should be a string starting with 285 | //! "TensorRT" and containing dot-separated strings containing 286 | //! the characters [A-Za-z0-9_]. 287 | //! For example, "TensorRT.sample_googlenet" 288 | //! \param[in] cmdline The command line used to reproduce the test 289 | // 290 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 291 | //! 292 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 293 | { 294 | return TestAtom(false, name, cmdline); 295 | } 296 | 297 | //! 298 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 299 | //! as input 300 | //! 301 | //! \param[in] name The name of the test 302 | //! \param[in] argc The number of command-line arguments 303 | //! \param[in] argv The array of command-line arguments (given as C strings) 304 | //! 305 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 306 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 307 | { 308 | auto cmdline = genCmdlineString(argc, argv); 309 | return defineTest(name, cmdline); 310 | } 311 | 312 | //! 313 | //! \brief Report that a test has started. 314 | //! 315 | //! \pre reportTestStart() has not been called yet for the given testAtom 316 | //! 317 | //! \param[in] testAtom The handle to the test that has started 318 | //! 319 | static void reportTestStart(TestAtom& testAtom) 320 | { 321 | reportTestResult(testAtom, TestResult::kRUNNING); 322 | assert(!testAtom.mStarted); 323 | testAtom.mStarted = true; 324 | } 325 | 326 | //! 327 | //! \brief Report that a test has ended. 328 | //! 329 | //! \pre reportTestStart() has been called for the given testAtom 330 | //! 331 | //! \param[in] testAtom The handle to the test that has ended 332 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 333 | //! TestResult::kFAILED, TestResult::kWAIVED 334 | //! 335 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 336 | { 337 | assert(result != TestResult::kRUNNING); 338 | assert(testAtom.mStarted); 339 | reportTestResult(testAtom, result); 340 | } 341 | 342 | static int reportPass(const TestAtom& testAtom) 343 | { 344 | reportTestEnd(testAtom, TestResult::kPASSED); 345 | return EXIT_SUCCESS; 346 | } 347 | 348 | static int reportFail(const TestAtom& testAtom) 349 | { 350 | reportTestEnd(testAtom, TestResult::kFAILED); 351 | return EXIT_FAILURE; 352 | } 353 | 354 | static int reportWaive(const TestAtom& testAtom) 355 | { 356 | reportTestEnd(testAtom, TestResult::kWAIVED); 357 | return EXIT_SUCCESS; 358 | } 359 | 360 | static int reportTest(const TestAtom& testAtom, bool pass) 361 | { 362 | return pass ? reportPass(testAtom) : reportFail(testAtom); 363 | } 364 | 365 | Severity getReportableSeverity() const 366 | { 367 | return mReportableSeverity; 368 | } 369 | 370 | private: 371 | //! 372 | //! \brief returns an appropriate string for prefixing a log message with the given severity 373 | //! 374 | static const char* severityPrefix(Severity severity) 375 | { 376 | switch (severity) 377 | { 378 | case Severity::kINTERNAL_ERROR: return "[F] "; 379 | case Severity::kERROR: return "[E] "; 380 | case Severity::kWARNING: return "[W] "; 381 | case Severity::kINFO: return "[I] "; 382 | case Severity::kVERBOSE: return "[V] "; 383 | default: assert(0); return ""; 384 | } 385 | } 386 | 387 | //! 388 | //! \brief returns an appropriate string for prefixing a test result message with the given result 389 | //! 390 | static const char* testResultString(TestResult result) 391 | { 392 | switch (result) 393 | { 394 | case TestResult::kRUNNING: return "RUNNING"; 395 | case TestResult::kPASSED: return "PASSED"; 396 | case TestResult::kFAILED: return "FAILED"; 397 | case TestResult::kWAIVED: return "WAIVED"; 398 | default: assert(0); return ""; 399 | } 400 | } 401 | 402 | //! 403 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 404 | //! 405 | static std::ostream& severityOstream(Severity severity) 406 | { 407 | return severity >= Severity::kINFO ? std::cout : std::cerr; 408 | } 409 | 410 | //! 411 | //! \brief method that implements logging test results 412 | //! 413 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 414 | { 415 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 416 | << testAtom.mCmdline << std::endl; 417 | } 418 | 419 | //! 420 | //! \brief generate a command line string from the given (argc, argv) values 421 | //! 422 | static std::string genCmdlineString(int argc, char const* const* argv) 423 | { 424 | std::stringstream ss; 425 | for (int i = 0; i < argc; i++) 426 | { 427 | if (i > 0) 428 | ss << " "; 429 | ss << argv[i]; 430 | } 431 | return ss.str(); 432 | } 433 | 434 | Severity mReportableSeverity; 435 | }; 436 | 437 | namespace 438 | { 439 | 440 | //! 441 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 442 | //! 443 | //! Example usage: 444 | //! 445 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 446 | //! 447 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 448 | { 449 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 450 | } 451 | 452 | //! 453 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 454 | //! 455 | //! Example usage: 456 | //! 457 | //! LOG_INFO(logger) << "hello world" << std::endl; 458 | //! 459 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 460 | { 461 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 462 | } 463 | 464 | //! 465 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 466 | //! 467 | //! Example usage: 468 | //! 469 | //! LOG_WARN(logger) << "hello world" << std::endl; 470 | //! 471 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 472 | { 473 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 474 | } 475 | 476 | //! 477 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 478 | //! 479 | //! Example usage: 480 | //! 481 | //! LOG_ERROR(logger) << "hello world" << std::endl; 482 | //! 483 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 484 | { 485 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 486 | } 487 | 488 | //! 489 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 490 | // ("fatal" severity) 491 | //! 492 | //! Example usage: 493 | //! 494 | //! LOG_FATAL(logger) << "hello world" << std::endl; 495 | //! 496 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 497 | { 498 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 499 | } 500 | 501 | } // anonymous namespace 502 | 503 | #endif // TENSORRT_LOGGING_H 504 | -------------------------------------------------------------------------------- /src/postprocess.cpp: -------------------------------------------------------------------------------- 1 | #include "postprocess.hpp" 2 | #include 3 | #include 4 | 5 | #define ZQ_MAX(a, b) ((a) > (b) ? (a) : (b)) 6 | #define ZQ_MIN(a, b) ((a) < (b) ? (a) : (b)) 7 | 8 | static inline float fast_exp(float x) 9 | { 10 | // return exp(x); 11 | union 12 | { 13 | uint32_t i; 14 | float f; 15 | } v; 16 | v.i = (12102203.1616540672 * x + 1064807160.56887296); 17 | return v.f; 18 | } 19 | 20 | static inline float IOU(float XMin1, float YMin1, float XMax1, float YMax1, float XMin2, float YMin2, float XMax2, float YMax2) 21 | { 22 | float Inter = 0; 23 | float Total = 0; 24 | float XMin = 0; 25 | float YMin = 0; 26 | float XMax = 0; 27 | float YMax = 0; 28 | float Area1 = 0; 29 | float Area2 = 0; 30 | float InterWidth = 0; 31 | float InterHeight = 0; 32 | 33 | XMin = ZQ_MAX(XMin1, XMin2); 34 | YMin = ZQ_MAX(YMin1, YMin2); 35 | XMax = ZQ_MIN(XMax1, XMax2); 36 | YMax = ZQ_MIN(YMax1, YMax2); 37 | 38 | InterWidth = XMax - XMin; 39 | InterHeight = YMax - YMin; 40 | 41 | InterWidth = (InterWidth >= 0) ? InterWidth : 0; 42 | InterHeight = (InterHeight >= 0) ? InterHeight : 0; 43 | 44 | Inter = InterWidth * InterHeight; 45 | 46 | Area1 = (XMax1 - XMin1) * (YMax1 - YMin1); 47 | Area2 = (XMax2 - XMin2) * (YMax2 - YMin2); 48 | 49 | Total = Area1 + Area2 - Inter; 50 | 51 | return float(Inter) / float(Total); 52 | } 53 | 54 | /****** yolov8 ****/ 55 | GetResultRectYolov8::GetResultRectYolov8() 56 | { 57 | } 58 | 59 | GetResultRectYolov8::~GetResultRectYolov8() 60 | { 61 | } 62 | 63 | float GetResultRectYolov8::sigmoid(float x) 64 | { 65 | return 1 / (1 + fast_exp(-x)); 66 | } 67 | 68 | int GetResultRectYolov8::GenerateMeshgrid() 69 | { 70 | int ret = 0; 71 | if (headNum == 0) 72 | { 73 | printf("=== yolov8 Meshgrid Generate failed! \n"); 74 | } 75 | 76 | for (int index = 0; index < headNum; index++) 77 | { 78 | for (int i = 0; i < mapSize[index][0]; i++) 79 | { 80 | for (int j = 0; j < mapSize[index][1]; j++) 81 | { 82 | meshgrid.push_back(float(j + 0.5)); 83 | meshgrid.push_back(float(i + 0.5)); 84 | } 85 | } 86 | } 87 | 88 | // printf("=== yolov8 Meshgrid Generate success! \n"); 89 | 90 | return ret; 91 | } 92 | 93 | int GetResultRectYolov8::GetConvDetectionResult(std::vector &BlobPtr, std::vector &DetectiontRects) 94 | { 95 | int ret = 0; 96 | if (meshgrid.empty()) 97 | { 98 | ret = GenerateMeshgrid(); 99 | } 100 | 101 | int gridIndex = -2; 102 | float xmin = 0, ymin = 0, xmax = 0, ymax = 0; 103 | float cls_val = 0; 104 | float cls_max = 0; 105 | int cls_index = 0; 106 | 107 | DetectRect temp; 108 | std::vector detectRects; 109 | 110 | for (int index = 0; index < headNum; index++) 111 | { 112 | float *cls = BlobPtr[index * 2 + 0]; 113 | float *reg = BlobPtr[index * 2 + 1]; 114 | 115 | for (int h = 0; h < mapSize[index][0]; h++) 116 | { 117 | for (int w = 0; w < mapSize[index][1]; w++) 118 | { 119 | gridIndex += 2; 120 | 121 | for (int cl = 0; cl < class_num; cl++) 122 | { 123 | cls_val = sigmoid(cls[cl * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]); 124 | 125 | if (0 == cl) 126 | { 127 | cls_max = cls_val; 128 | cls_index = cl; 129 | } 130 | else 131 | { 132 | if (cls_val > cls_max) 133 | { 134 | cls_max = cls_val; 135 | cls_index = cl; 136 | } 137 | } 138 | } 139 | 140 | if (cls_max > objectThresh) 141 | { 142 | xmin = (meshgrid[gridIndex + 0] - reg[0 * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) * strides[index]; 143 | ymin = (meshgrid[gridIndex + 1] - reg[1 * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) * strides[index]; 144 | xmax = (meshgrid[gridIndex + 0] + reg[2 * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) * strides[index]; 145 | ymax = (meshgrid[gridIndex + 1] + reg[3 * mapSize[index][0] * mapSize[index][1] + h * mapSize[index][1] + w]) * strides[index]; 146 | 147 | xmin = xmin > 0 ? xmin : 0; 148 | ymin = ymin > 0 ? ymin : 0; 149 | xmax = xmax < input_w ? xmax : input_w; 150 | ymax = ymax < input_h ? ymax : input_h; 151 | 152 | if (xmin >= 0 && ymin >= 0 && xmax <= input_w && ymax <= input_h) 153 | { 154 | temp.xmin = xmin / input_w; 155 | temp.ymin = ymin / input_h; 156 | temp.xmax = xmax / input_w; 157 | temp.ymax = ymax / input_h; 158 | temp.classId = cls_index; 159 | temp.score = cls_max; 160 | detectRects.push_back(temp); 161 | } 162 | } 163 | } 164 | } 165 | } 166 | 167 | std::sort(detectRects.begin(), detectRects.end(), [](DetectRect &Rect1, DetectRect &Rect2) -> bool 168 | { return (Rect1.score > Rect2.score); }); 169 | 170 | std::cout << "NMS Before num :" << detectRects.size() << std::endl; 171 | for (int i = 0; i < detectRects.size(); ++i) 172 | { 173 | float xmin1 = detectRects[i].xmin; 174 | float ymin1 = detectRects[i].ymin; 175 | float xmax1 = detectRects[i].xmax; 176 | float ymax1 = detectRects[i].ymax; 177 | int classId = detectRects[i].classId; 178 | float score = detectRects[i].score; 179 | 180 | if (classId != -1) 181 | { 182 | // 将检测结果按照classId、score、xmin1、ymin1、xmax1、ymax1 的格式存放在vector中 183 | DetectiontRects.push_back(float(classId)); 184 | DetectiontRects.push_back(float(score)); 185 | DetectiontRects.push_back(float(xmin1)); 186 | DetectiontRects.push_back(float(ymin1)); 187 | DetectiontRects.push_back(float(xmax1)); 188 | DetectiontRects.push_back(float(ymax1)); 189 | 190 | for (int j = i + 1; j < detectRects.size(); ++j) 191 | { 192 | float xmin2 = detectRects[j].xmin; 193 | float ymin2 = detectRects[j].ymin; 194 | float xmax2 = detectRects[j].xmax; 195 | float ymax2 = detectRects[j].ymax; 196 | float iou = IOU(xmin1, ymin1, xmax1, ymax1, xmin2, ymin2, xmax2, ymax2); 197 | if (iou > nmsThresh) 198 | { 199 | detectRects[j].classId = -1; 200 | } 201 | } 202 | } 203 | } 204 | 205 | return ret; 206 | } 207 | -------------------------------------------------------------------------------- /src/postprocess.hpp: -------------------------------------------------------------------------------- 1 | #ifndef _POSTPROCESS_H_ 2 | #define _POSTPROCESS_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | typedef struct 10 | { 11 | float xmin; 12 | float ymin; 13 | float xmax; 14 | float ymax; 15 | float score; 16 | int classId; 17 | } DetectRect; 18 | 19 | // yolov8 20 | class GetResultRectYolov8 21 | { 22 | public: 23 | GetResultRectYolov8(); 24 | 25 | ~GetResultRectYolov8(); 26 | 27 | int GenerateMeshgrid(); 28 | 29 | int GetConvDetectionResult(std::vector &BlobPtr, std::vector &DetectiontRects); 30 | 31 | float sigmoid(float x); 32 | 33 | private: 34 | std::vector meshgrid; 35 | 36 | const int class_num = 80; 37 | int headNum = 3; 38 | 39 | int input_w = 640; 40 | int input_h = 640; 41 | int strides[3] = {8, 16, 32}; 42 | int mapSize[3][2] = {{80, 80}, {40, 40}, {20, 20}}; 43 | 44 | float nmsThresh = 0.45; 45 | float objectThresh = 0.5; 46 | }; 47 | 48 | #endif 49 | --------------------------------------------------------------------------------