634 |
635 | This program is free software: you can redistribute it and/or modify
636 | it under the terms of the GNU Affero General Public License as published
637 | by the Free Software Foundation, either version 3 of the License, or
638 | (at your option) any later version.
639 |
640 | This program is distributed in the hope that it will be useful,
641 | but WITHOUT ANY WARRANTY; without even the implied warranty of
642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643 | GNU Affero General Public License for more details.
644 |
645 | You should have received a copy of the GNU Affero General Public License
646 | along with this program. If not, see .
647 |
648 | Also add information on how to contact you by electronic and paper mail.
649 |
650 | If your software can interact with users remotely through a computer
651 | network, you should also make sure that it provides a way for users to
652 | get its source. For example, if your program is a web application, its
653 | interface could display a "Source" link that leads users to an archive
654 | of the code. There are many ways you could offer source, and different
655 | solutions will be better for different programs; see section 13 for the
656 | specific requirements.
657 |
658 | You should also get your employer (if you work as a programmer) or school,
659 | if any, to sign a "copyright disclaimer" for the program, if necessary.
660 | For more information on this, and how to apply and follow the GNU AGPL, see
661 | .
662 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # YOLOv12 TensorRT CPP
2 |
3 | 
4 | 
5 | 
6 | 
7 | 
8 | [](https://colab.research.google.com/drive/1kqo6J6jxTiGFjpCHrv6HA5InOIp07Awt?usp=sharing)
9 | 
10 |
11 | ## Overview
12 |
13 | This project is a high-performance **C++ implementation** for real-time object detection using **YOLOv12**. It leverages **TensorRT** for optimized inference and **CUDA** for accelerated processing, enabling efficient detection on both images and videos. Designed for maximum speed and accuracy, this implementation ensures seamless integration with YOLOv12 models, making it suitable for deployment in research, production, and real-time applications.
14 |
15 | **Google Colab Support**: To make GPU-based inference more accessible, a **fully configured Google Colab notebook** is provided. This notebook allows users to **run the entire project from start to finish** using a **Google Colab T4 GPU**, including compiling and executing C++ code directly in Colab. This is especially helpful for those who **struggle with local GPU availability**.
16 |
17 | ## **Note:** The Google Colab notebook is not meant for performance testing, as the performance will be poor. It is intended for learning how to integrate C++, TensorRT, and CUDA.
18 |
19 |
20 |
21 |
22 |
23 | Comparison with popular methods in terms of latency-accuracy (left) and FLOPs-accuracy (right) trade-offs
24 |
25 |
26 |
27 |
28 | ## Output
29 |
30 | Below are examples of the output for both image and video inference:
31 |
32 |
33 |
Image Output
34 |

35 |
36 |
Video Output
37 |

38 |
39 |
40 |
41 |
42 | ## Features
43 |
44 | - **TensorRT Integration**: Optimized deep learning inference using **NVIDIA TensorRT**, ensuring high-speed execution on **GPU**.
45 | - **Efficient Memory Management**: Uses **CUDA buffers** and **TensorRT engine caching** for improved performance.
46 | - **Real-Time Inference**: Supports **image** and **video** processing, allowing smooth detection across frames.
47 | - **Custom Preprocessing & Postprocessing**: Handles **image normalization, tensor conversion, and result decoding** directly in CUDA for minimal overhead.
48 | - **High-Performance Video Processing**: Efficiently processes video streams using OpenCV while maintaining low-latency inference with TensorRT.
49 | - **Google Colab Support**: A **ready-to-use Google Colab Notebook** is provided to **run the project in Google Colab**, enabling easy setup and execution without requiring a local GPU.
50 |
51 | ## Requirements
52 |
53 | Before building the project, ensure that the following dependencies are installed on your system:
54 |
55 | - **C++ Compiler**: Compatible with **C++17** or higher.
56 | - **CMake**: Version **3.12** or higher.
57 | - **CUDA**: Version **12.4** .
58 | - **TensorRT**: Tested with **TensorRT 10.8.0** for high-performance inference.
59 | - **OpenCV**: Version **4.5.4** or higher for image and video processing.
60 |
61 | ## Installation And Usage
62 |
63 | ### 1- Generate ONNX models
64 | Generate the onnx version of the yolov12 model, You can use the same way defined in this repo [YOLOv12 ONNX CPP](https://github.com/mohamedsamirx/YOLOv12-ONNX-CPP.git).
65 |
66 | ### 2- Clone Repository
67 | Clone the repository to your local machine:
68 |
69 | ```bash
70 | git https://github.com/mohamedsamirx/YOLOv12-TensorRT-CPP.git
71 | cd YOLOv12-TensorRT-CPP
72 | ```
73 |
74 | ### 3- Build the C++ Code
75 | **Ensure that OpenCV and TensorRT are installed. Set the correct paths for these libraries in the** `CMakeLists.txt` **file.**
76 |
77 | ```bash
78 | mkdir build && cd build
79 | cmake ..
80 | cmake --build . --config Release
81 | ```
82 |
83 | ### 4- Create a TensorRT Engine
84 |
85 | Convert the ONNX model to a TensorRT engine:
86 | From the build directory run:
87 | ```bash
88 | ./yolov12-tensorrt yolo12n.onnx ""
89 | ```
90 |
91 | ### 5- Run Inference on an Image
92 |
93 | Perform object detection on an image:
94 |
95 | ```bash
96 | ./yolov12-tensorrt yolo12n.engine "zidane.jpg"
97 | ```
98 |
99 | ### 6- Run Inference on a Video
100 |
101 | Perform object detection on a video:
102 |
103 | ```bash
104 | ./yolov12-tensorrt yolo12n.engine "road.mp4"
105 | ```
106 |
107 | ### License
108 | This project is licensed under the GNU Affero General Public License v3.0 (AGPL-3.0). See the [LICENSE](LICENSE) file for details.
109 |
110 | ### Acknowledgment
111 |
112 | - [https://github.com/spacewalk01/yolov11-tensorrt](https://github.com/spacewalk01/yolov11-tensorrt)
113 | - [https://github.com/sunsmarterjie/yolov12](https://github.com/sunsmarterjie/yolov12)
--------------------------------------------------------------------------------
/asset/000010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mohamedsamirx/YOLOv12-TensorRT-CPP/9013b40929e2b10c34cae3fed9f82c2744ae1ab9/asset/000010.png
--------------------------------------------------------------------------------
/asset/frame_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mohamedsamirx/YOLOv12-TensorRT-CPP/9013b40929e2b10c34cae3fed9f82c2744ae1ab9/asset/frame_0.jpg
--------------------------------------------------------------------------------
/asset/output_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mohamedsamirx/YOLOv12-TensorRT-CPP/9013b40929e2b10c34cae3fed9f82c2744ae1ab9/asset/output_gif.gif
--------------------------------------------------------------------------------
/main.cpp:
--------------------------------------------------------------------------------
1 | #ifdef _WIN32
2 | #include
3 | #else
4 | #include
5 | #include
6 | #endif
7 |
8 | #include
9 | #include
10 | #include "YOLOv12.h"
11 |
12 |
13 | bool IsPathExist(const string& path) {
14 | #ifdef _WIN32
15 | DWORD fileAttributes = GetFileAttributesA(path.c_str());
16 | return (fileAttributes != INVALID_FILE_ATTRIBUTES);
17 | #else
18 | return (access(path.c_str(), F_OK) == 0);
19 | #endif
20 | }
21 |
22 | bool IsFile(const string& path) {
23 | if (!IsPathExist(path)) {
24 | printf("%s:%d %s not exist\n", __FILE__, __LINE__, path.c_str());
25 | return false;
26 | }
27 |
28 | #ifdef _WIN32
29 | DWORD fileAttributes = GetFileAttributesA(path.c_str());
30 | return ((fileAttributes != INVALID_FILE_ATTRIBUTES) && ((fileAttributes & FILE_ATTRIBUTE_DIRECTORY) == 0));
31 | #else
32 | struct stat buffer;
33 | return (stat(path.c_str(), &buffer) == 0 && S_ISREG(buffer.st_mode));
34 | #endif
35 | }
36 |
37 | /**
38 | * @brief Setting up Tensorrt logger
39 | */
40 | class Logger : public nvinfer1::ILogger {
41 | void log(Severity severity, const char* msg) noexcept override {
42 | // Only output logs with severity greater than warning
43 | if (severity <= Severity::kWARNING)
44 | std::cout << msg << std::endl;
45 | }
46 | }logger;
47 |
48 | int main(int argc, char** argv){
49 |
50 | const string engine_file_path{ argv[1] };
51 | const string path{ argv[2] };
52 | vector imagePathList;
53 | bool isVideo{ false };
54 |
55 | assert(argc == 3);
56 | if (IsFile(path)){
57 | string suffix = path.substr(path.find_last_of('.') + 1);
58 | if (suffix == "jpg" || suffix == "jpeg" || suffix == "png"){
59 | imagePathList.push_back(path);
60 | }
61 | else if (suffix == "mp4" || suffix == "avi" || suffix == "m4v" || suffix == "mpeg" || suffix == "mov" || suffix == "mkv" || suffix == "webm"){
62 | isVideo = true;
63 | }
64 | else {
65 | printf("suffix %s is wrong !!!\n", suffix.c_str());
66 | abort();
67 | }
68 | }
69 | else if (IsPathExist(path)){
70 | glob(path + "/*.jpg", imagePathList);
71 | }
72 |
73 | YOLOv12 model(engine_file_path, logger);
74 |
75 | if (engine_file_path.find(".onnx") != std::string::npos){
76 | return 0;
77 | }
78 |
79 | if (isVideo) {
80 | cout << "Opening video: " << path << endl;
81 | cv::VideoCapture cap(path);
82 |
83 | if (!cap.isOpened()) {
84 | cerr << "Error: Cannot open video file!" << endl;
85 | return 0 ;
86 | }
87 |
88 | // Get frame width, height, and FPS
89 | int frameWidth = static_cast(cap.get(cv::CAP_PROP_FRAME_WIDTH));
90 | int frameHeight = static_cast(cap.get(cv::CAP_PROP_FRAME_HEIGHT));
91 | int fps = static_cast(cap.get(cv::CAP_PROP_FPS));
92 |
93 | // Define the codec and create VideoWriter object
94 | cv::VideoWriter videoWriter("output.mp4",
95 | cv::VideoWriter::fourcc('m', 'p', '4', 'v'),
96 | fps,
97 | cv::Size(frameWidth, frameHeight));
98 |
99 | if (!videoWriter.isOpened()) {
100 | cerr << "Error: Cannot open VideoWriter!" << endl;
101 | return 0 ;
102 | }
103 |
104 | while (true) {
105 | cv::Mat image;
106 | cap >> image;
107 |
108 | if (image.empty()) {
109 | break;
110 | }
111 |
112 | vector objects;
113 |
114 | model.preprocess(image);
115 |
116 | auto start = std::chrono::system_clock::now();
117 | model.infer();
118 | auto end = std::chrono::system_clock::now();
119 |
120 | model.postprocess(objects);
121 | model.draw(image, objects);
122 |
123 | auto tc = (double)std::chrono::duration_cast(end - start).count() / 1000.;
124 | printf("Cost %2.4lf ms\n", tc);
125 |
126 | // Write processed frame to output video
127 | videoWriter.write(image);
128 |
129 | if (cv::waitKey(1) == 27) { // Press 'ESC' to exit early
130 | break;
131 | }
132 | }
133 |
134 | // Release resources
135 | cap.release();
136 | videoWriter.release();
137 | cv::destroyAllWindows();
138 | }
139 |
140 |
141 | else {
142 | // path to folder saves images
143 | for (const auto& imagePath : imagePathList){
144 | // open image
145 | Mat image = imread(imagePath);
146 | if (image.empty()){
147 | cerr << "Error reading image: " << imagePath << endl;
148 | continue;
149 | }
150 |
151 | vector objects;
152 | model.preprocess(image);
153 |
154 | auto start = std::chrono::system_clock::now();
155 | model.infer();
156 | auto end = std::chrono::system_clock::now();
157 |
158 | model.postprocess(objects);
159 | model.draw(image, objects);
160 |
161 | auto tc = (double)std::chrono::duration_cast(end - start).count() / 1000.;
162 | printf("cost %2.4lf ms\n", tc);
163 |
164 | model.draw(image, objects);
165 |
166 | imshow("Result", image);
167 |
168 | waitKey(0);
169 | }
170 | }
171 | return 0;
172 | }
--------------------------------------------------------------------------------
/src/YOLOv12.cpp:
--------------------------------------------------------------------------------
1 | #include "YOLOv12.h"
2 | #include "logging.h"
3 | #include "cuda_utils.h"
4 | #include "macros.h"
5 | #include "preprocess.h"
6 | #include
7 | #include "common.h"
8 | #include
9 | #include
10 |
11 |
12 | static Logger logger;
13 | #define isFP16 true
14 | #define warmup true
15 |
16 |
17 | YOLOv12::YOLOv12(string model_path, nvinfer1::ILogger& logger){
18 | // Deserialize an engine
19 | if (model_path.find(".onnx") == std::string::npos){
20 | init(model_path, logger);
21 | }
22 |
23 | // Build an engine from an onnx model
24 | else{
25 | build(model_path, logger);
26 | saveEngine(model_path);
27 | }
28 |
29 | #if NV_TENSORRT_MAJOR < 8
30 | // Define input dimensions
31 | auto input_dims = engine->getBindingDimensions(0);
32 | input_h = input_dims.d[2];
33 | input_w = input_dims.d[3];
34 | #else
35 | auto input_dims = engine->getTensorShape(engine->getIOTensorName(0));
36 | input_h = input_dims.d[2];
37 | input_w = input_dims.d[3];
38 | #endif
39 | }
40 |
41 |
42 | void YOLOv12::init(std::string engine_path, nvinfer1::ILogger& logger){
43 | // Read the engine file
44 | ifstream engineStream(engine_path, ios::binary);
45 | engineStream.seekg(0, ios::end);
46 | const size_t modelSize = engineStream.tellg();
47 | engineStream.seekg(0, ios::beg);
48 | unique_ptr engineData(new char[modelSize]);
49 | engineStream.read(engineData.get(), modelSize);
50 | engineStream.close();
51 |
52 | // Deserialize the tensorrt engine
53 | runtime = createInferRuntime(logger);
54 | engine = runtime->deserializeCudaEngine(engineData.get(), modelSize);
55 | context = engine->createExecutionContext();
56 |
57 |
58 | #if NV_TENSORRT_MAJOR < 8
59 | input_h = engine->getBindingDimensions(0).d[2];
60 | input_w = engine->getBindingDimensions(0).d[3];
61 | detection_attribute_size = engine->getBindingDimensions(1).d[1];
62 | num_detections = engine->getBindingDimensions(1).d[2];
63 | #else
64 | auto input_name = engine->getIOTensorName(0);
65 | auto output_name = engine->getIOTensorName(1);
66 |
67 | auto input_dims = engine->getTensorShape(input_name);
68 | auto output_dims = engine->getTensorShape(output_name);
69 |
70 | input_h = input_dims.d[2];
71 | input_w = input_dims.d[3];
72 | detection_attribute_size = output_dims.d[1];
73 | num_detections = output_dims.d[2];
74 | #endif
75 | num_classes = detection_attribute_size - 4;
76 |
77 |
78 |
79 |
80 | // Initialize input buffers
81 | cpu_output_buffer = new float[detection_attribute_size * num_detections];
82 | CUDA_CHECK(cudaMalloc(&gpu_buffers[0], 3 * input_w * input_h * sizeof(float)));
83 |
84 | // Initialize output buffer
85 | CUDA_CHECK(cudaMalloc(&gpu_buffers[1], detection_attribute_size * num_detections * sizeof(float)));
86 |
87 | cuda_preprocess_init(MAX_IMAGE_SIZE);
88 |
89 | CUDA_CHECK(cudaStreamCreate(&stream));
90 |
91 |
92 | if (warmup) {
93 | for (int i = 0; i < 10; i++) {
94 | this->infer();
95 | }
96 | printf("model warmup 10 times\n");
97 | }
98 | }
99 |
100 | YOLOv12::~YOLOv12(){
101 | // Release stream and buffers
102 | CUDA_CHECK(cudaStreamSynchronize(stream));
103 | CUDA_CHECK(cudaStreamDestroy(stream));
104 | for (int i = 0; i < 2; i++)
105 | CUDA_CHECK(cudaFree(gpu_buffers[i]));
106 | delete[] cpu_output_buffer;
107 |
108 | // Destroy the engine
109 | cuda_preprocess_destroy();
110 | delete context;
111 | delete engine;
112 | delete runtime;
113 | }
114 |
115 | void YOLOv12::preprocess(Mat& image) {
116 | // Preprocessing data on gpu
117 | cuda_preprocess(image.ptr(), image.cols, image.rows, gpu_buffers[0], input_w, input_h, stream);
118 | CUDA_CHECK(cudaStreamSynchronize(stream));
119 | }
120 |
121 | void YOLOv12::infer(){
122 | // Register the input and output buffers
123 | const char* input_name = engine->getIOTensorName(0);
124 | const char* output_name = engine->getIOTensorName(1);
125 |
126 | // Set the input tensor address
127 | context->setTensorAddress(input_name, gpu_buffers[0]);
128 | context->setTensorAddress(output_name, gpu_buffers[1]);
129 |
130 | #if NV_TENSORRT_MAJOR < 10
131 | context->enqueueV2((void**)gpu_buffers, stream, nullptr);
132 | #else
133 | this->context->enqueueV3(this->stream);
134 | #endif
135 | }
136 |
137 | void YOLOv12::postprocess(vector& output){
138 | // Memcpy from device output buffer to host output buffer
139 | CUDA_CHECK(cudaMemcpyAsync(cpu_output_buffer, gpu_buffers[1], num_detections * detection_attribute_size * sizeof(float), cudaMemcpyDeviceToHost, stream));
140 | CUDA_CHECK(cudaStreamSynchronize(stream));
141 |
142 | vector boxes;
143 | vector class_ids;
144 | vector confidences;
145 |
146 | const Mat det_output(detection_attribute_size, num_detections, CV_32F, cpu_output_buffer);
147 |
148 | for (int i = 0; i < det_output.cols; ++i) {
149 | const Mat classes_scores = det_output.col(i).rowRange(4, 4 + num_classes);
150 | Point class_id_point;
151 | double score;
152 | minMaxLoc(classes_scores, nullptr, &score, nullptr, &class_id_point);
153 |
154 | if (score > conf_threshold) {
155 | const float cx = det_output.at(0, i);
156 | const float cy = det_output.at(1, i);
157 | const float ow = det_output.at(2, i);
158 | const float oh = det_output.at(3, i);
159 | Rect box;
160 | box.x = static_cast((cx - 0.5 * ow));
161 | box.y = static_cast((cy - 0.5 * oh));
162 | box.width = static_cast(ow);
163 | box.height = static_cast(oh);
164 |
165 | boxes.push_back(box);
166 | class_ids.push_back(class_id_point.y);
167 | confidences.push_back(score);
168 | }
169 | }
170 |
171 | vector nms_result;
172 | dnn::NMSBoxes(boxes, confidences, conf_threshold, nms_threshold, nms_result);
173 |
174 | for (int i = 0; i < nms_result.size(); i++){
175 | Detection result;
176 | int idx = nms_result[i];
177 | result.class_id = class_ids[idx];
178 | result.conf = confidences[idx];
179 | result.bbox = boxes[idx];
180 | output.push_back(result);
181 | }
182 | }
183 |
184 | void YOLOv12::build(std::string onnxPath, nvinfer1::ILogger& logger){
185 | auto builder = createInferBuilder(logger);
186 | const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
187 | INetworkDefinition* network = builder->createNetworkV2(explicitBatch);
188 | IBuilderConfig* config = builder->createBuilderConfig();
189 |
190 | if (isFP16){
191 | config->setFlag(BuilderFlag::kFP16);
192 | }
193 |
194 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger);
195 | bool parsed = parser->parseFromFile(onnxPath.c_str(), static_cast(nvinfer1::ILogger::Severity::kINFO));
196 | IHostMemory* plan{ builder->buildSerializedNetwork(*network, *config) };
197 |
198 | runtime = createInferRuntime(logger);
199 |
200 | engine = runtime->deserializeCudaEngine(plan->data(), plan->size());
201 |
202 | context = engine->createExecutionContext();
203 |
204 | delete network;
205 | delete config;
206 | delete parser;
207 | delete plan;
208 | }
209 |
210 | bool YOLOv12::saveEngine(const std::string& onnxpath){
211 | // Create an engine path from onnx path
212 | std::string engine_path;
213 | size_t dotIndex = onnxpath.find_last_of(".");
214 | if (dotIndex != std::string::npos){
215 | engine_path = onnxpath.substr(0, dotIndex) + ".engine";
216 | }
217 | else{
218 | return false;
219 | }
220 |
221 | // Save the engine to the path
222 | if (engine){
223 | nvinfer1::IHostMemory* data = engine->serialize();
224 | std::ofstream file;
225 | file.open(engine_path, std::ios::binary | std::ios::out);
226 | if (!file.is_open()){
227 | std::cout << "Create engine file" << engine_path << " failed" << std::endl;
228 | return 0;
229 | }
230 | file.write((const char*)data->data(), data->size());
231 | file.close();
232 |
233 | delete data;
234 | }
235 | return true;
236 | }
237 |
238 | void YOLOv12::draw(Mat& image, const vector& output){
239 | const float ratio_h = input_h / (float)image.rows;
240 | const float ratio_w = input_w / (float)image.cols;
241 |
242 | for (int i = 0; i < output.size(); i++){
243 | auto detection = output[i];
244 | auto box = detection.bbox;
245 | auto class_id = detection.class_id;
246 | auto conf = detection.conf;
247 | cv::Scalar color = cv::Scalar(COLORS[class_id][0], COLORS[class_id][1], COLORS[class_id][2]);
248 |
249 | if (ratio_h > ratio_w){
250 | box.x = box.x / ratio_w;
251 | box.y = (box.y - (input_h - ratio_w * image.rows) / 2) / ratio_w;
252 | box.width = box.width / ratio_w;
253 | box.height = box.height / ratio_w;
254 | }
255 | else{
256 | box.x = (box.x - (input_w - ratio_h * image.cols) / 2) / ratio_h;
257 | box.y = box.y / ratio_h;
258 | box.width = box.width / ratio_h;
259 | box.height = box.height / ratio_h;
260 | }
261 |
262 | rectangle(image, Point(box.x, box.y), Point(box.x + box.width, box.y + box.height), color, 3);
263 |
264 | // Detection box text
265 | string class_string = CLASS_NAMES[class_id] + ' ' + to_string(conf).substr(0, 4);
266 | Size text_size = getTextSize(class_string, FONT_HERSHEY_DUPLEX, 1, 2, 0);
267 | Rect text_rect(box.x, box.y - 40, text_size.width + 10, text_size.height + 20);
268 | rectangle(image, text_rect, color, FILLED);
269 | putText(image, class_string, Point(box.x + 5, box.y - 10), FONT_HERSHEY_DUPLEX, 1, Scalar(0, 0, 0), 2, 0);
270 | }
271 | }
--------------------------------------------------------------------------------
/src/YOLOv12.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "NvInfer.h" // TensorRT library for high-performance inference
4 | #include // OpenCV for image processing
5 |
6 | using namespace nvinfer1; // Namespace for TensorRT
7 | using namespace std; // Use standard library namespace
8 | using namespace cv; // Use OpenCV namespace
9 |
10 | // Struct to store detection results
11 | struct Detection {
12 | float conf; // Confidence score of the detection
13 | int class_id; // Class ID of the detected object (e.g., person, car, etc.)
14 | Rect bbox; // Bounding box coordinates around the detected object
15 | };
16 |
17 | // Main class for the YOLOv12 model
18 | class YOLOv12 {
19 |
20 | public:
21 | // Constructor: Loads the TensorRT engine and initializes the model
22 | YOLOv12(string model_path, nvinfer1::ILogger& logger);
23 |
24 | // Destructor: Cleans up resources used by the model
25 | ~YOLOv12();
26 |
27 | // Preprocess the input image to match the model's input format
28 | void preprocess(Mat& image);
29 |
30 | // Run inference on the preprocessed image
31 | void infer();
32 |
33 | // Postprocess the model's output to extract detection results
34 | void postprocess(vector& output);
35 |
36 | // Draw bounding boxes and labels on the original image
37 | void draw(Mat& image, const vector& output);
38 |
39 | private:
40 | // Initialize the TensorRT engine from a serialized model file
41 | void init(std::string engine_path, nvinfer1::ILogger& logger);
42 |
43 | // Device (GPU) buffers for input and output
44 | float* gpu_buffers[2]; //!< Input and output buffers allocated on the GPU
45 |
46 | // Host (CPU) buffer for storing inference output
47 | float* cpu_output_buffer;
48 |
49 | // CUDA stream for asynchronous execution
50 | cudaStream_t stream;
51 |
52 | // TensorRT runtime for deserializing the engine from file
53 | IRuntime* runtime;
54 |
55 | // TensorRT engine used to execute the network
56 | ICudaEngine* engine;
57 |
58 | // Execution context for running inference with the engine
59 | IExecutionContext* context;
60 |
61 | // Model parameters
62 | int input_w; // Input image width expected by the model
63 | int input_h; // Input image height expected by the model
64 | int num_detections; // Number of detections output by the model
65 | int detection_attribute_size; // Attributes (e.g., bbox, class) per detection
66 | int num_classes = 80; // Number of classes (e.g., COCO dataset has 80 classes)
67 |
68 | // Maximum supported image size (used for memory allocation checks)
69 | const int MAX_IMAGE_SIZE = 4096 * 4096;
70 |
71 | // Confidence threshold for filtering detections
72 | float conf_threshold = 0.3f;
73 |
74 | // Non-Maximum Suppression (NMS) threshold to remove duplicate boxes
75 | float nms_threshold = 0.4f;
76 |
77 | // Colors for drawing bounding boxes for each class
78 | vector colors;
79 |
80 | // Build TensorRT engine from an ONNX model file (if applicable)
81 | void build(std::string onnxPath, nvinfer1::ILogger& logger);
82 |
83 | // Save the built TensorRT engine to a file
84 | bool saveEngine(const std::string& filename);
85 | };
86 |
--------------------------------------------------------------------------------
/src/common.h:
--------------------------------------------------------------------------------
1 | const std::vector CLASS_NAMES = {
2 | "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
3 | "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench",
4 | "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
5 | "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie",
6 | "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
7 | "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup",
8 | "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich",
9 | "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake",
10 | "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv",
11 | "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
12 | "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
13 | "teddy bear", "hair drier", "toothbrush" };
14 |
15 | const std::vector> COLORS = {
16 | {0, 114, 189}, {217, 83, 25}, {237, 177, 32}, {126, 47, 142}, {119, 172, 48}, {77, 190, 238},
17 | {162, 20, 47}, {76, 76, 76}, {153, 153, 153}, {255, 0, 0}, {255, 128, 0}, {191, 191, 0},
18 | {0, 255, 0}, {0, 0, 255}, {170, 0, 255}, {85, 85, 0}, {85, 170, 0}, {85, 255, 0},
19 | {170, 85, 0}, {170, 170, 0}, {170, 255, 0}, {255, 85, 0}, {255, 170, 0}, {255, 255, 0},
20 | {0, 85, 128}, {0, 170, 128}, {0, 255, 128}, {85, 0, 128}, {85, 85, 128}, {85, 170, 128},
21 | {85, 255, 128}, {170, 0, 128}, {170, 85, 128}, {170, 170, 128}, {170, 255, 128}, {255, 0, 128},
22 | {255, 85, 128}, {255, 170, 128}, {255, 255, 128}, {0, 85, 255}, {0, 170, 255}, {0, 255, 255},
23 | {85, 0, 255}, {85, 85, 255}, {85, 170, 255}, {85, 255, 255}, {170, 0, 255}, {170, 85, 255},
24 | {170, 170, 255}, {170, 255, 255}, {255, 0, 255}, {255, 85, 255}, {255, 170, 255}, {85, 0, 0},
25 | {128, 0, 0}, {170, 0, 0}, {212, 0, 0}, {255, 0, 0}, {0, 43, 0}, {0, 85, 0},
26 | {0, 128, 0}, {0, 170, 0}, {0, 212, 0}, {0, 255, 0}, {0, 0, 43}, {0, 0, 85},
27 | {0, 0, 128}, {0, 0, 170}, {0, 0, 212}, {0, 0, 255}, {0, 0, 0}, {36, 36, 36},
28 | {73, 73, 73}, {109, 109, 109}, {146, 146, 146}, {182, 182, 182}, {219, 219, 219}, {0, 114, 189},
29 | {80, 183, 189}, {128, 128, 0} };
--------------------------------------------------------------------------------
/src/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef TRTX_CUDA_UTILS_H_
2 | #define TRTX_CUDA_UTILS_H_
3 |
4 | #include
5 |
6 | #ifndef CUDA_CHECK
7 | #define CUDA_CHECK(callstr)\
8 | {\
9 | cudaError_t error_code = callstr;\
10 | if (error_code != cudaSuccess) {\
11 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\
12 | assert(0);\
13 | }\
14 | }
15 | #endif // CUDA_CHECK
16 |
17 | #endif // TRTX_CUDA_UTILS_H_
18 |
--------------------------------------------------------------------------------
/src/logging.h:
--------------------------------------------------------------------------------
1 | #ifndef TENSORRT_LOGGING_H
2 | #define TENSORRT_LOGGING_H
3 |
4 | #include "NvInferRuntimeCommon.h"
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include "macros.h"
13 |
14 | using Severity = nvinfer1::ILogger::Severity;
15 |
16 | class LogStreamConsumerBuffer : public std::stringbuf
17 | {
18 | public:
19 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
20 | : mOutput(stream)
21 | , mPrefix(prefix)
22 | , mShouldLog(shouldLog)
23 | {
24 | }
25 |
26 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
27 | : mOutput(other.mOutput)
28 | {
29 | }
30 |
31 | ~LogStreamConsumerBuffer()
32 | {
33 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
34 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence
35 | // if the pointer to the beginning is not equal to the pointer to the current position,
36 | // call putOutput() to log the output to the stream
37 | if (pbase() != pptr())
38 | {
39 | putOutput();
40 | }
41 | }
42 |
43 | // synchronizes the stream buffer and returns 0 on success
44 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
45 | // resetting the buffer and flushing the stream
46 | virtual int sync()
47 | {
48 | putOutput();
49 | return 0;
50 | }
51 |
52 | void putOutput()
53 | {
54 | if (mShouldLog)
55 | {
56 | // prepend timestamp
57 | std::time_t timestamp = std::time(nullptr);
58 | tm* tm_local = std::localtime(×tamp);
59 | std::cout << "[";
60 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
61 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
62 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
63 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
64 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
65 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
66 | // std::stringbuf::str() gets the string contents of the buffer
67 | // insert the buffer contents pre-appended by the appropriate prefix into the stream
68 | mOutput << mPrefix << str();
69 | // set the buffer to empty
70 | str("");
71 | // flush the stream
72 | mOutput.flush();
73 | }
74 | }
75 |
76 | void setShouldLog(bool shouldLog)
77 | {
78 | mShouldLog = shouldLog;
79 | }
80 |
81 | private:
82 | std::ostream& mOutput;
83 | std::string mPrefix;
84 | bool mShouldLog;
85 | };
86 |
87 | //!
88 | //! \class LogStreamConsumerBase
89 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
90 | //!
91 | class LogStreamConsumerBase
92 | {
93 | public:
94 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
95 | : mBuffer(stream, prefix, shouldLog)
96 | {
97 | }
98 |
99 | protected:
100 | LogStreamConsumerBuffer mBuffer;
101 | };
102 |
103 | //!
104 | //! \class LogStreamConsumer
105 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
106 | //! Order of base classes is LogStreamConsumerBase and then std::ostream.
107 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
108 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
109 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
110 | //! Please do not change the order of the parent classes.
111 | //!
112 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
113 | {
114 | public:
115 | //! \brief Creates a LogStreamConsumer which logs messages with level severity.
116 | //! Reportable severity determines if the messages are severe enough to be logged.
117 | LogStreamConsumer(Severity reportableSeverity, Severity severity)
118 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
119 | , std::ostream(&mBuffer) // links the stream buffer with the stream
120 | , mShouldLog(severity <= reportableSeverity)
121 | , mSeverity(severity)
122 | {
123 | }
124 |
125 | LogStreamConsumer(LogStreamConsumer&& other)
126 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
127 | , std::ostream(&mBuffer) // links the stream buffer with the stream
128 | , mShouldLog(other.mShouldLog)
129 | , mSeverity(other.mSeverity)
130 | {
131 | }
132 |
133 | void setReportableSeverity(Severity reportableSeverity)
134 | {
135 | mShouldLog = mSeverity <= reportableSeverity;
136 | mBuffer.setShouldLog(mShouldLog);
137 | }
138 |
139 | private:
140 | static std::ostream& severityOstream(Severity severity)
141 | {
142 | return severity >= Severity::kINFO ? std::cout : std::cerr;
143 | }
144 |
145 | static std::string severityPrefix(Severity severity)
146 | {
147 | switch (severity)
148 | {
149 | case Severity::kINTERNAL_ERROR: return "[F] ";
150 | case Severity::kERROR: return "[E] ";
151 | case Severity::kWARNING: return "[W] ";
152 | case Severity::kINFO: return "[I] ";
153 | case Severity::kVERBOSE: return "[V] ";
154 | default: assert(0); return "";
155 | }
156 | }
157 |
158 | bool mShouldLog;
159 | Severity mSeverity;
160 | };
161 |
162 | //! \class Logger
163 | //!
164 | //! \brief Class which manages logging of TensorRT tools and samples
165 | //!
166 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
167 | //! and supports logging two types of messages:
168 | //!
169 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
170 | //! - Test pass/fail messages
171 | //!
172 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
173 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
174 | //!
175 | //! In the future, this class could be extended to support dumping test results to a file in some standard format
176 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
177 | //!
178 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
179 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
180 | //! library and messages coming from the sample.
181 | //!
182 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
183 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
184 | //! object.
185 |
186 | class Logger : public nvinfer1::ILogger
187 | {
188 | public:
189 | Logger(Severity severity = Severity::kWARNING)
190 | : mReportableSeverity(severity)
191 | {
192 | }
193 |
194 | //!
195 | //! \enum TestResult
196 | //! \brief Represents the state of a given test
197 | //!
198 | enum class TestResult
199 | {
200 | kRUNNING, //!< The test is running
201 | kPASSED, //!< The test passed
202 | kFAILED, //!< The test failed
203 | kWAIVED //!< The test was waived
204 | };
205 |
206 | //!
207 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
208 | //! \return The nvinfer1::ILogger associated with this Logger
209 | //!
210 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
211 | //! we can eliminate the inheritance of Logger from ILogger
212 | //!
213 | nvinfer1::ILogger& getTRTLogger()
214 | {
215 | return *this;
216 | }
217 |
218 | //!
219 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
220 | //!
221 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
222 | //! inheritance from nvinfer1::ILogger
223 | //!
224 | void log(Severity severity, const char* msg) TRT_NOEXCEPT override
225 | {
226 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
227 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
228 | }
229 |
230 | //!
231 | //! \brief Method for controlling the verbosity of logging output
232 | //!
233 | //! \param severity The logger will only emit messages that have severity of this level or higher.
234 | //!
235 | void setReportableSeverity(Severity severity)
236 | {
237 | mReportableSeverity = severity;
238 | }
239 |
240 | //!
241 | //! \brief Opaque handle that holds logging information for a particular test
242 | //!
243 | //! This object is an opaque handle to information used by the Logger to print test results.
244 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
245 | //! with Logger::reportTest{Start,End}().
246 | //!
247 | class TestAtom
248 | {
249 | public:
250 | TestAtom(TestAtom&&) = default;
251 |
252 | private:
253 | friend class Logger;
254 |
255 | TestAtom(bool started, const std::string& name, const std::string& cmdline)
256 | : mStarted(started)
257 | , mName(name)
258 | , mCmdline(cmdline)
259 | {
260 | }
261 |
262 | bool mStarted;
263 | std::string mName;
264 | std::string mCmdline;
265 | };
266 |
267 | //!
268 | //! \brief Define a test for logging
269 | //!
270 | //! \param[in] name The name of the test. This should be a string starting with
271 | //! "TensorRT" and containing dot-separated strings containing
272 | //! the characters [A-Za-z0-9_].
273 | //! For example, "TensorRT.sample_googlenet"
274 | //! \param[in] cmdline The command line used to reproduce the test
275 | //
276 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
277 | //!
278 | static TestAtom defineTest(const std::string& name, const std::string& cmdline)
279 | {
280 | return TestAtom(false, name, cmdline);
281 | }
282 |
283 | //!
284 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
285 | //! as input
286 | //!
287 | //! \param[in] name The name of the test
288 | //! \param[in] argc The number of command-line arguments
289 | //! \param[in] argv The array of command-line arguments (given as C strings)
290 | //!
291 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
292 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
293 | {
294 | auto cmdline = genCmdlineString(argc, argv);
295 | return defineTest(name, cmdline);
296 | }
297 |
298 | //!
299 | //! \brief Report that a test has started.
300 | //!
301 | //! \pre reportTestStart() has not been called yet for the given testAtom
302 | //!
303 | //! \param[in] testAtom The handle to the test that has started
304 | //!
305 | static void reportTestStart(TestAtom& testAtom)
306 | {
307 | reportTestResult(testAtom, TestResult::kRUNNING);
308 | assert(!testAtom.mStarted);
309 | testAtom.mStarted = true;
310 | }
311 |
312 | //!
313 | //! \brief Report that a test has ended.
314 | //!
315 | //! \pre reportTestStart() has been called for the given testAtom
316 | //!
317 | //! \param[in] testAtom The handle to the test that has ended
318 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
319 | //! TestResult::kFAILED, TestResult::kWAIVED
320 | //!
321 | static void reportTestEnd(const TestAtom& testAtom, TestResult result)
322 | {
323 | assert(result != TestResult::kRUNNING);
324 | assert(testAtom.mStarted);
325 | reportTestResult(testAtom, result);
326 | }
327 |
328 | static int reportPass(const TestAtom& testAtom)
329 | {
330 | reportTestEnd(testAtom, TestResult::kPASSED);
331 | return EXIT_SUCCESS;
332 | }
333 |
334 | static int reportFail(const TestAtom& testAtom)
335 | {
336 | reportTestEnd(testAtom, TestResult::kFAILED);
337 | return EXIT_FAILURE;
338 | }
339 |
340 | static int reportWaive(const TestAtom& testAtom)
341 | {
342 | reportTestEnd(testAtom, TestResult::kWAIVED);
343 | return EXIT_SUCCESS;
344 | }
345 |
346 | static int reportTest(const TestAtom& testAtom, bool pass)
347 | {
348 | return pass ? reportPass(testAtom) : reportFail(testAtom);
349 | }
350 |
351 | Severity getReportableSeverity() const
352 | {
353 | return mReportableSeverity;
354 | }
355 |
356 | private:
357 | //!
358 | //! \brief returns an appropriate string for prefixing a log message with the given severity
359 | //!
360 | static const char* severityPrefix(Severity severity)
361 | {
362 | switch (severity)
363 | {
364 | case Severity::kINTERNAL_ERROR: return "[F] ";
365 | case Severity::kERROR: return "[E] ";
366 | case Severity::kWARNING: return "[W] ";
367 | case Severity::kINFO: return "[I] ";
368 | case Severity::kVERBOSE: return "[V] ";
369 | default: assert(0); return "";
370 | }
371 | }
372 |
373 | //!
374 | //! \brief returns an appropriate string for prefixing a test result message with the given result
375 | //!
376 | static const char* testResultString(TestResult result)
377 | {
378 | switch (result)
379 | {
380 | case TestResult::kRUNNING: return "RUNNING";
381 | case TestResult::kPASSED: return "PASSED";
382 | case TestResult::kFAILED: return "FAILED";
383 | case TestResult::kWAIVED: return "WAIVED";
384 | default: assert(0); return "";
385 | }
386 | }
387 |
388 | //!
389 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
390 | //!
391 | static std::ostream& severityOstream(Severity severity)
392 | {
393 | return severity >= Severity::kINFO ? std::cout : std::cerr;
394 | }
395 |
396 | //!
397 | //! \brief method that implements logging test results
398 | //!
399 | static void reportTestResult(const TestAtom& testAtom, TestResult result)
400 | {
401 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
402 | << testAtom.mCmdline << std::endl;
403 | }
404 |
405 | //!
406 | //! \brief generate a command line string from the given (argc, argv) values
407 | //!
408 | static std::string genCmdlineString(int argc, char const* const* argv)
409 | {
410 | std::stringstream ss;
411 | for (int i = 0; i < argc; i++)
412 | {
413 | if (i > 0)
414 | ss << " ";
415 | ss << argv[i];
416 | }
417 | return ss.str();
418 | }
419 |
420 | Severity mReportableSeverity;
421 | };
422 |
423 | namespace
424 | {
425 |
426 | //!
427 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
428 | //!
429 | //! Example usage:
430 | //!
431 | //! LOG_VERBOSE(logger) << "hello world" << std::endl;
432 | //!
433 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
434 | {
435 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
436 | }
437 |
438 | //!
439 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
440 | //!
441 | //! Example usage:
442 | //!
443 | //! LOG_INFO(logger) << "hello world" << std::endl;
444 | //!
445 | inline LogStreamConsumer LOG_INFO(const Logger& logger)
446 | {
447 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
448 | }
449 |
450 | //!
451 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
452 | //!
453 | //! Example usage:
454 | //!
455 | //! LOG_WARN(logger) << "hello world" << std::endl;
456 | //!
457 | inline LogStreamConsumer LOG_WARN(const Logger& logger)
458 | {
459 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
460 | }
461 |
462 | //!
463 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
464 | //!
465 | //! Example usage:
466 | //!
467 | //! LOG_ERROR(logger) << "hello world" << std::endl;
468 | //!
469 | inline LogStreamConsumer LOG_ERROR(const Logger& logger)
470 | {
471 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
472 | }
473 |
474 | //!
475 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
476 | // ("fatal" severity)
477 | //!
478 | //! Example usage:
479 | //!
480 | //! LOG_FATAL(logger) << "hello world" << std::endl;
481 | //!
482 | inline LogStreamConsumer LOG_FATAL(const Logger& logger)
483 | {
484 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
485 | }
486 |
487 | } // anonymous namespace
488 |
489 | #endif // TENSORRT_LOGGING_H
--------------------------------------------------------------------------------
/src/macros.h:
--------------------------------------------------------------------------------
1 | #ifndef __MACROS_H
2 | #define __MACROS_H
3 |
4 | #ifdef API_EXPORTS
5 | #if defined(_MSC_VER)
6 | #define API __declspec(dllexport)
7 | #else
8 | #define API __attribute__((visibility("default")))
9 | #endif
10 | #else
11 |
12 | #if defined(_MSC_VER)
13 | #define API __declspec(dllimport)
14 | #else
15 | #define API
16 | #endif
17 | #endif // API_EXPORTS
18 |
19 | #if NV_TENSORRT_MAJOR >= 8
20 | #define TRT_NOEXCEPT noexcept
21 | #define TRT_CONST_ENQUEUE const
22 | #else
23 | #define TRT_NOEXCEPT
24 | #define TRT_CONST_ENQUEUE
25 | #endif
26 |
27 | #endif // __MACROS_H
--------------------------------------------------------------------------------
/src/preprocess.cu:
--------------------------------------------------------------------------------
1 | #include "preprocess.h"
2 | #include "cuda_utils.h"
3 | #include "device_launch_parameters.h"
4 | #include
5 |
6 | // Static buffers
7 | static uint8_t* img_buffer_host = nullptr;
8 | static uint8_t* img_buffer_device = nullptr;
9 |
10 | struct AffineMatrix {
11 | float value[6];
12 | };
13 |
14 | // CUDA error checking macro
15 | #define CUDA_CALL(x) do { \
16 | cudaError_t err = x; \
17 | if (err != cudaSuccess) { \
18 | std::cerr << "CUDA Error: " << cudaGetErrorString(err) \
19 | << " at " << __FILE__ << ":" << __LINE__ << std::endl; \
20 | std::exit(EXIT_FAILURE); \
21 | } \
22 | } while (0)
23 |
24 | // Kernel with logs
25 | __global__ void warpaffine_kernel(
26 | uint8_t* src, int src_line_size, int src_width,
27 | int src_height, float* dst, int dst_width,
28 | int dst_height, uint8_t const_value_st,
29 | AffineMatrix d2s, int edge) {
30 |
31 | int position = blockDim.x * blockIdx.x + threadIdx.x;
32 | if (position >= edge) return;
33 |
34 | int dx = position % dst_width;
35 | int dy = position / dst_width;
36 |
37 | // Transform source coordinates
38 | float src_x = d2s.value[0] * dx + d2s.value[1] * dy + d2s.value[2] + 0.5f;
39 | float src_y = d2s.value[3] * dx + d2s.value[4] * dy + d2s.value[5] + 0.5f;
40 |
41 | // printf("Thread %d: (dx, dy) = (%d, %d), (src_x, src_y) = (%.2f, %.2f)\n",
42 | // position, dx, dy, src_x, src_y);
43 |
44 | float c0, c1, c2;
45 |
46 | // Check if source coordinates are out of bounds
47 | if (src_x < 0 || src_x >= src_width || src_y < 0 || src_y >= src_height) {
48 | c0 = c1 = c2 = const_value_st; // Default value for out-of-range
49 | } else {
50 | int x_low = floorf(src_x);
51 | int y_low = floorf(src_y);
52 | int x_high = x_low + 1;
53 | int y_high = y_low + 1;
54 |
55 | // Handle boundary conditions
56 | uint8_t* v1 = src + y_low * src_line_size + x_low * 3;
57 | uint8_t* v2 = (x_high < src_width) ? src + y_low * src_line_size + x_high * 3 : v1;
58 | uint8_t* v3 = (y_high < src_height) ? src + y_high * src_line_size + x_low * 3 : v1;
59 | uint8_t* v4 = (x_high < src_width && y_high < src_height) ? src + y_high * src_line_size + x_high * 3 : v1;
60 |
61 | // Bilinear interpolation weights
62 | float lx = src_x - x_low;
63 | float ly = src_y - y_low;
64 | float hx = 1 - lx;
65 | float hy = 1 - ly;
66 | float w1 = hx * hy, w2 = lx * hy, w3 = hx * ly, w4 = lx * ly;
67 |
68 | // Compute final pixel values
69 | c0 = w1 * v1[0] + w2 * v2[0] + w3 * v3[0] + w4 * v4[0];
70 | c1 = w1 * v1[1] + w2 * v2[1] + w3 * v3[1] + w4 * v4[1];
71 | c2 = w1 * v1[2] + w2 * v2[2] + w3 * v3[2] + w4 * v4[2];
72 | }
73 |
74 | // Convert BGR to RGB
75 | float tmp = c0; c0 = c2; c2 = tmp;
76 |
77 | // Normalize pixel values
78 | c0 /= 255.0f;
79 | c1 /= 255.0f;
80 | c2 /= 255.0f;
81 |
82 | int area = dst_width * dst_height;
83 | float* pdst_c0 = dst + dy * dst_width + dx;
84 | float* pdst_c1 = pdst_c0 + area;
85 | float* pdst_c2 = pdst_c1 + area;
86 |
87 | // Store the normalized values
88 | *pdst_c0 = c0;
89 | *pdst_c1 = c1;
90 | *pdst_c2 = c2;
91 | }
92 |
93 | // Host-side preprocessing function
94 | void cuda_preprocess(
95 | uint8_t* src, int src_width, int src_height,
96 | float* dst, int dst_width, int dst_height,
97 | cudaStream_t stream) {
98 |
99 | int img_size = src_width * src_height * 3;
100 |
101 | if (img_buffer_host == nullptr) {
102 | std::cerr << "Error: img_buffer_host not allocated!" << std::endl;
103 | }
104 |
105 | if (src == nullptr) {
106 | std::cerr << "Error: Source image pointer is null!" << std::endl;
107 | }
108 |
109 |
110 |
111 | size_t free_mem, total_mem;
112 | cudaMemGetInfo(&free_mem, &total_mem);
113 | std::cout << "Free GPU memory: " << free_mem << ", Total GPU memory: " << total_mem << std::endl;
114 |
115 |
116 | cudaDeviceSynchronize();
117 | std::cout << "Synchronized CUDA device." << std::endl;
118 |
119 | // Copy data to pinned memory
120 | std::cout << "Copying data to pinned memory..." << std::endl;
121 | memcpy(img_buffer_host, src, img_size);
122 |
123 | // Copy data to device memory
124 | std::cout << "Copying data to device memory..." << std::endl;
125 | CUDA_CALL(cudaMemcpyAsync(img_buffer_device, img_buffer_host, img_size, cudaMemcpyHostToDevice, stream));
126 | CUDA_CALL(cudaStreamSynchronize(stream));
127 |
128 | // Prepare the affine matrices
129 | AffineMatrix s2d, d2s;
130 | float scale = std::min(dst_height / (float)src_height, dst_width / (float)src_width);
131 |
132 | s2d.value[0] = scale; s2d.value[1] = 0;
133 | s2d.value[2] = -scale * src_width * 0.5 + dst_width * 0.5;
134 | s2d.value[3] = 0; s2d.value[4] = scale;
135 | s2d.value[5] = -scale * src_height * 0.5 + dst_height * 0.5;
136 |
137 | cv::Mat m2x3_s2d(2, 3, CV_32F, s2d.value);
138 | cv::Mat m2x3_d2s(2, 3, CV_32F, d2s.value);
139 | cv::invertAffineTransform(m2x3_s2d, m2x3_d2s);
140 | memcpy(d2s.value, m2x3_d2s.ptr(0), sizeof(d2s.value));
141 |
142 | int jobs = dst_width * dst_height;
143 | int threads = 256;
144 | int blocks = (jobs + threads - 1) / threads;
145 |
146 | // Launch the kernel
147 | std::cout << "Launching kernel..." << std::endl;
148 | warpaffine_kernel<<>>(
149 | img_buffer_device, src_width * 3, src_width, src_height,
150 | dst, dst_width, dst_height, 128, d2s, jobs);
151 |
152 | // Synchronize and check for errors
153 | CUDA_CALL(cudaStreamSynchronize(stream));
154 | std::cout << "Kernel execution completed." << std::endl;
155 | }
156 |
157 | void cuda_preprocess_init(int max_image_size) {
158 | std::cout << "I am in the preprocess init" << std::endl;
159 | CUDA_CALL(cudaMallocHost((void**)&img_buffer_host, max_image_size * 3));
160 | CUDA_CALL(cudaMalloc((void**)&img_buffer_device, max_image_size * 3));
161 | }
162 |
163 | void cuda_preprocess_destroy() {
164 | CUDA_CALL(cudaFree(img_buffer_device));
165 | CUDA_CALL(cudaFreeHost(img_buffer_host));
166 | }
167 |
--------------------------------------------------------------------------------
/src/preprocess.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | void cuda_preprocess_init(int max_image_size);
8 | void cuda_preprocess_destroy();
9 | void cuda_preprocess(uint8_t* src, int src_width, int src_height,
10 | float* dst, int dst_width, int dst_height,
11 | cudaStream_t stream);
--------------------------------------------------------------------------------