├── .gitignore ├── CMakeLists.txt ├── inference.cpp ├── inference.h ├── main.cpp └── scripts ├── build.sh ├── run.sh └── start.sh /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | models/* 3 | onnxruntime-linux-cpu-1.17.0/ 4 | onnxruntime-linux-cpu-1.15.0/ 5 | onnxruntime-linux-gpu-1.17.0/ 6 | onnxruntime-linux-gpu-1.15.0/ 7 | onnxruntime-win-cpu-1.17.0/ 8 | onnxruntime-win-gpu-1.17.0/ 9 | .vscode/ 10 | images/origin_data 11 | output/ 12 | images/ 13 | configs/ 14 | val_ng_cls_output/ 15 | val_ok_cls_output/ 16 | cls_output/ 17 | det_seg_output/ 18 | DZtest/ 19 | huachuan/ 20 | main_.cpp 21 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.5) 2 | 3 | set(PROJECT_NAME Yolov8OnnxRuntimeCPPInference) 4 | project(${PROJECT_NAME} VERSION 0.0.1 LANGUAGES CXX) 5 | 6 | 7 | # -------------- Support C++17 for using filesystem ------------------# 8 | set(CMAKE_CXX_STANDARD 17) 9 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 10 | set(CMAKE_CXX_EXTENSIONS ON) 11 | set(CMAKE_INCLUDE_CURRENT_DIR ON) 12 | message("c_CMAKE_CURRENT_SOURCE_DIR: " ${CMAKE_CURRENT_SOURCE_DIR} ) 13 | 14 | 15 | # -------------- OpenCV ------------------# 16 | find_package(OpenCV REQUIRED) 17 | include_directories(${OpenCV_INCLUDE_DIRS}) 18 | 19 | # -------------- Compile CUDA for FP16 inference if needed ------------------# 20 | option(USE_CUDA "Enable CUDA support" ON) 21 | if (NOT APPLE AND USE_CUDA) 22 | find_package(CUDA REQUIRED) 23 | include_directories(${CUDA_INCLUDE_DIRS}) 24 | add_definitions(-DUSE_CUDA) 25 | else () 26 | set(USE_CUDA OFF) 27 | endif () 28 | 29 | # -------------- ONNXRUNTIME ------------------# 30 | set(ONNXRUNTIME_VERSION 1.17.0) 31 | 32 | if (WIN32) 33 | if (USE_CUDA) 34 | set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-win-gpu-${ONNXRUNTIME_VERSION}") 35 | else () 36 | set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-win-cpu-${ONNXRUNTIME_VERSION}") 37 | endif () 38 | elseif (LINUX) 39 | if (USE_CUDA) 40 | set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-linux-gpu-${ONNXRUNTIME_VERSION}") 41 | else () 42 | set(ONNXRUNTIME_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime-linux-cpu-${ONNXRUNTIME_VERSION}") 43 | endif () 44 | endif () 45 | 46 | include_directories(${PROJECT_NAME} ${ONNXRUNTIME_ROOT}/include) 47 | set(PROJECT_SOURCES inference.h inference.cpp main.cpp) 48 | add_executable(${PROJECT_NAME} ${PROJECT_SOURCES}) 49 | 50 | if (WIN32) 51 | target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} ${ONNXRUNTIME_ROOT}/lib/onnxruntime.lib) 52 | if (USE_CUDA) 53 | target_link_libraries(${PROJECT_NAME} ${CUDA_LIBRARIES}) 54 | endif () 55 | elseif (LINUX) 56 | target_link_libraries(${PROJECT_NAME} ${OpenCV_LIBS} 57 | ${ONNXRUNTIME_ROOT}/lib/libonnxruntime.so 58 | ) 59 | if (USE_CUDA) 60 | target_link_libraries(${PROJECT_NAME} ${CUDA_LIBRARIES}) 61 | endif () 62 | endif () 63 | 64 | -------------------------------------------------------------------------------- /inference.cpp: -------------------------------------------------------------------------------- 1 | #define _CRT_SECURE_NO_WARNINGS 2 | #include "inference.h" 3 | #include 4 | #include 5 | #include 6 | #define benchmark 7 | 8 | YOLO_V8::YOLO_V8() { 9 | 10 | } 11 | 12 | YOLO_V8::~YOLO_V8() { 13 | delete session; 14 | } 15 | 16 | #ifdef USE_CUDA 17 | namespace Ort 18 | { 19 | template<> 20 | struct TypeToTensorType { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; 21 | } 22 | #endif 23 | 24 | 25 | std::vector YOLO_V8::Inference(const std::string& imagePath, const std::string& txtPath) { 26 | 27 | std::vector classNames; 28 | std::vector results; 29 | 30 | if (ReadClassNames(txtPath, classNames) != 0) { 31 | std::cerr << "[YOLO_V8]: Failed to read class names" << std::endl; 32 | return results; 33 | } 34 | classes = std::move(classNames); 35 | 36 | cv::Mat image = cv::imread(imagePath); 37 | if (image.empty()) { 38 | std::cerr << "[YOLO_V8]: Failed to load image" << std::endl; 39 | return results; 40 | } 41 | 42 | auto starttime_4 = std::chrono::high_resolution_clock::now(); 43 | 44 | std::vector res; 45 | if (RunSession(image, res) != 0) { 46 | std::cerr << "[YOLO_V8]: Failed to run session" << std::endl; 47 | return results; 48 | } 49 | auto starttime_3 = std::chrono::high_resolution_clock::now(); 50 | auto duration_ms3 = std::chrono::duration_cast(starttime_3 - starttime_4).count(); 51 | std::cout << "[YOLO_V8]: RunSession时间: " << duration_ms3 << " ms" << std::endl; 52 | 53 | if (modelType == YOLO_CLS_V8 ) { 54 | float maxConfidence = 0; 55 | int maxIndex = -1; 56 | 57 | for (int i = 0; i < res.size(); i++) 58 | { 59 | auto probs = res.at(i); 60 | if (probs.confidence > maxConfidence) 61 | { 62 | maxConfidence = probs.confidence; 63 | maxIndex = i; 64 | } 65 | } 66 | 67 | if (maxIndex != -1) { 68 | auto max_probs = res.at(maxIndex); 69 | int predict_label = max_probs.classId; 70 | auto predict_name = classes[predict_label]; 71 | float confidence = max_probs.confidence; 72 | max_probs.className = predict_name; 73 | results.push_back(max_probs); 74 | } 75 | } 76 | else { 77 | for (const auto& result : res) { 78 | results.push_back(result); 79 | } 80 | } 81 | 82 | return results; 83 | } 84 | 85 | 86 | 87 | template 88 | char* BlobFromImage(cv::Mat& iImg, T& iBlob) { 89 | int channels = iImg.channels(); 90 | int imgHeight = iImg.rows; 91 | int imgWidth = iImg.cols; 92 | for (int c = 0; c < channels; c++) { 93 | for (int h = 0; h < imgHeight; h++) { 94 | for (int w = 0; w < imgWidth; w++) 95 | { 96 | iBlob[c * imgWidth * imgHeight + h * imgWidth + w] = typename std::remove_pointer::type( 97 | (iImg.at(h, w)[c]) / 255.0f); 98 | } 99 | } 100 | } 101 | return RET_OK; 102 | } 103 | 104 | 105 | char* YOLO_V8::PreProcess(cv::Mat& iImg, std::vector iImgSize, cv::Mat& oImg) { 106 | if (iImg.channels() == 3) { 107 | oImg = iImg.clone(); 108 | cv::cvtColor(oImg, oImg, cv::COLOR_BGR2RGB); 109 | } 110 | else { 111 | cv::cvtColor(iImg, oImg, cv::COLOR_GRAY2RGB); 112 | } 113 | 114 | 115 | int h = iImg.rows; 116 | int w = iImg.cols; 117 | int m = std::min(h, w); 118 | int top = (h - m) / 2; 119 | int left = (w - m) / 2; 120 | 121 | cv::resize(oImg(cv::Rect(left, top, m, m)), oImg, cv::Size(iImgSize.at(0), iImgSize.at(1))); 122 | return RET_OK; 123 | } 124 | 125 | 126 | void LetterBox(const cv::Mat& image, cv::Mat& outImage, cv::Vec4d& params, 127 | const cv::Size& newShape=cv::Size(640, 640), 128 | bool autoShape=true, 129 | bool scaleFill=false, 130 | bool scaleUp=true, 131 | int stride=32, 132 | const cv::Scalar& color=cv::Scalar(114, 114, 114)) 133 | { 134 | // if (false) { 135 | // int maxLen = MAX(image.rows, image.cols); 136 | // outImage = cv::Mat::zeros(cv::Size(maxLen, maxLen), CV_8UC3); 137 | // image.copyTo(outImage(cv::Rect(0, 0, image.cols, image.rows))); 138 | // params[0] = 1; 139 | // params[1] = 1; 140 | // params[3] = 0; 141 | // params[2] = 0; 142 | // } 143 | 144 | cv::Size shape = image.size(); 145 | float r = std::min((float)newShape.height / (float)shape.height, 146 | (float)newShape.width / (float)shape.width); 147 | if (!scaleUp) 148 | r = std::min(r, 1.0f); 149 | 150 | float ratio[2]{ r, r }; 151 | int new_un_pad[2] = { (int)std::round((float)shape.width * r), 152 | (int)std::round((float)shape.height * r) }; 153 | 154 | auto dw = (float)(newShape.width - new_un_pad[0]); 155 | auto dh = (float)(newShape.height - new_un_pad[1]); 156 | 157 | if (autoShape) { 158 | dw = (float)((int)dw % stride); 159 | dh = (float)((int)dh % stride); 160 | } 161 | else if (scaleFill) { 162 | dw = 0.0f; 163 | dh = 0.0f; 164 | new_un_pad[0] = newShape.width; 165 | new_un_pad[1] = newShape.height; 166 | ratio[0] = (float)newShape.width / (float)shape.width; 167 | ratio[1] = (float)newShape.height / (float)shape.height; 168 | } 169 | 170 | dw /= 2.0f; 171 | dh /= 2.0f; 172 | 173 | if (shape.width != new_un_pad[0] && shape.height != new_un_pad[1]) { 174 | cv::resize(image, outImage, cv::Size(new_un_pad[0], new_un_pad[1])); 175 | } 176 | else { 177 | outImage = image.clone(); 178 | } 179 | int top = int(std::round(dh - 0.1f)); 180 | int bottom = int(std::round(dh + 0.1f)); 181 | int left = int(std::round(dw - 0.1f)); 182 | int right = int(std::round(dw + 0.1f)); 183 | params[0] = ratio[0]; 184 | params[1] = ratio[1]; 185 | params[2] = left; 186 | params[3] = top; 187 | cv::copyMakeBorder(outImage, outImage, top, bottom, left, right, cv::BORDER_CONSTANT, color); 188 | } 189 | 190 | 191 | void GetMask( 192 | const int* const _seg_params, 193 | const float& rectConfidenceThreshold, 194 | const cv::Mat& maskProposals, 195 | const cv::Mat& mask_protos, 196 | const cv::Vec4d& params, 197 | const cv::Size& srcImgShape, 198 | std::vector& output) 199 | { 200 | int _segChannels = *_seg_params; 201 | int _segHeight = *(_seg_params + 1); 202 | int _segWidth = *(_seg_params + 2); 203 | int _netHeight = *(_seg_params + 3); 204 | int _netWidth = *(_seg_params + 4); 205 | 206 | cv::Mat protos = mask_protos.reshape(0, { _segChannels,_segWidth * _segHeight }); 207 | cv::Mat matmulRes = (maskProposals * protos).t(); 208 | cv::Mat masks = matmulRes.reshape(output.size(), { _segHeight,_segWidth }); 209 | std::vector maskChannels; 210 | split(masks, maskChannels); 211 | for (int i = 0; i < output.size(); ++i) { 212 | cv::Mat dest, mask; 213 | cv::exp(-maskChannels[i], dest); 214 | dest = 1.0 / (1.0 + dest); 215 | cv::Rect roi( 216 | int(params[2] / _netWidth * _segWidth), 217 | int(params[3] / _netHeight * _segHeight), 218 | int(_segWidth - params[2] / 2), 219 | int(_segHeight - params[3] / 2)); 220 | dest = dest(roi); 221 | cv::resize(dest, mask, srcImgShape, cv::INTER_NEAREST); 222 | cv::Rect temp_rect = output[i].box; 223 | mask = mask(temp_rect) > 0.5f; // 固定mask阈值,实测0.5效果最好 224 | std::vector> contours; 225 | cv::findContours(mask.clone(), contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); 226 | double maxArea = -1; 227 | int maxAreaIdx = -1; 228 | for (int j = 0; j < contours.size(); ++j) { 229 | double area = cv::contourArea(contours[j]); 230 | if (area > maxArea) { 231 | maxArea = area; 232 | maxAreaIdx = j; 233 | } 234 | } 235 | if (maxAreaIdx != -1) { 236 | std::vector> filteredContours; 237 | filteredContours.push_back(contours[maxAreaIdx]); 238 | output[i].contours = filteredContours; 239 | } else { 240 | output[i].contours.clear(); 241 | } 242 | } 243 | } 244 | 245 | 246 | 247 | char* YOLO_V8::CreateSession(DL_INIT_PARAM& iParams) { 248 | rectConfidenceThreshold = iParams.rectConfidenceThreshold; 249 | iouThreshold = iParams.iouThreshold; 250 | imgSize = iParams.imgSize; 251 | modelType = iParams.modelType; 252 | env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "YOLOv8ONNXRuntimeInference"); 253 | Ort::SessionOptions sessionOption; 254 | 255 | if (iParams.cudaEnable) { 256 | cudaEnable = iParams.cudaEnable; 257 | auto providers = Ort::GetAvailableProviders(); 258 | auto cudaAvailable = std::find(providers.begin(), providers.end(), "CUDAExecutionProvider"); 259 | OrtCUDAProviderOptions cudaOption; 260 | 261 | if (cudaAvailable != providers.end()) 262 | { 263 | std::cout << "Inference device: GPU" << std::endl; 264 | cudaOption.device_id = 0; 265 | cudaOption.arena_extend_strategy = 0; 266 | cudaOption.do_copy_in_default_stream = 1; 267 | // cudaOption.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchDefault; 268 | // cudaOption.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchExhaustive; 269 | cudaOption.cudnn_conv_algo_search = OrtCudnnConvAlgoSearch::OrtCudnnConvAlgoSearchHeuristic; 270 | sessionOption.AppendExecutionProvider_CUDA(cudaOption); 271 | } 272 | else if (cudaAvailable == providers.end()) 273 | { 274 | std::cout << "GPU is not supported. Fallback to CPU." << std::endl; 275 | std::cout << "Inference device: CPU" << std::endl; 276 | } 277 | } 278 | 279 | sessionOption.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL); 280 | sessionOption.SetIntraOpNumThreads(0); 281 | // sessionOption.SetExecutionMode(ExecutionMode::ORT_PARALLEL); 282 | // sessionOption.SetInterOpNumThreads(0); 283 | sessionOption.SetLogSeverityLevel(iParams.logSeverityLevel); 284 | 285 | #ifdef _WIN32 286 | int ModelPathSize = MultiByteToWideChar(CP_UTF8, 0, iParams.modelPath.c_str(), static_cast(iParams.modelPath.length()), nullptr, 0); 287 | wchar_t* wide_cstr = new wchar_t[ModelPathSize + 1]; 288 | MultiByteToWideChar(CP_UTF8, 0, iParams.modelPath.c_str(), static_cast(iParams.modelPath.length()), wide_cstr, ModelPathSize); 289 | wide_cstr[ModelPathSize] = L'\0'; 290 | const wchar_t* modelPath = wide_cstr; 291 | #else 292 | const char* modelPath = iParams.modelPath.c_str(); 293 | #endif 294 | 295 | session = new Ort::Session(env, modelPath, sessionOption); 296 | Ort::AllocatorWithDefaultOptions allocator; 297 | size_t InputNodesNum = session->GetInputCount(); 298 | 299 | for (size_t i = 0; i < InputNodesNum; i++) { 300 | Ort::AllocatedStringPtr input_node_name = session->GetInputNameAllocated(i, allocator); 301 | this->inputNodeNames.push_back(input_node_name.get()); 302 | input_names_ptr.push_back(std::move(input_node_name)); 303 | 304 | Ort::TypeInfo inputTypeInfo = session->GetInputTypeInfo(i); 305 | std::vector inputTensorShape = inputTypeInfo.GetTensorTypeAndShapeInfo().GetShape(); 306 | this->inputShapes.push_back(inputTensorShape); 307 | // this->isDynamicInputShape = false; 308 | // // checking if width and height are dynamic 309 | // if (inputTensorShape[2] == -1 && inputTensorShape[3] == -1) 310 | // { 311 | // std::cout << "Dynamic input shape" << std::endl; 312 | // this->isDynamicInputShape = true; 313 | // } 314 | } 315 | size_t OutputNodesNum = session->GetOutputCount(); 316 | if (OutputNodesNum > 1) 317 | { 318 | this->runSegmentation = true; 319 | std::cout << "Instance Segmentation" << std::endl; 320 | } 321 | else 322 | std::cout << "Object Detection" << std::endl; 323 | 324 | for (size_t i = 0; i < OutputNodesNum; i++) { 325 | Ort::AllocatedStringPtr output_node_name = session->GetOutputNameAllocated(i, allocator); 326 | 327 | this->outputNodeNames.push_back(output_node_name.get()); 328 | output_names_ptr.push_back(std::move(output_node_name)); 329 | 330 | Ort::TypeInfo outputTypeInfo = session->GetOutputTypeInfo(i); 331 | std::vector outputTensorShape = outputTypeInfo.GetTensorTypeAndShapeInfo().GetShape(); 332 | this->outputShapes.push_back(outputTensorShape); 333 | } 334 | 335 | // for (const char *x : this->inputNodeNames) 336 | // { 337 | // std::cout << x << std::endl; 338 | // } 339 | // for (const char *x : this->outputNodeNames) 340 | // { 341 | // std::cout << x << std::endl; 342 | // } 343 | 344 | options = Ort::RunOptions{ nullptr }; 345 | WarmUpSession(); 346 | return RET_OK; 347 | } 348 | 349 | 350 | char* YOLO_V8::RunSession(cv::Mat& iImg, std::vector& oResult) { 351 | 352 | #ifdef benchmark 353 | auto starttime_1 = std::chrono::high_resolution_clock::now(); 354 | #endif 355 | char* Ret = RET_OK; 356 | cv::Mat processedImg; 357 | cv::Vec4d params; 358 | //resize图片尺寸,PreProcess是resize+centercrop,LetterBox有padding 359 | switch (modelType) { 360 | case YOLO_DET_SEG_V8: { 361 | LetterBox(iImg, processedImg, params, cv::Size(imgSize.at(1), imgSize.at(0))); 362 | break; 363 | } 364 | case YOLO_CLS_V8: { 365 | PreProcess(iImg, imgSize, processedImg); 366 | break; 367 | } 368 | } 369 | if (modelType < 4) { 370 | float* blob = new float[processedImg.total() * 3]; 371 | BlobFromImage(processedImg, blob); 372 | std::vector inputNodeDims = { 1, 3, imgSize.at(0), imgSize.at(1) }; 373 | TensorProcess(starttime_1, params, iImg, blob, inputNodeDims, oResult); 374 | } 375 | else { 376 | #ifdef USE_CUDA 377 | half* blob = new half[processedImg.total() * 3]; 378 | BlobFromImage(processedImg, blob); 379 | std::vector inputNodeDims = { 1, 3, imgSize.at(0), imgSize.at(1) }; 380 | TensorProcess(starttime_1, params, iImg, blob, inputNodeDims, oResult); 381 | #endif 382 | } 383 | return Ret; 384 | } 385 | 386 | 387 | 388 | template 389 | char* YOLO_V8::TensorProcess(std::chrono::_V2::system_clock::time_point& starttime_1, 390 | cv::Vec4d& params, cv::Mat& iImg, N& blob, std::vector& inputNodeDims, std::vector& oResult) { 391 | Ort::Value inputTensor = Ort::Value::CreateTensor::type> ( 392 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), 393 | blob, 394 | 3 * imgSize.at(0) * imgSize.at(1), 395 | inputNodeDims.data(), 396 | inputNodeDims.size() 397 | ); 398 | #ifdef benchmark 399 | auto starttime_2 = std::chrono::high_resolution_clock::now(); 400 | #endif 401 | auto outputTensor = session->Run(options, inputNodeNames.data(), &inputTensor, 1, outputNodeNames.data(),outputNodeNames.size()); 402 | #ifdef benchmark 403 | auto starttime_3 = std::chrono::high_resolution_clock::now(); 404 | #endif 405 | std::vector _outputTensorShape; 406 | _outputTensorShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); 407 | auto output = outputTensor[0].GetTensorMutableData::type>(); 408 | delete[] blob; 409 | 410 | switch (modelType) { 411 | case YOLO_DET_SEG_V8: { 412 | // yolov8 has an output of shape (batchSize, 84, 8400) (Num classes + box[x,y,w,h]) 413 | std::cout << "---------------------YOLO_DET_SEG_V8---------------------" << std::endl; 414 | int dimensions = _outputTensorShape[1]; 415 | int rows = _outputTensorShape[2]; 416 | cv::Mat rowData(dimensions, rows, CV_32F, output); 417 | if (rows > dimensions) { 418 | dimensions = _outputTensorShape[2]; 419 | rows = _outputTensorShape[1]; 420 | rowData = rowData.t(); 421 | } 422 | std::vector class_ids; 423 | std::vector confidences; 424 | std::vector boxes; 425 | std::vector> picked_proposals; 426 | 427 | float* data = (float*)rowData.data; 428 | 429 | for (int i = 0; i < dimensions; ++i) { 430 | float* classesScores = data + 4; 431 | cv::Mat scores(1, this->classes.size(), CV_32FC1, classesScores); 432 | cv::Point class_id; 433 | double maxClassScore; 434 | cv::minMaxLoc(scores, 0, &maxClassScore, 0, &class_id); 435 | if (maxClassScore > rectConfidenceThreshold) { 436 | if (runSegmentation) { 437 | int _segChannels = outputTensor[1].GetTensorTypeAndShapeInfo().GetShape()[1]; 438 | std::vector temp_proto(data + classes.size() + 4, data + classes.size() + 4 + _segChannels); 439 | picked_proposals.push_back(temp_proto); 440 | } 441 | confidences.push_back(maxClassScore); 442 | class_ids.push_back(class_id.x); 443 | float x = (data[0] - params[2]) / params[0]; 444 | float y = (data[1] - params[3]) / params[1]; 445 | float w = data[2] / params[0]; 446 | float h = data[3] / params[1]; 447 | int left = MAX(round(x - 0.5 * w + 0.5), 0); 448 | int top = MAX(round(y - 0.5 * h + 0.5), 0); 449 | if ((left + w) > iImg.cols) { w = iImg.cols - left; } 450 | if ((top + h) > iImg.rows) { h = iImg.rows - top; } 451 | boxes.emplace_back(cv::Rect(left, top, int(w), int(h))); 452 | } 453 | data += rows; 454 | } 455 | std::vector nmsResult; 456 | cv::dnn::NMSBoxes(boxes, confidences, rectConfidenceThreshold, iouThreshold, nmsResult); 457 | std::vector> temp_mask_proposals; 458 | for (int i = 0; i < nmsResult.size(); ++i) { 459 | int idx = nmsResult[i]; 460 | DL_RESULT result; 461 | result.classId = class_ids[idx]; 462 | result.confidence = confidences[idx]; 463 | result.box = boxes[idx]; 464 | result.className = classes[result.classId]; 465 | std::random_device rd; 466 | std::mt19937 gen(rd()); 467 | std::uniform_int_distribution dis(100, 255); 468 | result.color = cv::Scalar(dis(gen),dis(gen),dis(gen)); 469 | if (result.box.width != 0 && result.box.height != 0) oResult.push_back(result); 470 | if (runSegmentation) temp_mask_proposals.push_back(picked_proposals[idx]); 471 | } 472 | if (!boxes.empty()) { 473 | if (runSegmentation) { 474 | cv::Mat mask_proposals; 475 | for (int i = 0; i < temp_mask_proposals.size(); ++i) 476 | mask_proposals.push_back(cv::Mat(temp_mask_proposals[i]).t()); 477 | std::vector _outputMaskTensorShape; 478 | _outputMaskTensorShape = outputTensor[1].GetTensorTypeAndShapeInfo().GetShape(); 479 | int _segChannels = _outputMaskTensorShape[1]; 480 | int _segWidth = _outputMaskTensorShape[2]; 481 | int _segHeight = _outputMaskTensorShape[3]; 482 | float* pdata = outputTensor[1].GetTensorMutableData(); 483 | std::vector mask(pdata, pdata + _segChannels * _segWidth * _segHeight); 484 | int _seg_params[5] = {_segChannels, _segWidth, _segHeight, imgSize.at(0), imgSize.at(1) }; 485 | cv::Mat mask_protos = cv::Mat(mask); 486 | GetMask(_seg_params, rectConfidenceThreshold, mask_proposals, mask_protos, params, iImg.size(), oResult); 487 | } 488 | } 489 | 490 | #ifdef benchmark 491 | auto starttime_4 = std::chrono::high_resolution_clock::now(); 492 | 493 | double pre_process_time = std::chrono::duration_cast(starttime_2 - starttime_1).count(); 494 | double process_time = std::chrono::duration_cast(starttime_3 - starttime_2).count(); 495 | double post_process_time = std::chrono::duration_cast(starttime_4 - starttime_3).count(); 496 | double total_time = pre_process_time + process_time + post_process_time; 497 | if (cudaEnable) { 498 | std::cout << "[YOLO_V8(CUDA)]: 前处理 " << pre_process_time << " ms, 推理 " << process_time 499 | << " ms, 后处理 " << post_process_time << " ms. 总共耗时 " << total_time << " ms." << std::endl; 500 | } 501 | else { 502 | std::cout << "[YOLO_V8(CPU)]: 前处理 " << pre_process_time << " ms, 推理 " << process_time 503 | << " ms, 后处理 " << post_process_time << " ms. 总共耗时 " << total_time << " ms." << std::endl; 504 | } 505 | #endif 506 | break; 507 | } 508 | case YOLO_CLS_V8: 509 | { 510 | cv::Mat rawData; 511 | rawData = cv::Mat(1, this->classes.size(), CV_32F, output); 512 | 513 | float *data = (float *) rawData.data; 514 | 515 | DL_RESULT result; 516 | for (int i = 0; i < this->classes.size(); i++) 517 | { 518 | result.classId = i; 519 | result.confidence = data[i]; 520 | oResult.push_back(result); 521 | } 522 | 523 | #ifdef benchmark 524 | auto starttime_4 = std::chrono::high_resolution_clock::now(); 525 | double pre_process_time = std::chrono::duration_cast(starttime_2 - starttime_1).count(); 526 | double process_time = std::chrono::duration_cast(starttime_3 - starttime_2).count(); 527 | double post_process_time = std::chrono::duration_cast(starttime_4 - starttime_3).count(); 528 | double total_time = pre_process_time + process_time + post_process_time; 529 | if (cudaEnable) { 530 | std::cout << "[YOLO_V8(CUDA)]: 前处理 " << pre_process_time << " ms, 推理 " << process_time 531 | << " ms, 后处理 " << post_process_time << " ms. 总共耗时 " << total_time << " ms." << std::endl; 532 | } 533 | else { 534 | std::cout << "[YOLO_V8(CPU)]: 前处理 " << pre_process_time << " ms, 推理 " << process_time 535 | << " ms, 后处理 " << post_process_time << " ms. 总共耗时 " << total_time << " ms." << std::endl; 536 | } 537 | #endif 538 | break; 539 | } 540 | default: 541 | std::cout << "[YOLO_V8]: " << "不支持的模型类型." << std::endl; 542 | } 543 | return RET_OK; 544 | 545 | } 546 | 547 | 548 | char* YOLO_V8::WarmUpSession() { 549 | cv::Mat iImg = cv::Mat(cv::Size(imgSize.at(0), imgSize.at(1)), CV_8UC3); 550 | cv::Mat processedImg; 551 | cv::Vec4d params; 552 | LetterBox(iImg, processedImg, params, cv::Size(imgSize.at(1), imgSize.at(0))); 553 | if (modelType < 4) { 554 | float* blob = new float[iImg.total() * 3]; 555 | BlobFromImage(processedImg, blob); 556 | std::vector YOLO_input_node_dims = { 1, 3, imgSize.at(0), imgSize.at(1) }; 557 | Ort::Value input_tensor = Ort::Value::CreateTensor( 558 | Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), 559 | YOLO_input_node_dims.data(), YOLO_input_node_dims.size()); 560 | auto output_tensors = session->Run(options, inputNodeNames.data(), &input_tensor, 1, outputNodeNames.data(), outputNodeNames.size()); 561 | delete[] blob; 562 | } 563 | else { 564 | #ifdef USE_CUDA 565 | half* blob = new half[iImg.total() * 3]; 566 | BlobFromImage(processedImg, blob); 567 | std::vector YOLO_input_node_dims = { 1,3,imgSize.at(0),imgSize.at(1) }; 568 | Ort::Value input_tensor = Ort::Value::CreateTensor(Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob, 3 * imgSize.at(0) * imgSize.at(1), YOLO_input_node_dims.data(), YOLO_input_node_dims.size()); 569 | auto output_tensors = session->Run(options, inputNodeNames.data(), &input_tensor, 1, outputNodeNames.data(), outputNodeNames.size()); 570 | delete[] blob; 571 | #endif 572 | } 573 | return RET_OK; 574 | } 575 | 576 | 577 | int YOLO_V8::ReadClassNames(const std::string& txtPath, std::vector& classNames) { 578 | std::ifstream file(txtPath); 579 | if (!file.is_open()) { 580 | std::cerr << "Failed to open TXT file" << std::endl; 581 | return 1; 582 | } 583 | 584 | std::string line; 585 | while (std::getline(file, line)) { 586 | 587 | line.erase(line.begin(), std::find_if(line.begin(), line.end(), [](unsigned char ch) { 588 | return !std::isspace(ch); 589 | })); 590 | line.erase(std::find_if(line.rbegin(), line.rend(), [](unsigned char ch) { 591 | return !std::isspace(ch); 592 | }).base(), line.end()); 593 | 594 | if (!line.empty()) { 595 | classNames.push_back(line); 596 | } 597 | } 598 | 599 | file.close(); 600 | return 0; 601 | } -------------------------------------------------------------------------------- /inference.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #define RET_OK nullptr 3 | #ifdef _WIN32 4 | #include 5 | #include 6 | #include 7 | #endif 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "onnxruntime_cxx_api.h" 13 | #include 14 | 15 | #ifdef USE_CUDA 16 | #include 17 | #endif 18 | 19 | enum MODEL_TYPE 20 | { 21 | YOLO_DET_SEG_V8 = 1, 22 | YOLO_CLS_V8 = 2, 23 | }; 24 | 25 | typedef struct _DL_INIT_PARAM 26 | { 27 | std::string modelPath; 28 | MODEL_TYPE modelType; 29 | std::vector imgSize; 30 | float rectConfidenceThreshold = 0.6f; 31 | float iouThreshold = 0.5f; 32 | bool cudaEnable = false; 33 | int logSeverityLevel = 3; 34 | int intraOpNumThreads = 1; 35 | } DL_INIT_PARAM; 36 | 37 | typedef struct _DL_RESULT 38 | { 39 | int classId; 40 | std::string className; 41 | float confidence; 42 | cv::Rect box; 43 | std::vector> contours; // 分割的轮廓点 44 | cv::Scalar color; 45 | } DL_RESULT; 46 | 47 | 48 | 49 | class YOLO_V8 50 | { 51 | public: 52 | YOLO_V8(); 53 | ~YOLO_V8(); 54 | public: 55 | char* CreateSession(DL_INIT_PARAM& iParams); 56 | char* RunSession(cv::Mat& iImg, std::vector& oResult); 57 | char* WarmUpSession(); 58 | template 59 | char* TensorProcess(std::chrono::_V2::system_clock::time_point& starttime_1, cv::Vec4d& params, cv::Mat& iImg, N& blob, std::vector& inputNodeDims, 60 | std::vector& oResult); 61 | char* PreProcess(cv::Mat& iImg, std::vector iImgSize, cv::Mat& oImg); 62 | int ReadClassNames(const std::string& txtPath, std::vector& classNames); 63 | std::vector classes{}; 64 | int classNums = 80; 65 | MODEL_TYPE modelType; 66 | std::vector imgSize; 67 | std::vector Inference(const std::string& imagePath,const std::string& txtPath); 68 | private: 69 | bool cudaEnable; 70 | Ort::Env env; 71 | Ort::Session* session; 72 | Ort::RunOptions options; 73 | bool runSegmentation = false; 74 | std::vector inputNodeNames; 75 | std::vector outputNodeNames; 76 | std::vector input_names_ptr; 77 | std::vector output_names_ptr; 78 | std::vector> inputShapes; 79 | std::vector> outputShapes; 80 | float rectConfidenceThreshold; 81 | float iouThreshold; 82 | // bool isDynamicInputShape{}; 83 | }; 84 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include "inference.h" 7 | #include 8 | namespace fs = std::filesystem; 9 | 10 | 11 | void test(const std::string& directoryPath) { 12 | DL_INIT_PARAM params; 13 | std::filesystem::path projectRoot = std::filesystem::current_path().parent_path(); 14 | std::string labelPath = projectRoot / "huachuan/class_names_list.txt"; 15 | params.modelPath = projectRoot / "huachuan/best.onnx"; 16 | params.modelType = YOLO_DET_SEG_V8; 17 | params.imgSize = { 312, 312 }; // VAlgo需要暴露推理时的图片尺寸 18 | params.rectConfidenceThreshold = 0.7; 19 | params.iouThreshold = 0.0001; 20 | params.cudaEnable = true; 21 | auto starttime_1 = std::chrono::high_resolution_clock::now(); 22 | 23 | std::unique_ptr yolo(new YOLO_V8); 24 | yolo->CreateSession(params); 25 | auto starttime_3 = std::chrono::high_resolution_clock::now(); 26 | auto duration_ms4 = std::chrono::duration_cast(starttime_3 - starttime_1).count(); 27 | std::cout << "[YOLO_V8]: 模型预热时间: " << duration_ms4 << "ms" << std::endl; 28 | 29 | for (const auto& entry : fs::directory_iterator(directoryPath)) { 30 | if (fs::is_regular_file(entry.path()) && entry.path().extension() == ".jpg") { 31 | std::string imagePath = entry.path().string(); 32 | std::string imageName = entry.path().filename().stem().string(); 33 | std::cout << "\n[YOLO_V8]: 正在推理图片: " << imageName << ".jpg" << std::endl; 34 | auto starttime_2 = std::chrono::high_resolution_clock::now(); 35 | auto results = yolo->Inference(imagePath, labelPath); 36 | auto starttime_4 = std::chrono::high_resolution_clock::now(); 37 | auto duration_ms3 = std::chrono::duration_cast(starttime_4 - starttime_2).count(); 38 | 39 | cv::Mat image = cv::imread(imagePath); 40 | if (image.empty()) { 41 | std::cerr << "[YOLO_V8]: Failed to load image" << std::endl; 42 | return; 43 | } 44 | 45 | if (params.modelType == YOLO_DET_SEG_V8) { 46 | 47 | for (const auto& result : results) { 48 | std::cout << "[YOLO_V8]: 类别: " << result.className 49 | << " , 置信度: " << result.confidence 50 | << std::endl; 51 | } 52 | int detections = results.size(); 53 | std::cout << "[YOLO_V8]: 检测数量: " << detections << std::endl; 54 | std::cout << "[YOLO_V8]: 总推理时间:" <> contours = detection.contours; 69 | cv::drawContours(image(box), contours, -1, cv::Scalar(0, 255, 0), 2); 70 | } 71 | } 72 | std::string outputDirectory = "/home/yibo/git_dir/yolov8_onnxruntime_cpp/det_seg_output/"; 73 | 74 | if (!fs::exists(outputDirectory)) 75 | fs::create_directory(outputDirectory); 76 | 77 | std::filesystem::path outputImagePath = outputDirectory + imageName + "_result.jpg"; 78 | 79 | cv::imwrite(outputImagePath.string(), image); 80 | } 81 | else { // YOLO_CLS_V8 82 | 83 | for (const auto& result : results) { 84 | std::cout << "[YOLO_V8]: 类别: " << result.className << ", 置信度: " << result.confidence << std::endl; 85 | std::string text = result.className + " " + std::to_string(result.confidence).substr(0, 4); 86 | cv::putText(image, text, cv::Point(10, 30), cv::FONT_HERSHEY_SIMPLEX, 1, cv::Scalar(0, 255, 0), 2); 87 | 88 | } 89 | std::string outputDirectory = "/home/yibo/git_dir/yolov8_onnxruntime_cpp/cls_output/"; 90 | if (!fs::exists(outputDirectory)) 91 | fs::create_directory(outputDirectory); 92 | 93 | std::filesystem::create_directory(outputDirectory); 94 | 95 | std::filesystem::path outputImagePath = outputDirectory + imageName + "_result.jpg"; 96 | cv::imwrite(outputImagePath.string(), image); 97 | } 98 | 99 | } 100 | } 101 | } 102 | 103 | 104 | int main() { 105 | std::string dir = "/home/yibo/git_dir/yolov8_onnxruntime_cpp/images/HandianLocation_HC"; 106 | test(dir); 107 | return 0; 108 | } 109 | 110 | 111 | -------------------------------------------------------------------------------- /scripts/build.sh: -------------------------------------------------------------------------------- 1 | rm -r ../build 2 | 3 | mkdir ../build 4 | 5 | cd ../build 6 | 7 | cmake .. 8 | 9 | make 10 | 11 | cd ../scripts 12 | 13 | sh run.sh -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | ../build/Yolov8OnnxRuntimeCPPInference 2 | 3 | chmod -R 777 /home/yibo/git_dir/yolov8_onnxruntime_cpp/cls_output 4 | chmod -R 777 /home/yibo/git_dir/yolov8_onnxruntime_cpp/det_seg_output 5 | chmod -R 777 /home/yibo/git_dir/yolov8_onnxruntime_cpp/output 6 | chmod -R 777 /home/yibo/git_dir/yolov8_onnxruntime_cpp/build 7 | -------------------------------------------------------------------------------- /scripts/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | parent_dir=$(dirname "$(pwd)") 4 | 5 | docker run \ 6 | --privileged \ 7 | --gpus all \ 8 | --rm \ 9 | -it \ 10 | -v $parent_dir:$parent_dir \ 11 | -w $parent_dir/scripts \ 12 | hub.micro-i.com.cn:9443/dad/tritonserver_dev:23.09-py3 bash 13 | --------------------------------------------------------------------------------