├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── README.md ├── assets ├── Bench_YOLO_V10.jpg ├── Yolo_v10_cpp_tenosrrt.PNG └── blue-linkedin-logo.png ├── include └── YOLOv10.hpp └── src ├── YOLOv10.cpp └── main.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | Build/ -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(YOLOv10TRT) 4 | 5 | set(CMAKE_CXX_STANDARD 14) 6 | 7 | # Find OpenCV 8 | find_package(OpenCV REQUIRED) 9 | include_directories(${OpenCV_INCLUDE_DIRS}) 10 | 11 | # Find CUDA 12 | find_package(CUDA REQUIRED) 13 | include_directories(${CUDA_INCLUDE_DIRS}) 14 | 15 | # Set the path to TensorRT installation 16 | set(TENSORRT_PATH "F:/Program Files/TensorRT-8.6.1.6") # Update this to the actual path 17 | 18 | # Include TensorRT directories 19 | include_directories(${TENSORRT_PATH}/include) 20 | 21 | # Link TensorRT libraries 22 | link_directories(${TENSORRT_PATH}/lib) 23 | 24 | # Include directory for your project 25 | include_directories(${CMAKE_SOURCE_DIR}/include) 26 | 27 | # Define the source files 28 | set(SOURCES 29 | src/main.cpp 30 | src/YOLOv10.cpp 31 | ) 32 | 33 | # Add the executable target 34 | add_executable(YOLOv10Project ${SOURCES}) 35 | 36 | # Link libraries to the target 37 | target_link_libraries(YOLOv10Project 38 | ${OpenCV_LIBS} 39 | ${CUDA_LIBRARIES} 40 | nvinfer 41 | nvonnxparser 42 | ) 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Custom License Agreement 2 | 3 | 1. License Grant You are hereby granted a non-exclusive, non-transferable license to use, reproduce, and distribute the code (hereinafter referred to as "the Software") under the following conditions: 4 | 5 | 2. Conditions of Use 6 | 7 | Non-Commercial Use: You may use the Software for personal, educational, or non-commercial purposes without any additional permissions. 8 | Commercial Use: Any commercial use of the Software, including but not limited to selling, licensing, or using it in a commercial product, requires prior written permission from the original developer. 9 | 3. Contact Requirement 10 | 11 | If you wish to use the Software for commercial purposes, you must contact the original developer at [https://www.linkedin.com/in/hamdi-boukamcha/] to obtain a commercial license. 12 | The terms of any commercial license will be mutually agreed upon and may involve a licensing fee. 13 | 4. Attribution 14 | 15 | Regardless of whether you are using the Software for commercial or non-commercial purposes, you must provide appropriate credit to the original developer in any distributions or products that use the Software. 16 | 5. Disclaimer of Warranty 17 | 18 | The Software is provided "as is," without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the original developer be liable for any claim, damages, or other liability, whether in an action of contract, tort, or otherwise, arising from, out of, or in connection with the Software or the use or other dealings in the Software. 19 | 6. Governing Law 20 | 21 | This License Agreement shall be governed by and construed in accordance with the laws of France. 22 | By using the Software, you agree to abide by the terms outlined in this License Agreement. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLO V10 C++ TensorRT 2 | 3 | ![Inference Time of Yolo V10 Models](assets/Yolo_v10_cpp_tenosrrt.PNG) 4 | 5 | 6 | 7 | 8 | GitHub 9 | 10 | 11 | 12 | License 13 | 14 | 15 | ## 🌐 Overview 16 | 17 | The **YOLOv10 C++ TensorRT Project** is a high-performance object detection solution implemented in **C++** and optimized using **NVIDIA TensorRT**. This project leverages the YOLOv10 model to deliver fast and accurate object detection, utilizing TensorRT to maximize inference efficiency and performance. 18 | 19 | ![Inference Time of Yolo V10 Models](assets/Bench_YOLO_V10.jpg) 20 | 21 | 22 | ## 📢 Updates 23 | 24 | ### Key Features: 25 | 26 | - **Model Conversion**: Convert ONNX models to TensorRT engine files to accelerate inference. 27 | - **Inference on Videos**: Efficiently perform object detection on video files. 28 | - **Inference on Images**: Execute object detection on individual image files. 29 | 30 | By combining the advanced capabilities of YOLOv10 with TensorRT’s powerful optimization, this project provides a robust and scalable solution for real-time object detection tasks. 31 | 32 | ## 📑 Table of Contents 33 | 34 | - [Project Structure](#project-structure) 35 | - [Dependencies](#dependencies) 36 | - [Installation](#installation) 37 | - [Usage](#usage) 38 | - [Convert ONNX Model to TensorRT Engine](#convert-onnx-model-to-tensorrt-engine) 39 | - [Run Inference on Video](#run-inference-on-video) 40 | - [Run Inference on Image](#run-inference-on-image) 41 | - [Configuration](#configuration) 42 | - [Troubleshooting](#troubleshooting) 43 | 44 | ## 🏗️ Project Structure 45 | YOLOv10-TensorRT/ 46 | 47 | │── include/ 48 | 49 | │ ├── YOLOv10.hpp 50 | 51 | │── src/ 52 | 53 | │ ├── main.cpp 54 | 55 | │ ├── YOLOv10.cpp 56 | 57 | │── CMakeLists.txt 58 | 59 | └── README.md 60 | ## 📦 Dependencies 61 | 62 | - **OpenCV V4.10.0**: For image and video processing. 63 | - **CUDA V11.7**: For GPU acceleration. 64 | - **TensorRT V8.6.1.6**: For optimized inference with YOLOv10. 65 | - **cuDNN V9.2.1**: For accelerating deep learning training and inference on NVIDIA GPUs. 66 | 67 | ## 💾 Installation 68 | 69 | ### 1. Install Dependencies 70 | 71 | - **OpenCV**: Follow the instructions on the [OpenCV official website](https://opencv.org/) to install OpenCV. 72 | - **CUDA & cuDNN**: Install CUDA & cuDNN from the [NVIDIA website](https://developer.nvidia.com/cuda-toolkit). 73 | - **TensorRT**: Download and install TensorRT from the [NVIDIA Developer website](https://developer.nvidia.com/tensorrt). 74 | 75 | ### 2. Clone the Repository 76 | 77 | 78 | git clone [https://github.com/hamdiboukamcha/yolov10-tensorrt.git](https://github.com/hamdiboukamcha/yolov10-tensorrt.git) 79 | 80 | cd yolov10-tensorrt/Yolov10-TensorRT 81 | 82 | mkdir build 83 | cd build 84 | cmake .. 85 | cmake --build . 86 | 87 | ## 🚀 Usage 88 | 89 | ### Convert ONNX Model to TensorRT Engine 90 | 91 | To convert an ONNX model to a TensorRT engine file, use the following command: 92 | 93 | ./YOLOv10Project convert path_to_your_model.onnx path_to_your_engine.engine. 94 | 95 | path_to_your_model.onnx: Path to the ONNX model file. 96 | 97 | path_to_your_engine.engine: Path where the TensorRT engine file will be saved. 98 | 99 | ### Run Inference on Video 100 | To run inference on a video, use the following command: 101 | 102 | ./YOLOv10Project infer_video path_to_your_video.mp4 path_to_your_engine.engine 103 | 104 | path_to_your_video.mp4: Path to the input video file. 105 | 106 | path_to_your_engine.engine: Path to the TensorRT engine file. 107 | 108 | ### Run Inference on Video 109 | Run Inference on Image 110 | To run inference on an image, use the following command: 111 | 112 | ./YOLOv10Project infer_image path_to_your_image.jpg path_to_your_engine.engine 113 | 114 | path_to_your_image.jpg: Path to the input image file. 115 | 116 | path_to_your_engine.engine: Path to the TensorRT engine file. 117 | 118 | ## ⚙️ Configuration 119 | 120 | ### CMake Configuration 121 | In the CMakeLists.txt, update the paths for TensorRT and OpenCV if they are installed in non-default locations: 122 | 123 | #### Set the path to TensorRT installation 124 | 125 | set(TENSORRT_PATH "path/to/TensorRT") # Update this to the actual path 126 | 127 | Ensure that the path points to the directory where TensorRT is installed. 128 | 129 | ### Troubleshooting 130 | Cannot find nvinfer.lib: Ensure that TensorRT is correctly installed and that nvinfer.lib is in the specified path. Update CMakeLists.txt to include the correct path to TensorRT libraries. 131 | 132 | Linker Errors: Verify that all dependencies (OpenCV, CUDA, TensorRT) are correctly installed and that their paths are correctly set in CMakeLists.txt. 133 | 134 | Run-time Errors: Ensure that your system has the correct CUDA drivers and that TensorRT runtime libraries are accessible. Add TensorRT’s bin directory to your system PATH. 135 | 136 | ## 📞 Contact 137 | 138 | For advanced inquiries, feel free to contact me on LinkedIn: LinkedIn 139 | 140 | ## 📜 Citation 141 | 142 | If you use this code in your research, please cite the repository as follows: 143 | 144 | @misc{boukamcha2024yolov10, 145 | author = {Hamdi Boukamcha}, 146 | title = {Yolo-V10-cpp-TensorRT}, 147 | year = {2024}, 148 | publisher = {GitHub}, 149 | howpublished = {\url{https://github.com/hamdiboukamcha/Yolo-V10-cpp-TensorRT}}, 150 | } 151 | 152 | 153 | 154 | -------------------------------------------------------------------------------- /assets/Bench_YOLO_V10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/Yolo-V10-cpp-TensorRT/00c974e684a442b540c4d0491fcd29b65ba4c079/assets/Bench_YOLO_V10.jpg -------------------------------------------------------------------------------- /assets/Yolo_v10_cpp_tenosrrt.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/Yolo-V10-cpp-TensorRT/00c974e684a442b540c4d0491fcd29b65ba4c079/assets/Yolo_v10_cpp_tenosrrt.PNG -------------------------------------------------------------------------------- /assets/blue-linkedin-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/Yolo-V10-cpp-TensorRT/00c974e684a442b540c4d0491fcd29b65ba4c079/assets/blue-linkedin-logo.png -------------------------------------------------------------------------------- /include/YOLOv10.hpp: -------------------------------------------------------------------------------- 1 | #ifndef YOLOV10_HPP 2 | #define YOLOV10_HPP 3 | 4 | #include "opencv2/opencv.hpp" 5 | #include "cuda.h" 6 | #include "NvInfer.h" 7 | #include "NvOnnxParser.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | /** 14 | * @brief Logger class for TensorRT 15 | * 16 | * This class inherits from nvinfer1::ILogger and provides a logging mechanism for TensorRT. 17 | */ 18 | class Logger : public nvinfer1::ILogger { 19 | public: 20 | /** 21 | * @brief Logs a message with a given severity. 22 | * 23 | * @param severity The severity level of the message. 24 | * @param msg The message to log. 25 | */ 26 | void log(Severity severity, const char* msg) noexcept override { 27 | if (severity <= Severity::kWARNING) 28 | std::cout << msg << std::endl; 29 | } 30 | }; 31 | 32 | /** 33 | * @brief Structure to hold detection results. 34 | * 35 | * This structure holds the bounding box, confidence score, and label of a detection result. 36 | */ 37 | struct DetResult { 38 | cv::Rect bbox; ///< Bounding box of the detected object. 39 | float conf; ///< Confidence score of the detection. 40 | int label; ///< Label of the detected object. 41 | 42 | /** 43 | * @brief Constructs a DetResult object. 44 | * 45 | * @param bbox Bounding box of the detected object. 46 | * @param conf Confidence score of the detection. 47 | * @param label Label of the detected object. 48 | */ 49 | DetResult(cv::Rect bbox, float conf, int label) : bbox(bbox), conf(conf), label(label) {} 50 | }; 51 | 52 | /** 53 | * @brief Class for YOLOv10 object detection. 54 | * 55 | * This class provides methods for preprocessing input images, postprocessing detection results, 56 | * drawing bounding boxes, creating execution contexts, and performing inference on video streams. 57 | */ 58 | class YOLOv10 { 59 | public: 60 | /** 61 | * @brief Constructs a YOLOv10 object. 62 | */ 63 | YOLOv10(); 64 | 65 | /** 66 | * @brief Destructs a YOLOv10 object. 67 | */ 68 | ~YOLOv10(); 69 | 70 | /** 71 | * @brief Preprocesses an input image. 72 | * 73 | * @param img Pointer to the input image. 74 | * @param length Length of the image. 75 | * @param factor Scaling factor. 76 | * @param data Vector to hold the preprocessed data. 77 | */ 78 | void preProcess(cv::Mat* img, int length, float* factor, std::vector& data); 79 | 80 | /** 81 | * @brief Postprocesses the detection results. 82 | * 83 | * @param result Pointer to the detection results. 84 | * @param factor Scaling factor. 85 | * @param outputLength Length of the output results. 86 | * @return std::vector Vector of detection results. 87 | */ 88 | std::vector postProcess(float* result, float factor, int outputLength); 89 | 90 | /** 91 | * @brief Draws bounding boxes on an image. 92 | * 93 | * @param img Reference to the image. 94 | * @param res Vector of detection results. 95 | */ 96 | void drawBbox(cv::Mat& img, std::vector& res); 97 | 98 | /** 99 | * @brief Creates an execution context from a model path. 100 | * 101 | * @param modelPath Path to the model. 102 | * @return std::shared_ptr Shared pointer to the execution context. 103 | */ 104 | std::shared_ptr createExecutionContext(const std::string& modelPath); 105 | 106 | /** 107 | * @brief Performs inference on a video stream. 108 | * 109 | * @param videoPath Path to the video file. 110 | * @param enginePath Path to the TensorRT engine file. 111 | */ 112 | void inferVideo(const std::string& videoPath, const std::string& enginePath); 113 | 114 | /** 115 | * @brief Performs inference on an image. 116 | * 117 | * @param imagePath Path to the image file. 118 | * @param enginePath Path to the TensorRT engine file. 119 | */ 120 | void inferImage(const std::string& imagePath, const std::string& enginePath); 121 | 122 | /** 123 | * @brief Converts an ONNX model to a TensorRT engine. 124 | * 125 | * @param onnxFile Path to the ONNX file. 126 | * @param memorySize Size of the memory allocated for the engine. 127 | */ 128 | void convertOnnxToEngine(const std::string& onnxFile, int memorySize); 129 | 130 | private: 131 | Logger logger; ///< Logger instance for TensorRT. 132 | }; 133 | 134 | #endif // YOLOV10_HPP 135 | -------------------------------------------------------------------------------- /src/YOLOv10.cpp: -------------------------------------------------------------------------------- 1 | #include "YOLOv10.hpp" 2 | 3 | YOLOv10::YOLOv10() {} 4 | 5 | YOLOv10::~YOLOv10() {} 6 | 7 | void YOLOv10::preProcess(cv::Mat* img, int length, float* factor, std::vector& data) { 8 | // Create a new cv::Mat object for storing the image after conversion 9 | cv::Mat mat; 10 | 11 | // Get the dimensions and number of channels of the input image 12 | int rh = img->rows; // Height of the input image 13 | int rw = img->cols; // Width of the input image 14 | int rc = img->channels(); // Number of channels (e.g., 3 for RGB) 15 | 16 | // Convert the input image from BGR to RGB color space 17 | cv::cvtColor(*img, mat, cv::COLOR_BGR2RGB); 18 | 19 | // Determine the size of the new square image (largest dimension of the input image) 20 | int maxImageLength = rw > rh ? rw : rh; 21 | 22 | // Create a new square image filled with zeros (black) with dimensions maxImageLength x maxImageLength 23 | cv::Mat maxImage = cv::Mat::zeros(maxImageLength, maxImageLength, CV_8UC3); 24 | 25 | // Set all pixels to 255 (white) 26 | maxImage = maxImage * 255; 27 | 28 | // Define a Region of Interest (ROI) that covers the entire original image 29 | cv::Rect roi(0, 0, rw, rh); 30 | 31 | // Copy the original image into the ROI of the new square image 32 | mat.copyTo(cv::Mat(maxImage, roi)); 33 | 34 | // Create a new cv::Mat object for storing the resized image 35 | cv::Mat resizeImg; 36 | 37 | // Resize the square image to the specified dimensions (length x length) 38 | cv::resize(maxImage, resizeImg, cv::Size(length, length), 0.0f, 0.0f, cv::INTER_LINEAR); 39 | 40 | // Calculate the scaling factor and store it in the 'factor' variable 41 | *factor = (float)((float)maxImageLength / (float)length); 42 | 43 | // Convert the resized image to floating-point format with values in range [0, 1] 44 | resizeImg.convertTo(resizeImg, CV_32FC3, 1 / 255.0); 45 | 46 | // Update the height, width, and number of channels for the resized image 47 | rh = resizeImg.rows; 48 | rw = resizeImg.cols; 49 | rc = resizeImg.channels(); 50 | 51 | // Extract each channel of the resized image and store it in the 'data' vector 52 | for (int i = 0; i < rc; ++i) { 53 | // Extract the i-th channel and store it in the appropriate part of the 'data' vector 54 | cv::extractChannel(resizeImg, cv::Mat(rh, rw, CV_32FC1, data.data() + i * rh * rw), i); 55 | } 56 | } 57 | 58 | std::vector YOLOv10::postProcess(float* result, float factor, int outputLength) { 59 | // Vectors to store the detection results 60 | std::vector positionBoxes; // Stores bounding boxes of detected objects 61 | std::vector classIds; // Stores class IDs for detected objects 62 | std::vector confidences; // Stores confidence scores for detected objects 63 | 64 | // Process each detection result from the output 65 | for (int i = 0; i < outputLength; i++) { 66 | // Compute the starting index for the current detection result in the 'result' array 67 | int s = 6 * i; 68 | 69 | // Check if the confidence score of the detection is above a threshold (0.2 in this case) 70 | if ((float)result[s + 4] > 0.2) { 71 | // Extract the coordinates and dimensions of the bounding box (normalized values) 72 | float cx = result[s + 0]; // Center x-coordinate 73 | float cy = result[s + 1]; // Center y-coordinate 74 | float dx = result[s + 2]; // Bottom-right x-coordinate 75 | float dy = result[s + 3]; // Bottom-right y-coordinate 76 | 77 | // Convert normalized coordinates and dimensions to pixel values using the scaling factor 78 | int x = (int)((cx) * factor); // Top-left x-coordinate of the bounding box 79 | int y = (int)((cy) * factor); // Top-left y-coordinate of the bounding box 80 | int width = (int)((dx - cx) * factor); // Width of the bounding box 81 | int height = (int)((dy - cy) * factor); // Height of the bounding box 82 | 83 | // Create a cv::Rect object to represent the bounding box 84 | cv::Rect box(x, y, width, height); 85 | 86 | // Store the bounding box, class ID, and confidence score in the corresponding vectors 87 | positionBoxes.push_back(box); 88 | classIds.push_back((int)result[s + 5]); // Class ID is stored at position s + 5 in the 'result' array 89 | confidences.push_back((float)result[s + 4]); // Confidence score is stored at position s + 4 in the 'result' array 90 | } 91 | } 92 | 93 | // Vector to store the final detection results 94 | std::vector re; 95 | 96 | // Convert the extracted detection information into DetResult objects and store them in the 're' vector 97 | for (int i = 0; i < positionBoxes.size(); i++) { 98 | DetResult det(positionBoxes[i], confidences[i], classIds[i]); 99 | re.push_back(det); 100 | } 101 | 102 | // Return the vector of DetResult objects 103 | return re; 104 | } 105 | 106 | void YOLOv10::drawBbox(cv::Mat& img, std::vector& res) { 107 | // Iterate through each result in the 'res' vector 108 | for (size_t j = 0; j < res.size(); j++) { 109 | // Draw a rectangle around the detected object using the bounding box (bbox) 110 | cv::rectangle(img, res[j].bbox, cv::Scalar(255, 0, 255), 2); 111 | // Add text label and confidence score near the top-left corner of the bounding box 112 | cv::putText( 113 | img, // The image on which to draw 114 | std::to_string(res[j].label) + "-" + std::to_string(res[j].conf), // Text to display: label and confidence score 115 | cv::Point(res[j].bbox.x, res[j].bbox.y - 1), // Position of the text (slightly above the top-left corner of the bounding box) 116 | cv::FONT_HERSHEY_PLAIN, // Font type 117 | 1.2, // Font size 118 | cv::Scalar(0, 0, 255), // Text color (red) 119 | 2 // Thickness of the text 120 | ); 121 | } 122 | } 123 | 124 | std::shared_ptr YOLOv10::createExecutionContext(const std::string& modelPath) { 125 | // Open the model file in binary mode 126 | std::ifstream filePtr(modelPath, std::ios::binary); 127 | 128 | // Check if the file was opened successfully 129 | if (!filePtr.good()) { 130 | std::cerr << "File cannot be opened, please check the file!" << std::endl; 131 | return nullptr; // Return nullptr if the file cannot be opened 132 | } 133 | 134 | // Determine the size of the file 135 | size_t size = 0; 136 | filePtr.seekg(0, filePtr.end); // Move to the end of the file 137 | size = filePtr.tellg(); // Get the current position, which is the size of the file 138 | filePtr.seekg(0, filePtr.beg); // Move back to the beginning of the file 139 | 140 | // Allocate memory to hold the file contents 141 | char* modelStream = new char[size]; 142 | filePtr.read(modelStream, size); // Read the entire file into the allocated memory 143 | filePtr.close(); // Close the file after reading 144 | 145 | // Create an instance of nvinfer1::IRuntime 146 | nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger); 147 | if (!runtime) { 148 | std::cerr << "Failed to create runtime" << std::endl; 149 | delete[] modelStream; // Free the allocated memory 150 | return nullptr; // Return nullptr if the runtime creation fails 151 | } 152 | 153 | // Deserialize the model and create an ICudaEngine 154 | nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(modelStream, size); 155 | delete[] modelStream; // Free the allocated memory 156 | if (!engine) { 157 | std::cerr << "Failed to create engine" << std::endl; 158 | runtime->destroy(); // Clean up runtime 159 | return nullptr; // Return nullptr if the engine creation fails 160 | } 161 | 162 | // Create an execution context from the engine 163 | nvinfer1::IExecutionContext* context = engine->createExecutionContext(); 164 | if (!context) { 165 | std::cerr << "Failed to create execution context" << std::endl; 166 | engine->destroy(); // Clean up engine 167 | return nullptr; // Return nullptr if the execution context creation fails 168 | } 169 | 170 | // Return a shared pointer to the execution context with a custom deleter 171 | return std::shared_ptr(context, [](nvinfer1::IExecutionContext* ctx) { 172 | ctx->destroy(); // Clean up the execution context 173 | }); 174 | } 175 | 176 | void YOLOv10::inferVideo(const std::string& videoPath, const std::string& enginePath) { 177 | // Create a shared pointer to the execution context using the provided engine path 178 | std::shared_ptr context = createExecutionContext(enginePath); 179 | 180 | // Open the video file using OpenCV's VideoCapture 181 | cv::VideoCapture capture(videoPath); 182 | 183 | // Check if the video file was opened successfully 184 | if (!capture.isOpened()) { 185 | std::cerr << "ERROR: Could not open video file." << std::endl; 186 | return; // Exit if the video file cannot be opened 187 | } 188 | 189 | // Create a CUDA stream for asynchronous operations 190 | cudaStream_t stream; 191 | cudaStreamCreate(&stream); 192 | 193 | // Allocate device memory for input and output data 194 | void* inputSrcDevice; 195 | void* outputSrcDevice; 196 | cudaMalloc(&inputSrcDevice, 3 * 640 * 640 * sizeof(float)); // For storing input data (3 channels, 640x640 size) 197 | cudaMalloc(&outputSrcDevice, 1 * 300 * 6 * sizeof(float)); // For storing output data (300 detections, each with 6 values) 198 | 199 | // Create vectors to hold input and output data on the host (CPU) 200 | std::vector output_data(300 * 6); // Buffer to hold output data from the device 201 | std::vector inputData(640 * 640 * 3); // Buffer to hold input data (640x640 image with 3 channels) 202 | 203 | // Process video frames in a loop 204 | while (true) { 205 | cv::Mat frame; 206 | 207 | // Read the next frame from the video file 208 | if (!capture.read(frame)) { 209 | break; // Exit the loop if no more frames are available 210 | } 211 | 212 | float factor = 0; 213 | // Preprocess the frame (resize, normalize, etc.) and store the result in inputData 214 | preProcess(&frame, 640, &factor, inputData); 215 | 216 | // Copy the preprocessed input data from host to device memory 217 | cudaMemcpyAsync(inputSrcDevice, inputData.data(), 3 * 640 * 640 * sizeof(float), 218 | cudaMemcpyHostToDevice, stream); 219 | 220 | // Set up bindings for TensorRT inference (input and output tensors) 221 | void* bindings[] = { inputSrcDevice, outputSrcDevice }; 222 | 223 | // Perform inference using the TensorRT execution context 224 | context->enqueueV2(bindings, stream, nullptr); 225 | 226 | // Copy the output data from device to host memory 227 | cudaMemcpyAsync(output_data.data(), outputSrcDevice, 300 * 6 * sizeof(float), 228 | cudaMemcpyDeviceToHost, stream); 229 | 230 | // Wait for the CUDA stream operations to complete 231 | cudaStreamSynchronize(stream); 232 | 233 | // Post-process the output data to extract detection results 234 | std::vector result = postProcess(output_data.data(), factor, 300); 235 | 236 | // Draw bounding boxes and annotations on the frame 237 | drawBbox(frame, result); 238 | 239 | // Display the annotated frame in a window 240 | cv::imshow("RESULT", frame); 241 | 242 | // Wait for 10 milliseconds or until a key is pressed 243 | cv::waitKey(10); 244 | } 245 | 246 | // Destroy all OpenCV windows after processing is complete 247 | cv::destroyAllWindows(); 248 | 249 | // Free the allocated CUDA memory and destroy the CUDA stream 250 | cudaFree(inputSrcDevice); 251 | cudaFree(outputSrcDevice); 252 | cudaStreamDestroy(stream); 253 | } 254 | 255 | void YOLOv10::inferImage(const std::string& imagePath, const std::string& enginePath) { 256 | // Create an execution context from the TensorRT engine file 257 | std::shared_ptr context = createExecutionContext(enginePath); 258 | 259 | // Load the image from the specified path using OpenCV 260 | cv::Mat img = cv::imread(imagePath); 261 | 262 | // Check if the image was loaded successfully 263 | if (img.empty()) { 264 | std::cerr << "ERROR: Could not open or find the image." << std::endl; 265 | return; 266 | } 267 | 268 | // Preprocess the image: resize, normalize, etc. 269 | // `inputData` is the buffer where the preprocessed image data will be stored 270 | float factor = 0; // Factor for scaling the image, if needed 271 | std::vector inputData(640 * 640 * 3); // Buffer for the input data 272 | preProcess(&img, 640, &factor, inputData); 273 | 274 | // Create a CUDA stream for asynchronous operations 275 | cudaStream_t stream; 276 | cudaStreamCreate(&stream); 277 | 278 | // Allocate device memory for input and output data 279 | void* inputSrcDevice; 280 | void* outputSrcDevice; 281 | cudaMalloc(&inputSrcDevice, 3 * 640 * 640 * sizeof(float)); // Input buffer 282 | cudaMalloc(&outputSrcDevice, 1 * 300 * 6 * sizeof(float)); // Output buffer 283 | 284 | // Copy the preprocessed input data from host to device 285 | cudaMemcpyAsync(inputSrcDevice, inputData.data(), 3 * 640 * 640 * sizeof(float), 286 | cudaMemcpyHostToDevice, stream); 287 | 288 | // Set up the input and output bindings for the TensorRT execution context 289 | void* bindings[] = { inputSrcDevice, outputSrcDevice }; 290 | 291 | // Execute the TensorRT inference 292 | context->enqueueV2(bindings, stream, nullptr); 293 | 294 | // Allocate buffer to store the output data 295 | std::vector output_data(300 * 6); 296 | cudaMemcpyAsync(output_data.data(), outputSrcDevice, 300 * 6 * sizeof(float), 297 | cudaMemcpyDeviceToHost, stream); 298 | 299 | // Synchronize the stream to ensure all operations are complete 300 | cudaStreamSynchronize(stream); 301 | 302 | // Post-process the output data to extract detection results 303 | std::vector result = postProcess(output_data.data(), factor, 300); 304 | 305 | // Draw bounding boxes on the image 306 | drawBbox(img, result); 307 | 308 | // Display the image with results using OpenCV 309 | cv::imshow("RESULT", img); 310 | cv::waitKey(0); // Wait indefinitely for a key press 311 | 312 | // Free device memory 313 | cudaFree(inputSrcDevice); 314 | cudaFree(outputSrcDevice); 315 | 316 | // Destroy the CUDA stream 317 | cudaStreamDestroy(stream); 318 | } 319 | 320 | void YOLOv10::convertOnnxToEngine(const std::string& onnxFile, int memorySize) { 321 | // Check if the ONNX file exists 322 | std::ifstream ifile(onnxFile); 323 | if (!ifile) { 324 | // Print an error message if the file cannot be opened 325 | std::cerr << "Error: Could not open file " << onnxFile << std::endl; 326 | return; 327 | } 328 | 329 | // Extract the directory path and file name from the given ONNX file path 330 | std::string path(onnxFile); 331 | std::string::size_type iPos = (path.find_last_of('\\') + 1) == 0 ? path.find_last_of('/') + 1 : path.find_last_of('\\') + 1; 332 | std::string modelPath = path.substr(0, iPos); // Directory path 333 | std::string modelName = path.substr(iPos, path.length() - iPos); // File name with extension 334 | std::string modelName_ = modelName.substr(0, modelName.rfind(".")); // File name without extension 335 | std::string engineFile = modelPath + modelName_ + ".engine"; // Output file path for TensorRT engine 336 | 337 | // Create a TensorRT builder 338 | nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger); 339 | // Define the network with explicit batch size 340 | const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 341 | nvinfer1::INetworkDefinition* network = builder->createNetworkV2(explicitBatch); 342 | 343 | // Create an ONNX parser for the network 344 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger); 345 | 346 | // Parse the ONNX file into the network definition 347 | if (!parser->parseFromFile(onnxFile.c_str(), 2)) { 348 | // Print errors if parsing fails 349 | std::cerr << "Error: Failed to parse ONNX model from file: " << onnxFile << std::endl; 350 | for (int i = 0; i < parser->getNbErrors(); ++i) { 351 | std::cerr << "Parser error: " << parser->getError(i)->desc() << std::endl; 352 | } 353 | return; 354 | } 355 | std::cout << "TensorRT loaded ONNX model successfully." << std::endl; 356 | 357 | // Create a builder configuration 358 | nvinfer1::IBuilderConfig* config = builder->createBuilderConfig(); 359 | // Set the maximum workspace size for TensorRT engine (in bytes) 360 | config->setMaxWorkspaceSize(1024 * 1024 * memorySize); 361 | // Set the builder flag to enable FP16 precision 362 | config->setFlag(nvinfer1::BuilderFlag::kFP16); 363 | 364 | // Build the TensorRT engine with the network and configuration 365 | nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); 366 | 367 | // Try to save the built engine to a file 368 | std::cout << "Trying to save engine file now..." << std::endl; 369 | std::ofstream filePtr(engineFile, std::ios::binary); 370 | if (!filePtr) { 371 | // Print an error message if the engine file cannot be opened 372 | std::cerr << "Error: Could not open plan output file: " << engineFile << std::endl; 373 | return; 374 | } 375 | 376 | // Serialize the engine to a memory stream and write it to the file 377 | nvinfer1::IHostMemory* modelStream = engine->serialize(); 378 | filePtr.write(reinterpret_cast(modelStream->data()), modelStream->size()); 379 | 380 | // Clean up resources 381 | modelStream->destroy(); 382 | engine->destroy(); 383 | network->destroy(); 384 | parser->destroy(); 385 | 386 | // Print a success message 387 | std::cout << "Converted ONNX model to TensorRT engine model successfully!" << std::endl; 388 | } -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "YOLOv10.hpp" 2 | #include 3 | 4 | int main(int argc, char* argv[]) { 5 | 6 | // Ensure that the correct number of arguments are provided 7 | const std::string RED_COLOR = "\033[31m"; 8 | const std::string GREEN_COLOR = "\033[32m"; 9 | const std::string YELLOW_COLOR = "\033[33m"; 10 | const std::string RESET_COLOR = "\033[0m"; 11 | 12 | if (argc < 4 || argc > 5) { 13 | std::cerr << RED_COLOR << "Usage: " << RESET_COLOR << argv[0] << " [onnx_path]" << std::endl; 14 | std::cerr << YELLOW_COLOR << " - Mode of operation: 'convert', 'infer_video', or 'infer_image'" << RESET_COLOR << std::endl; 15 | std::cerr << YELLOW_COLOR << " - Path to the input video/image or ONNX model" << RESET_COLOR << std::endl; 16 | std::cerr << YELLOW_COLOR << " - Path to the TensorRT engine file" << RESET_COLOR << std::endl; 17 | std::cerr << YELLOW_COLOR << " [onnx_path] - Path to the ONNX model (only for 'convert' mode)" << RESET_COLOR << std::endl; 18 | return 1; 19 | } 20 | 21 | std::string mode = argv[1]; 22 | std::string inputPath = argv[2]; 23 | std::string enginePath = argv[3]; 24 | std::string onnxPath; 25 | 26 | if (mode == "convert") { 27 | if (argc != 4) { 28 | std::cerr << "Usage for conversion: " << argv[0] << " convert " << std::endl; 29 | return 1; 30 | } 31 | onnxPath = inputPath; // Using inputPath as ONNX path for conversion mode 32 | } else if (mode == "infer_video") { 33 | if (argc != 4) { 34 | std::cerr << "Usage for video inference: " << argv[0] << " infer_video " << std::endl; 35 | return 1; 36 | } 37 | } else if (mode == "infer_image") { 38 | if (argc != 4) { 39 | std::cerr << "Usage for image inference: " << argv[0] << " infer_image " << std::endl; 40 | return 1; 41 | } 42 | } else { 43 | std::cerr << "Invalid mode. Use 'convert' to convert ONNX model, 'infer_video' to perform inference on video, or 'infer_image' to perform inference on image." << std::endl; 44 | return 1; 45 | } 46 | 47 | YOLOv10 yolov10; 48 | 49 | if (mode == "convert") { 50 | yolov10.convertOnnxToEngine(onnxPath, 50); // Convert ONNX model to TensorRT engine 51 | } else if (mode == "infer_video") { 52 | yolov10.inferVideo(inputPath, enginePath); // Perform inference on video 53 | } else if (mode == "infer_image") { 54 | yolov10.inferImage(inputPath, enginePath); // Perform inference on image 55 | } 56 | 57 | return 0; 58 | } 59 | --------------------------------------------------------------------------------