├── CMakeLists.txt
├── LICENSE
├── README.md
├── assets
├── BY HAMDI BOUKAMCHA.JPG
├── Speed_SAM_Results.JPG
├── blue-linkedin-logo.png
├── dog.jpg
├── dogs.jpg
└── speed_sam_cpp_tenosrrt.PNG
├── include
├── config.h
├── cuda_utils.h
├── engineTRT.h
├── logging.h
├── macros.h
├── speedSam.h
└── utils.h
├── model
├── SAM_encoder.onnx
└── SAM_mask_decoder.onnx
└── src
├── engineTRT.cpp
├── main.cpp
└── speedSam.cpp
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.18)
2 | project(nanosam LANGUAGES CXX CUDA)
3 |
4 | # Set C++ standard
5 | set(CMAKE_CXX_STANDARD 17)
6 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
7 |
8 | # Suppress ZERO_CHECK generation in Visual Studio
9 | set(CMAKE_SUPPRESS_REGENERATION ON)
10 |
11 | # Find packages
12 | find_package(CUDA REQUIRED)
13 |
14 | # Set OpenCV root directory
15 | set(OpenCV_DIR "F:/Program Files (x86)/opencv-4.10.0-windows/opencv/build")
16 | find_package(OpenCV REQUIRED)
17 |
18 | # Set TensorRT root directory
19 | set(TENSORRT_ROOT "F:/Program Files/TensorRT-8.6.1.6")
20 | find_path(TENSORRT_INCLUDE_DIR NvInfer.h PATHS ${TENSORRT_ROOT}/include)
21 | find_library(NVINFER_LIBRARY nvinfer PATHS ${TENSORRT_ROOT}/lib)
22 | find_library(NVONNXPARSER_LIBRARY nvonnxparser PATHS ${TENSORRT_ROOT}/lib)
23 |
24 | # Include directories
25 | include_directories(
26 | ${CUDA_INCLUDE_DIRS}
27 | ${TENSORRT_INCLUDE_DIR}
28 | ${OpenCV_INCLUDE_DIRS}
29 | ${CMAKE_SOURCE_DIR}/include # Add include directory
30 | )
31 |
32 | # Add source files
33 | file(GLOB SOURCES
34 | ${CMAKE_SOURCE_DIR}/src/*.cpp
35 | )
36 |
37 | # Add header files
38 | file(GLOB HEADERS
39 | ${CMAKE_SOURCE_DIR}/include/*.h
40 | )
41 |
42 | # Organize headers under "Header Files" in Visual Studio
43 | source_group("Header Files" FILES ${HEADERS})
44 |
45 | # CUDA compilation flags
46 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17")
47 |
48 | # Add main executable without generating ALL_BUILD
49 | add_executable(SpeedSAM ${SOURCES} ${HEADERS})
50 |
51 | # Link libraries
52 | target_link_libraries(SpeedSAM
53 | ${NVINFER_LIBRARY}
54 | ${NVONNXPARSER_LIBRARY}
55 | ${CUDA_LIBRARIES}
56 | ${OpenCV_LIBS}
57 | )
58 |
59 | # Set up custom CUDA runtime library if needed
60 | set_target_properties(SpeedSAM PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
61 |
62 | # TensorRT and CUDA runtime flags
63 | target_compile_definitions(SpeedSAM PRIVATE USE_TENSORRT USE_CUDA)
64 |
65 | # Export paths for flexibility in multi-platform use
66 | message(STATUS "CUDA Libraries: ${CUDA_LIBRARIES}")
67 | message(STATUS "TensorRT Libraries: ${TENSORRT_INCLUDE_DIR}, ${NVINFER_LIBRARY}, ${NVONNXPARSER_LIBRARY}")
68 | message(STATUS "OpenCV Directory: ${OpenCV_DIR}")
69 |
70 | # Set main executable as default startup project (optional)
71 | set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY VS_STARTUP_PROJECT SpeedSAM)
72 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Custom License Agreement
2 |
3 | 1. License Grant You are hereby granted a non-exclusive, non-transferable license to use, reproduce, and distribute the code (hereinafter referred to as "the Software") under the following conditions:
4 |
5 | 2. Conditions of Use
6 |
7 | Non-Commercial Use: You may use the Software for personal, educational, or non-commercial purposes without any additional permissions.
8 | Commercial Use: Any commercial use of the Software, including but not limited to selling, licensing, or using it in a commercial product, requires prior written permission from the original developer.
9 | 3. Contact Requirement
10 |
11 | If you wish to use the Software for commercial purposes, you must contact the original developer at [https://www.linkedin.com/in/hamdi-boukamcha/] to obtain a commercial license.
12 | The terms of any commercial license will be mutually agreed upon and may involve a licensing fee.
13 | 4. Attribution
14 |
15 | Regardless of whether you are using the Software for commercial or non-commercial purposes, you must provide appropriate credit to the original developer in any distributions or products that use the Software.
16 | 5. Disclaimer of Warranty
17 |
18 | The Software is provided "as is," without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the original developer be liable for any claim, damages, or other liability, whether in an action of contract, tort, or otherwise, arising from, out of, or in connection with the Software or the use or other dealings in the Software.
19 | 6. Governing Law
20 |
21 | This License Agreement shall be governed by and construed in accordance with the laws of France.
22 | By using the Software, you agree to abide by the terms outlined in this License Agreement.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SPEED SAM C++ TENSORRT
2 | 
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 | ## 🌐 Overview
13 | A high-performance C++ implementation for SAM (segment anything model) using TensorRT and CUDA, optimized for real-time image segmentation tasks.
14 |
15 | ## 📢 Updates
16 | Model Conversion: Build TensorRT engines from ONNX models for accelerated inference.
17 | Segmentation with Points and BBoxes: Easily segment images using selected points or bounding boxes.
18 | FP16 Precision: Choose between FP16 and FP32 for speed and precision balance.
19 | Dynamic Shape Support: Efficient handling of variable input sizes using optimization profiles.
20 | CUDA Optimization: Leverage CUDA for preprocessing and efficient memory handling.
21 |
22 | ## 📢 Performance
23 | ### Infernce Time
24 |
25 | | Component | SpeedSAM |
26 | |----------------------------|-----------|
27 | | **Image Encoder** | |
28 | | Parameters | 5M |
29 | | Speed | 8ms |
30 | | **Mask Decoder** | |
31 | | Parameters | 3.876M |
32 | | Speed | 4ms |
33 | | **Whole Pipeline (Enc+Dec)** | |
34 | | Parameters | 9.66M |
35 | | Speed | 12ms |
36 | ### Results
37 | 
38 |
39 | ## 📂 Project Structure
40 | SPEED-SAM-CPP-TENSORRT/
41 | ├── include
42 | │ ├── config.h # Model configuration and macros
43 | │ ├── cuda_utils.h # CUDA utility macros
44 | │ ├── engineTRT.h # TensorRT engine management
45 | │ ├── logging.h # Logging utilities
46 | │ ├── macros.h # API export/import macros
47 | │ ├── speedSam.h # SpeedSam class definition
48 | │ └── utils.h # Utility functions for image handling
49 | ├── src
50 | │ ├── engineTRT.cpp # Implementation of the TensorRT engine
51 | │ ├── main.cpp # Main entry point
52 | │ └── speedSam.cpp # Implementation of the SpeedSam class
53 | └── CMakeLists.txt # CMake configuration
54 |
55 | # 🚀 Installation
56 | ## Prerequisites
57 | git clone https://github.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT.git
58 | cd SPEED-SAM-CPP-TENSORRT
59 |
60 | # Create a build directory and compile
61 | mkdir build && cd build
62 | cmake ..
63 | make -j$(nproc)
64 | Note: Update the CMakeLists.txt with the correct paths for TensorRT and OpenCV.
65 |
66 | ## 📦 Dependencies
67 | CUDA: NVIDIA's parallel computing platform
68 | TensorRT: High-performance deep learning inference
69 | OpenCV: Image processing library
70 | C++17: Required standard for compilation
71 |
72 | # 🔍 Code Overview
73 | ## Main Components
74 | SpeedSam Class (speedSam.h): Manages image encoding and mask decoding.
75 | EngineTRT Class (engineTRT.h): TensorRT engine creation and inference.
76 | CUDA Utilities (cuda_utils.h): Macros for CUDA error handling.
77 | Config (config.h): Defines model parameters and precision settings.
78 | ## Key Functions
79 | EngineTRT::build: Builds the TensorRT engine from an ONNX model.
80 | EngineTRT::infer: Runs inference on the provided input data.
81 | SpeedSam::predict: Segments an image using input points or bounding boxes.
82 | ## 📞 Contact
83 |
84 | For advanced inquiries, feel free to contact me on LinkedIn:
85 |
86 | ## 📜 Citation
87 |
88 | If you use this code in your research, please cite the repository as follows:
89 |
90 | @misc{boukamcha2024SpeedSam,
91 | author = {Hamdi Boukamcha},
92 | title = {SPEED-SAM-C-TENSORRT},
93 | year = {2024},
94 | publisher = {GitHub},
95 | howpublished = {\url{https://github.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT//}},
96 | }
97 |
98 |
99 |
100 |
101 |
102 |
--------------------------------------------------------------------------------
/assets/BY HAMDI BOUKAMCHA.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/BY HAMDI BOUKAMCHA.JPG
--------------------------------------------------------------------------------
/assets/Speed_SAM_Results.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/Speed_SAM_Results.JPG
--------------------------------------------------------------------------------
/assets/blue-linkedin-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/blue-linkedin-logo.png
--------------------------------------------------------------------------------
/assets/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/dog.jpg
--------------------------------------------------------------------------------
/assets/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/dogs.jpg
--------------------------------------------------------------------------------
/assets/speed_sam_cpp_tenosrrt.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/speed_sam_cpp_tenosrrt.PNG
--------------------------------------------------------------------------------
/include/config.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | /// \file model_params.h
4 | /// \brief Header file defining model parameters and configuration macros.
5 | ///
6 | /// This header file contains macros for model parameters including input dimensions,
7 | /// feature dimensions, and precision settings.
8 | ///
9 | /// \author Hamdi Boukamcha
10 | /// \date 2024
11 |
12 | #define USE_FP16 ///< Set to use FP16 (float16) precision, or comment to use FP32 (float32) precision.
13 |
14 | #define MAX_NUM_PROMPTS 1 ///< Maximum number of prompts to be processed at once.
15 |
16 | // Model Params
17 | #define MODEL_INPUT_WIDTH 1024.0f ///< Width of the model input in pixels.
18 | #define MODEL_INPUT_HEIGHT 1024.0f ///< Height of the model input in pixels.
19 | #define HIDDEN_DIM 256 ///< Dimension of the hidden layer.
20 | #define NUM_LABELS 4 ///< Number of output labels.
21 | #define FEATURE_WIDTH 64 ///< Width of the feature map.
22 | #define FEATURE_HEIGHT 64 ///< Height of the feature map.
23 |
--------------------------------------------------------------------------------
/include/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef TRTX_CUDA_UTILS_H_
2 | #define TRTX_CUDA_UTILS_H_
3 |
4 | #include
5 |
6 | /// \file trtx_cuda_utils.h
7 | /// \brief Header file providing CUDA utility macros.
8 | ///
9 | /// This header file defines utility macros for error checking
10 | /// in CUDA operations, allowing for easier debugging and error
11 | /// handling in GPU-related code.
12 | ///
13 | /// \author Hamdi Boukamcha
14 | /// \date 2024
15 |
16 | #ifndef CUDA_CHECK
17 | /// \brief Macro for checking CUDA function calls.
18 | ///
19 | /// This macro checks the return status of a CUDA call and prints
20 | /// an error message if the call fails. It asserts to halt execution
21 | /// in case of an error.
22 | #define CUDA_CHECK(callstr) \
23 | { \
24 | cudaError_t error_code = callstr; \
25 | if (error_code != cudaSuccess) { \
26 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \
27 | assert(0); \
28 | } \
29 | }
30 | #endif // CUDA_CHECK
31 |
32 | /// \brief Macro for checking conditions with a custom error message.
33 | ///
34 | /// This macro checks a condition and, if false, logs an error message
35 | /// and returns a specified value. It includes information about the
36 | /// file, function, and line number where the error occurred.
37 | ///
38 | /// \param status The condition to check.
39 | /// \param val The value to return if the condition is false.
40 | /// \param errMsg The error message to log.
41 | #define CHECK_RETURN_W_MSG(status, val, errMsg) \
42 | do \
43 | { \
44 | if (!(status)) \
45 | { \
46 | sample::gLogError << errMsg << " Error in " << __FILE__ << ", function " << FN_NAME << "(), line " << __LINE__ \
47 | << std::endl; \
48 | return val; \
49 | } \
50 | } while (0)
51 |
52 | #endif // TRTX_CUDA_UTILS_H_
53 |
--------------------------------------------------------------------------------
/include/engineTRT.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "NvInfer.h"
4 | #include
5 |
6 | using namespace nvinfer1;
7 | using namespace std;
8 | using namespace cv;
9 |
10 | /// \class TRTModule
11 | /// \brief A class for handling TensorRT model inference.
12 | ///
13 | /// This class manages loading, setting inputs, and executing inference
14 | /// for a TensorRT model. It provides methods for setting input data,
15 | /// retrieving output predictions, and handling both static and dynamic shapes.
16 | ///
17 | /// \author Hamdi Boukamcha
18 | /// \date 2024
19 | class EngineTRT
20 | {
21 |
22 | public:
23 | /// \brief Constructor for the TRTModule class.
24 | ///
25 | /// \param modelPath Path to the ONNX model file.
26 | /// \param inputNames Names of the input tensors.
27 | /// \param outputNames Names of the output tensors.
28 | /// \param isDynamicShape Indicates if the model uses dynamic shapes.
29 | /// \param isFP16 Indicates if the model should use FP16 precision.
30 | EngineTRT(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16);
31 |
32 | /// \brief Performs inference on the input data.
33 | /// \return True if inference was successful, false otherwise.
34 | bool infer();
35 |
36 | /// \brief Sets the input image for inference.
37 | ///
38 | /// \param image The input image to be processed.
39 | void setInput(Mat& image);
40 |
41 | /// \brief Sets multiple inputs for inference from raw data.
42 | ///
43 | /// \param features Pointer to feature data.
44 | /// \param imagePointCoords Pointer to image point coordinates.
45 | /// \param imagePointLabels Pointer to image point labels.
46 | /// \param maskInput Pointer to mask input data.
47 | /// \param hasMaskInput Pointer to existence of mask data.
48 | /// \param numPoints The number of points in the input.
49 | void setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasMaskInput, int numPoints);
50 |
51 | /// \brief Retrieves the output predictions for IoU and low-resolution masks.
52 | ///
53 | /// \param iouPrediction Pointer to store IoU prediction output.
54 | /// \param lowResolutionMasks Pointer to store low-resolution mask output.
55 | void getOutput(float* iouPrediction, float* lowResolutionMasks);
56 |
57 | /// \brief Retrieves the output features from inference.
58 | ///
59 | /// \param features Pointer to store output features.
60 | void getOutput(float* features);
61 |
62 | /// \brief Destructor for the TRTModule class.
63 | ~EngineTRT();
64 |
65 | private:
66 | /// \brief Builds the TensorRT engine from an ONNX model.
67 | ///
68 | /// \param onnxPath Path to the ONNX model file.
69 | /// \param inputNames Names of the input tensors.
70 | /// \param outputNames Names of the output tensors.
71 | /// \param isDynamicShape Indicates if the model uses dynamic shapes.
72 | /// \param isFP16 Indicates if the model should use FP16 precision.
73 | void build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape = false, bool isFP16 = false);
74 |
75 | void saveEngine(const std::string& engineFilePath);
76 |
77 | /// \brief Deserializes the engine from a file.
78 | ///
79 | /// \param engineName Name of the engine file.
80 | /// \param inputNames Names of the input tensors.
81 | /// \param outputNames Names of the output tensors.
82 | void deserializeEngine(string engineName, vector inputNames, vector outputNames);
83 |
84 | /// \brief Initializes the TensorRT module with input and output names.
85 | ///
86 | /// \param inputNames Names of the input tensors.
87 | /// \param outputNames Names of the output tensors.
88 | void initialize(vector inputNames, vector outputNames);
89 |
90 | /// \brief Gets the size of a buffer based on its dimensions.
91 | ///
92 | /// \param dims The dimensions of the buffer.
93 | /// \return The size in bytes of the buffer.
94 | size_t getSizeByDim(const Dims& dims);
95 |
96 | /// \brief Copies buffers between device and host memory.
97 | ///
98 | /// \param copyInput Indicates whether to copy input data.
99 | /// \param deviceToHost Indicates the direction of the copy.
100 | /// \param async Indicates whether the copy should be asynchronous.
101 | /// \param stream The CUDA stream to use for the copy operation.
102 | void memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream = 0);
103 |
104 | /// \brief Asynchronously copies input data to the device.
105 | ///
106 | /// \param stream The CUDA stream to use for the copy operation.
107 | void copyInputToDeviceAsync(const cudaStream_t& stream = 0);
108 |
109 | /// \brief Asynchronously copies output data from the device to host.
110 | ///
111 | /// \param stream The CUDA stream to use for the copy operation.
112 | void copyOutputToHostAsync(const cudaStream_t& stream = 0);
113 |
114 | vector mInputDims; //!< The dimensions of the input to the network.
115 | vector mOutputDims; //!< The dimensions of the output to the network.
116 | vector mGpuBuffers; //!< The vector of device buffers needed for engine execution.
117 | vector mCpuBuffers; //!< The vector of CPU buffers for input/output.
118 | vector mBufferBindingBytes; //!< The sizes in bytes of each buffer binding.
119 | vector mBufferBindingSizes; //!< The sizes of the buffer bindings.
120 | cudaStream_t mCudaStream; //!< The CUDA stream used for asynchronous operations.
121 |
122 | IRuntime* mRuntime; //!< The TensorRT runtime used to deserialize the engine.
123 | ICudaEngine* mEngine; //!< The TensorRT engine used to run the network.
124 | IExecutionContext* mContext; //!< The context for executing inference using an ICudaEngine.
125 | };
126 |
--------------------------------------------------------------------------------
/include/logging.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #ifndef TENSORRT_LOGGING_H
18 | #define TENSORRT_LOGGING_H
19 |
20 | #include "NvInferRuntimeCommon.h"
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 | #include
28 | #include "macros.h"
29 |
30 | using Severity = nvinfer1::ILogger::Severity;
31 |
32 | class LogStreamConsumerBuffer : public std::stringbuf
33 | {
34 | public:
35 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
36 | : mOutput(stream)
37 | , mPrefix(prefix)
38 | , mShouldLog(shouldLog)
39 | {
40 | }
41 |
42 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
43 | : mOutput(other.mOutput)
44 | {
45 | }
46 |
47 | ~LogStreamConsumerBuffer()
48 | {
49 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
50 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence
51 | // if the pointer to the beginning is not equal to the pointer to the current position,
52 | // call putOutput() to log the output to the stream
53 | if (pbase() != pptr())
54 | {
55 | putOutput();
56 | }
57 | }
58 |
59 | // synchronizes the stream buffer and returns 0 on success
60 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
61 | // resetting the buffer and flushing the stream
62 | virtual int sync()
63 | {
64 | putOutput();
65 | return 0;
66 | }
67 |
68 | void putOutput()
69 | {
70 | if (mShouldLog)
71 | {
72 | // prepend timestamp
73 | std::time_t timestamp = std::time(nullptr);
74 | tm* tm_local = std::localtime(×tamp);
75 | std::cout << "[";
76 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
77 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
78 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
81 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
82 | // std::stringbuf::str() gets the string contents of the buffer
83 | // insert the buffer contents pre-appended by the appropriate prefix into the stream
84 | mOutput << mPrefix << str();
85 | // set the buffer to empty
86 | str("");
87 | // flush the stream
88 | mOutput.flush();
89 | }
90 | }
91 |
92 | void setShouldLog(bool shouldLog)
93 | {
94 | mShouldLog = shouldLog;
95 | }
96 |
97 | private:
98 | std::ostream& mOutput;
99 | std::string mPrefix;
100 | bool mShouldLog;
101 | };
102 |
103 | //!
104 | //! \class LogStreamConsumerBase
105 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
106 | //!
107 | class LogStreamConsumerBase
108 | {
109 | public:
110 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
111 | : mBuffer(stream, prefix, shouldLog)
112 | {
113 | }
114 |
115 | protected:
116 | LogStreamConsumerBuffer mBuffer;
117 | };
118 |
119 | //!
120 | //! \class LogStreamConsumer
121 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
122 | //! Order of base classes is LogStreamConsumerBase and then std::ostream.
123 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
124 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
125 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
126 | //! Please do not change the order of the parent classes.
127 | //!
128 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
129 | {
130 | public:
131 | //! \brief Creates a LogStreamConsumer which logs messages with level severity.
132 | //! Reportable severity determines if the messages are severe enough to be logged.
133 | LogStreamConsumer(Severity reportableSeverity, Severity severity)
134 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
135 | , std::ostream(&mBuffer) // links the stream buffer with the stream
136 | , mShouldLog(severity <= reportableSeverity)
137 | , mSeverity(severity)
138 | {
139 | }
140 |
141 | LogStreamConsumer(LogStreamConsumer&& other)
142 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
143 | , std::ostream(&mBuffer) // links the stream buffer with the stream
144 | , mShouldLog(other.mShouldLog)
145 | , mSeverity(other.mSeverity)
146 | {
147 | }
148 |
149 | void setReportableSeverity(Severity reportableSeverity)
150 | {
151 | mShouldLog = mSeverity <= reportableSeverity;
152 | mBuffer.setShouldLog(mShouldLog);
153 | }
154 |
155 | private:
156 | static std::ostream& severityOstream(Severity severity)
157 | {
158 | return severity >= Severity::kINFO ? std::cout : std::cerr;
159 | }
160 |
161 | static std::string severityPrefix(Severity severity)
162 | {
163 | switch (severity)
164 | {
165 | case Severity::kINTERNAL_ERROR: return "[F] ";
166 | case Severity::kERROR: return "[E] ";
167 | case Severity::kWARNING: return "[W] ";
168 | case Severity::kINFO: return "[I] ";
169 | case Severity::kVERBOSE: return "[V] ";
170 | default: assert(0); return "";
171 | }
172 | }
173 |
174 | bool mShouldLog;
175 | Severity mSeverity;
176 | };
177 |
178 | //! \class Logger
179 | //!
180 | //! \brief Class which manages logging of TensorRT tools and samples
181 | //!
182 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
183 | //! and supports logging two types of messages:
184 | //!
185 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
186 | //! - Test pass/fail messages
187 | //!
188 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
189 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
190 | //!
191 | //! In the future, this class could be extended to support dumping test results to a file in some standard format
192 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
193 | //!
194 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
195 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
196 | //! library and messages coming from the sample.
197 | //!
198 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
199 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
200 | //! object.
201 |
202 | class Logger : public nvinfer1::ILogger
203 | {
204 | public:
205 | Logger(Severity severity = Severity::kWARNING)
206 | : mReportableSeverity(severity)
207 | {
208 | }
209 |
210 | //!
211 | //! \enum TestResult
212 | //! \brief Represents the state of a given test
213 | //!
214 | enum class TestResult
215 | {
216 | kRUNNING, //!< The test is running
217 | kPASSED, //!< The test passed
218 | kFAILED, //!< The test failed
219 | kWAIVED //!< The test was waived
220 | };
221 |
222 | //!
223 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
224 | //! \return The nvinfer1::ILogger associated with this Logger
225 | //!
226 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
227 | //! we can eliminate the inheritance of Logger from ILogger
228 | //!
229 | nvinfer1::ILogger& getTRTLogger()
230 | {
231 | return *this;
232 | }
233 |
234 | //!
235 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
236 | //!
237 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
238 | //! inheritance from nvinfer1::ILogger
239 | //!
240 | void log(Severity severity, const char* msg) TRT_NOEXCEPT override
241 | {
242 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
243 | }
244 |
245 | //!
246 | //! \brief Method for controlling the verbosity of logging output
247 | //!
248 | //! \param severity The logger will only emit messages that have severity of this level or higher.
249 | //!
250 | void setReportableSeverity(Severity severity)
251 | {
252 | mReportableSeverity = severity;
253 | }
254 |
255 | //!
256 | //! \brief Opaque handle that holds logging information for a particular test
257 | //!
258 | //! This object is an opaque handle to information used by the Logger to print test results.
259 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
260 | //! with Logger::reportTest{Start,End}().
261 | //!
262 | class TestAtom
263 | {
264 | public:
265 | TestAtom(TestAtom&&) = default;
266 |
267 | private:
268 | friend class Logger;
269 |
270 | TestAtom(bool started, const std::string& name, const std::string& cmdline)
271 | : mStarted(started)
272 | , mName(name)
273 | , mCmdline(cmdline)
274 | {
275 | }
276 |
277 | bool mStarted;
278 | std::string mName;
279 | std::string mCmdline;
280 | };
281 |
282 | //!
283 | //! \brief Define a test for logging
284 | //!
285 | //! \param[in] name The name of the test. This should be a string starting with
286 | //! "TensorRT" and containing dot-separated strings containing
287 | //! the characters [A-Za-z0-9_].
288 | //! For example, "TensorRT.sample_googlenet"
289 | //! \param[in] cmdline The command line used to reproduce the test
290 | //
291 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
292 | //!
293 | static TestAtom defineTest(const std::string& name, const std::string& cmdline)
294 | {
295 | return TestAtom(false, name, cmdline);
296 | }
297 |
298 | //!
299 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
300 | //! as input
301 | //!
302 | //! \param[in] name The name of the test
303 | //! \param[in] argc The number of command-line arguments
304 | //! \param[in] argv The array of command-line arguments (given as C strings)
305 | //!
306 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
307 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
308 | {
309 | auto cmdline = genCmdlineString(argc, argv);
310 | return defineTest(name, cmdline);
311 | }
312 |
313 | //!
314 | //! \brief Report that a test has started.
315 | //!
316 | //! \pre reportTestStart() has not been called yet for the given testAtom
317 | //!
318 | //! \param[in] testAtom The handle to the test that has started
319 | //!
320 | static void reportTestStart(TestAtom& testAtom)
321 | {
322 | reportTestResult(testAtom, TestResult::kRUNNING);
323 | assert(!testAtom.mStarted);
324 | testAtom.mStarted = true;
325 | }
326 |
327 | //!
328 | //! \brief Report that a test has ended.
329 | //!
330 | //! \pre reportTestStart() has been called for the given testAtom
331 | //!
332 | //! \param[in] testAtom The handle to the test that has ended
333 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
334 | //! TestResult::kFAILED, TestResult::kWAIVED
335 | //!
336 | static void reportTestEnd(const TestAtom& testAtom, TestResult result)
337 | {
338 | assert(result != TestResult::kRUNNING);
339 | assert(testAtom.mStarted);
340 | reportTestResult(testAtom, result);
341 | }
342 |
343 | static int reportPass(const TestAtom& testAtom)
344 | {
345 | reportTestEnd(testAtom, TestResult::kPASSED);
346 | return EXIT_SUCCESS;
347 | }
348 |
349 | static int reportFail(const TestAtom& testAtom)
350 | {
351 | reportTestEnd(testAtom, TestResult::kFAILED);
352 | return EXIT_FAILURE;
353 | }
354 |
355 | static int reportWaive(const TestAtom& testAtom)
356 | {
357 | reportTestEnd(testAtom, TestResult::kWAIVED);
358 | return EXIT_SUCCESS;
359 | }
360 |
361 | static int reportTest(const TestAtom& testAtom, bool pass)
362 | {
363 | return pass ? reportPass(testAtom) : reportFail(testAtom);
364 | }
365 |
366 | Severity getReportableSeverity() const
367 | {
368 | return mReportableSeverity;
369 | }
370 |
371 | private:
372 | //!
373 | //! \brief returns an appropriate string for prefixing a log message with the given severity
374 | //!
375 | static const char* severityPrefix(Severity severity)
376 | {
377 | switch (severity)
378 | {
379 | case Severity::kINTERNAL_ERROR: return "[F] ";
380 | case Severity::kERROR: return "[E] ";
381 | case Severity::kWARNING: return "[W] ";
382 | case Severity::kINFO: return "[I] ";
383 | case Severity::kVERBOSE: return "[V] ";
384 | default: assert(0); return "";
385 | }
386 | }
387 |
388 | //!
389 | //! \brief returns an appropriate string for prefixing a test result message with the given result
390 | //!
391 | static const char* testResultString(TestResult result)
392 | {
393 | switch (result)
394 | {
395 | case TestResult::kRUNNING: return "RUNNING";
396 | case TestResult::kPASSED: return "PASSED";
397 | case TestResult::kFAILED: return "FAILED";
398 | case TestResult::kWAIVED: return "WAIVED";
399 | default: assert(0); return "";
400 | }
401 | }
402 |
403 | //!
404 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
405 | //!
406 | static std::ostream& severityOstream(Severity severity)
407 | {
408 | return severity >= Severity::kINFO ? std::cout : std::cerr;
409 | }
410 |
411 | //!
412 | //! \brief method that implements logging test results
413 | //!
414 | static void reportTestResult(const TestAtom& testAtom, TestResult result)
415 | {
416 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
417 | << testAtom.mCmdline << std::endl;
418 | }
419 |
420 | //!
421 | //! \brief generate a command line string from the given (argc, argv) values
422 | //!
423 | static std::string genCmdlineString(int argc, char const* const* argv)
424 | {
425 | std::stringstream ss;
426 | for (int i = 0; i < argc; i++)
427 | {
428 | if (i > 0)
429 | ss << " ";
430 | ss << argv[i];
431 | }
432 | return ss.str();
433 | }
434 |
435 | Severity mReportableSeverity;
436 | };
437 |
438 | namespace
439 | {
440 |
441 | //!
442 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
443 | //!
444 | //! Example usage:
445 | //!
446 | //! LOG_VERBOSE(logger) << "hello world" << std::endl;
447 | //!
448 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
449 | {
450 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
451 | }
452 |
453 | //!
454 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
455 | //!
456 | //! Example usage:
457 | //!
458 | //! LOG_INFO(logger) << "hello world" << std::endl;
459 | //!
460 | inline LogStreamConsumer LOG_INFO(const Logger& logger)
461 | {
462 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
463 | }
464 |
465 | //!
466 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
467 | //!
468 | //! Example usage:
469 | //!
470 | //! LOG_WARN(logger) << "hello world" << std::endl;
471 | //!
472 | inline LogStreamConsumer LOG_WARN(const Logger& logger)
473 | {
474 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
475 | }
476 |
477 | //!
478 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
479 | //!
480 | //! Example usage:
481 | //!
482 | //! LOG_ERROR(logger) << "hello world" << std::endl;
483 | //!
484 | inline LogStreamConsumer LOG_ERROR(const Logger& logger)
485 | {
486 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
487 | }
488 |
489 | //!
490 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
491 | // ("fatal" severity)
492 | //!
493 | //! Example usage:
494 | //!
495 | //! LOG_FATAL(logger) << "hello world" << std::endl;
496 | //!
497 | inline LogStreamConsumer LOG_FATAL(const Logger& logger)
498 | {
499 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
500 | }
501 |
502 | } // anonymous namespace
503 |
504 | #endif // TENSORRT_LOGGING_H
505 |
--------------------------------------------------------------------------------
/include/macros.h:
--------------------------------------------------------------------------------
1 | #ifndef __MACROS_H
2 | #define __MACROS_H
3 |
4 | /// \file macros.h
5 | /// \brief Header file defining macros for API export and compatibility.
6 | ///
7 | /// This header file contains macros that facilitate
8 | /// the export and import of functions for shared libraries,
9 | /// as well as compatibility settings based on the NV TensorRT version.
10 | ///
11 | /// \author Hamdi Boukamcha
12 | /// \date 2024
13 |
14 | #ifdef API_EXPORTS
15 | #if defined(_MSC_VER)
16 | #define API __declspec(dllexport) ///< Macro for exporting functions in Windows.
17 | #else
18 | #define API __attribute__((visibility("default"))) ///< Macro for exporting functions in non-Windows environments.
19 | #endif
20 | #else
21 | #if defined(_MSC_VER)
22 | #define API __declspec(dllimport) ///< Macro for importing functions in Windows.
23 | #else
24 | #define API ///< No import/export in non-Windows environments.
25 | #endif
26 | #endif // API_EXPORTS
27 |
28 | #if NV_TENSORRT_MAJOR >= 8
29 | #define TRT_NOEXCEPT noexcept ///< Macro for noexcept specification based on TensorRT version.
30 | #define TRT_CONST_ENQUEUE const ///< Macro to define const enqueue for TensorRT version >= 8.
31 | #else
32 | #define TRT_NOEXCEPT ///< No exception specification for TensorRT version < 8.
33 | #define TRT_CONST_ENQUEUE ///< No const enqueue definition for TensorRT version < 8.
34 | #endif
35 |
36 | #endif // __MACROS_H
37 |
--------------------------------------------------------------------------------
/include/speedSam.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include "engineTRT.h"
5 |
6 | /// \class NanoSam
7 | /// \brief A class for handling image predictions with an encoder and decoder model.
8 | ///
9 | /// This class manages the process of encoding and decoding images,
10 | /// allowing for the prediction of masks based on input images and points.
11 | ///
12 | /// \author Hamdi Boukamcha
13 | /// \date 2024
14 | class SpeedSam
15 | {
16 |
17 | public:
18 | /// \brief Constructor for the NanoSam class.
19 | ///
20 | /// \param encoderPath Path to the encoder model.
21 | /// \param decoderPath Path to the decoder model.
22 | SpeedSam(std::string encoderPath, std::string decoderPath);
23 |
24 | /// \brief Destructor for the NanoSam class.
25 | ~SpeedSam();
26 |
27 | /// \brief Predicts masks based on the input image and points.
28 | ///
29 | /// \param image The input image for prediction.
30 | /// \param points The points used for mask prediction.
31 | /// \param labels The labels associated with the points.
32 | /// \return A matrix containing the predicted masks.
33 | Mat predict(Mat& image, std::vector points, std::vector labels);
34 |
35 | private:
36 | // Variables
37 | float* mFeatures; ///< Pointer to the feature data.
38 | float* mMaskInput; ///< Pointer to the mask input data.
39 | float* mHasMaskInput; ///< Pointer to the mask existence input data.
40 | float* mIouPrediction; ///< Pointer to the IoU prediction data.
41 | float* mLowResMasks; ///< Pointer to the low-resolution masks.
42 |
43 | EngineTRT* mImageEncoder; ///< Pointer to the image encoder module.
44 | EngineTRT* mMaskDecoder; ///< Pointer to the mask decoder module.
45 |
46 | /// \brief Upscales the given mask to the target width and height.
47 | ///
48 | /// \param mask The mask to upscale.
49 | /// \param targetWidth The target width for upscaling.
50 | /// \param targetHeight The target height for upscaling.
51 | /// \param size The size of the mask (default is 256).
52 | void upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size = 256);
53 |
54 | /// \brief Resizes the input image to match the model's dimensions.
55 | ///
56 | /// \param img The image to resize.
57 | /// \param modelWidth The width required by the model.
58 | /// \param modelHeight The height required by the model.
59 | /// \return A resized image matrix.
60 | Mat resizeImage(Mat& img, int modelWidth, int modelHeight);
61 |
62 | /// \brief Prepares the decoder input from the provided points.
63 | ///
64 | /// \param points The points to be converted into input data.
65 | /// \param pointData Pointer to the data array for points.
66 | /// \param numPoints The number of points.
67 | /// \param imageWidth The width of the input image.
68 | /// \param imageHeight The height of the input image.
69 | void prepareDecoderInput(std::vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight);
70 | };
71 |
--------------------------------------------------------------------------------
/include/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | /// \file colors.h
3 | /// \brief Header file defining color constants and utility functions.
4 | ///
5 | /// This header file contains a set of predefined colors for the Cityscapes dataset,
6 | /// a structure to hold clicked point data, and utility functions for overlaying masks
7 | /// on images and handling mouse events.
8 | ///
9 | /// \author Hamdi Boukamcha
10 | /// \date 2024
11 | #include
12 | using namespace cv;
13 | using namespace std;
14 |
15 | // Global variables to store the selected point and image
16 | Point selectedPoint;
17 | bool pointSelected = false;
18 | float resizeScale = 1.0f; // Scale factor for resizing
19 |
20 | // Colors
21 | const std::vector CITYSCAPES_COLORS = {
22 | cv::Scalar(128, 64, 128),
23 | cv::Scalar(232, 35, 244),
24 | cv::Scalar(70, 70, 70),
25 | cv::Scalar(156, 102, 102),
26 | cv::Scalar(153, 153, 190),
27 | cv::Scalar(153, 153, 153),
28 | cv::Scalar(30, 170, 250),
29 | cv::Scalar(0, 220, 220),
30 | cv::Scalar(35, 142, 107),
31 | cv::Scalar(152, 251, 152),
32 | cv::Scalar(180, 130, 70),
33 | cv::Scalar(60, 20, 220),
34 | cv::Scalar(0, 0, 255),
35 | cv::Scalar(142, 0, 0),
36 | cv::Scalar(70, 0, 0),
37 | cv::Scalar(100, 60, 0),
38 | cv::Scalar(90, 0, 0),
39 | cv::Scalar(230, 0, 0),
40 | cv::Scalar(32, 11, 119),
41 | cv::Scalar(0, 74, 111),
42 | cv::Scalar(81, 0, 81)
43 | };
44 |
45 | /// \struct PointData
46 | /// \brief Structure to hold clicked point coordinates.
47 | ///
48 | /// This structure stores the coordinates of a clicked point
49 | /// and a flag indicating whether the point has been clicked.
50 | struct PointData {
51 | cv::Point point; ///< The coordinates of the clicked point.
52 | bool clicked; ///< Flag indicating if the point was clicked.
53 | };
54 |
55 | // Mouse callback function to capture the clicked point
56 | void mouseCallback(int event, int x, int y, int flags, void* param) {
57 | if (event == EVENT_LBUTTONDOWN) {
58 | selectedPoint = Point(x, y);
59 | pointSelected = true;
60 | cout << "Point selected (in resized image): " << selectedPoint << endl;
61 | }
62 | }
63 |
64 | // Example overlay function
65 | void overlay(Mat& image, const Mat& mask) {
66 | // Placeholder for the overlay logic
67 | // This function should blend the mask with the original image
68 | addWeighted(image, 0.5, mask, 0.5, 0, image);
69 | }
70 |
71 | /// \brief Overlays a mask on the given image.
72 | ///
73 | /// This function overlays a colored mask on the input image
74 | /// using a specified transparency (alpha) and optionally shows
75 | /// the contour edges of the mask.
76 | ///
77 | /// \param image The image on which to overlay the mask.
78 | /// \param mask The mask to overlay on the image.
79 | /// \param color The color of the overlay mask (default is CITYSCAPES_COLORS[0]).
80 | /// \param alpha The transparency level of the overlay (default is 0.8).
81 | /// \param showEdge Whether to show the contour edges of the mask (default is true).
82 | void overlay(Mat& image, Mat& mask, cv::Scalar color = cv::Scalar(128, 64, 128), float alpha = 0.8f, bool showEdge = true)
83 | {
84 | // Draw mask
85 | Mat ucharMask(image.rows, image.cols, CV_8UC3, color);
86 | image.copyTo(ucharMask, mask <= 0);
87 | addWeighted(ucharMask, alpha, image, 1.0 - alpha, 0.0f, image);
88 |
89 | // Draw contour edge
90 | if (showEdge)
91 | {
92 | vector> contours;
93 | vector hierarchy;
94 | findContours(mask <= 0, contours, hierarchy, RETR_TREE, CHAIN_APPROX_NONE);
95 | drawContours(image, contours, -1, Scalar(255, 255, 255), 2);
96 | }
97 | }
98 |
99 | /// \brief Handles mouse events for clicking on the image.
100 | ///
101 | /// This function processes mouse events, storing the clicked
102 | /// point coordinates in the provided PointData structure.
103 | ///
104 | /// \param event The type of mouse event.
105 | /// \param x The x-coordinate of the mouse event.
106 | /// \param y The y-coordinate of the mouse event.
107 | /// \param flags Any relevant flags associated with the mouse event.
108 | /// \param userdata User data pointer to store the PointData structure.
109 | void onMouse(int event, int x, int y, int flags, void* userdata) {
110 | PointData* pd = (PointData*)userdata;
111 | if (event == cv::EVENT_LBUTTONDOWN) {
112 | // Save the clicked coordinates
113 | pd->point = cv::Point(x, y);
114 | pd->clicked = true;
115 | }
116 | }
117 |
118 | /// \brief Segments the image based on a user-selected point.
119 | ///
120 | /// This function allows the user to click on a point within the image, which is then
121 | /// used to perform segmentation using the NanoSam model. The segmented result is
122 | /// overlaid on the original image and displayed in the same window. The final image
123 | /// is saved to the specified output path.
124 | ///
125 | /// \param nanosam Reference to the SpeedSam model used for segmentation.
126 | /// \param imagePath Path to the input image.
127 | /// \param outputPath Path to save the segmented output image.
128 | void segmentWithPoint(SpeedSam& nanosam, const string& imagePath, const string& outputPath) {
129 | // Load the image from the specified path
130 | Mat image = imread(imagePath);
131 | if (image.empty()) {
132 | cerr << "Error: Unable to load image from " << imagePath << endl;
133 | return;
134 | }
135 |
136 | // Resize the image for easier viewing
137 | Mat resizedImage;
138 | const int maxWidth = 800; // Maximum width for display
139 | if (image.cols > maxWidth) {
140 | resizeScale = static_cast(maxWidth) / image.cols;
141 | resize(image, resizedImage, Size(), resizeScale, resizeScale);
142 | }
143 | else {
144 | resizedImage = image;
145 | }
146 |
147 | // Display the image in a window
148 | namedWindow("Select Point", WINDOW_AUTOSIZE);
149 | setMouseCallback("Select Point", mouseCallback, nullptr);
150 | imshow("Select Point", resizedImage);
151 |
152 | // Wait indefinitely until the user clicks a point
153 | while (!pointSelected) {
154 | waitKey(10); // Small delay to prevent high CPU usage
155 | }
156 |
157 | // Scale the selected point back to the original image size
158 | Point originalPoint(static_cast(selectedPoint.x / resizeScale),
159 | static_cast(selectedPoint.y / resizeScale));
160 | cout << "Point mapped to original image: " << originalPoint << endl;
161 |
162 | // Label indicating that the prompt point corresponds to the foreground class
163 | vector labels = { 1.0f };
164 |
165 | // Perform prediction using the NanoSam model at the specified prompt point
166 | auto mask = nanosam.predict(image, { originalPoint }, labels);
167 |
168 | // Overlay the segmentation mask on the original image
169 | overlay(image, mask);
170 |
171 | // Save the resulting image with the overlay to the specified output path
172 | imwrite(outputPath, image);
173 |
174 | if (image.cols > maxWidth) {
175 | resizeScale = static_cast(maxWidth) / image.cols;
176 | resize(image, resizedImage, Size(), resizeScale, resizeScale);
177 | }
178 | else {
179 | resizedImage = image;
180 | }
181 |
182 | // Update the same window with the segmented image
183 | imshow("Select Point", resizedImage);
184 | waitKey(0); // Wait for another key press to close the window
185 | destroyAllWindows(); // Close all OpenCV windows
186 | }
187 |
188 | /// \brief Segments an image using bounding box information and saves the result.
189 | ///
190 | /// This function loads an image from the specified path,
191 | /// performs segmentation using the provided bounding boxes,
192 | /// and saves the segmented image with an overlay to the specified output path.
193 | ///
194 | /// \param nanosam The SpeedSam model used for segmentation.
195 | /// \param imagePath The path to the input image.
196 | /// \param outputPath The path where the output image will be saved.
197 | /// \param bbox A vector of Points representing the top-left and bottom-right
198 | /// corners of the bounding box for segmentation.
199 | void segmentBbox(SpeedSam& nanosam, string imagePath, string outputPath) {
200 | // Load the image from the specified path
201 | auto image = imread(imagePath);
202 |
203 | // Check if the image was loaded successfully
204 | if (image.empty()) {
205 | cerr << "Error: Unable to load image from " << imagePath << endl;
206 | return;
207 | }
208 |
209 | // Create a window for user interaction
210 | namedWindow("Select and View Result", cv::WINDOW_AUTOSIZE);
211 |
212 | // Let the user select the bounding box
213 | cv::Rect bbox = selectROI("Select and View Result", image, false, false);
214 |
215 | // Check if a valid bounding box was selected
216 | if (bbox.width == 0 || bbox.height == 0) {
217 | cerr << "No valid bounding box selected." << endl;
218 | return;
219 | }
220 |
221 | // Convert the selected bounding box to a vector of points
222 | vector bboxPoints = {
223 | Point(bbox.x, bbox.y), // Top-left point
224 | Point(bbox.x + bbox.width, bbox.y + bbox.height) // Bottom-right point
225 | };
226 |
227 | // Labels corresponding to the bounding box classes
228 | // 2 : Bounding box top-left, 3 : Bounding box bottom-right
229 | vector labels = { 2, 3 };
230 |
231 | // Perform prediction using the NanoSam model with the given bounding boxes and labels
232 | auto mask = nanosam.predict(image, bboxPoints, labels);
233 |
234 | // Overlay the segmentation mask on the original image
235 | overlay(image, mask);
236 |
237 | // Draw the bounding box on the image
238 | rectangle(image, bboxPoints[0], bboxPoints[1], cv::Scalar(255, 255, 0), 3);
239 |
240 | // Display the updated image in the same window
241 | imshow("Select and View Result", image);
242 | waitKey(0);
243 |
244 | // Save the resulting image to the specified output path
245 | imwrite(outputPath, image);
246 | }
247 |
--------------------------------------------------------------------------------
/model/SAM_encoder.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/model/SAM_encoder.onnx
--------------------------------------------------------------------------------
/model/SAM_mask_decoder.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/model/SAM_mask_decoder.onnx
--------------------------------------------------------------------------------
/src/engineTRT.cpp:
--------------------------------------------------------------------------------
1 | #include "engineTRT.h"
2 | #include "logging.h"
3 | #include "cuda_utils.h"
4 | #include "config.h"
5 | #include "macros.h"
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | static Logger gLogger;
15 |
16 | std::string getFileExtension(const std::string& filePath) {
17 | // Find the position of the last dot in the file path
18 | size_t dotPos = filePath.find_last_of(".");
19 | // If a dot is found, extract and return the substring after the dot as the file extension
20 | if (dotPos != std::string::npos) {
21 | return filePath.substr(dotPos + 1);
22 | }
23 | // Return an empty string if no extension is found
24 | return "";
25 | }
26 |
27 | EngineTRT::EngineTRT(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16) {
28 | // Check if the model file has an ".onnx" extension
29 | if (getFileExtension(modelPath) == "onnx") {
30 | // If the file is an ONNX model, build the engine using the provided parameters
31 | cout << "Building Engine from " << modelPath << endl;
32 | build(modelPath, inputNames, outputNames, isDynamicShape, isFP16);
33 | }
34 | else {
35 | // If the file is not an ONNX model, deserialize an existing engine
36 | cout << "Deserializing Engine." << endl;
37 | deserializeEngine(modelPath, inputNames, outputNames);
38 | }
39 | }
40 |
41 | EngineTRT::~EngineTRT() {
42 | // Release the CUDA stream
43 | cudaStreamDestroy(mCudaStream);
44 | // Free GPU buffers allocated for inference
45 | for (int i = 0; i < mGpuBuffers.size(); i++)
46 | CUDA_CHECK(cudaFree(mGpuBuffers[i]));
47 | // Free CPU buffers
48 | for (int i = 0; i < mCpuBuffers.size(); i++)
49 | delete[] mCpuBuffers[i];
50 |
51 | // Clean up and destroy the TensorRT engine components
52 | delete mContext; // Destroy the execution context
53 | delete mEngine; // Destroy the engine
54 | delete mRuntime; // Destroy the runtime
55 | }
56 |
57 | void EngineTRT::build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16)
58 | {
59 | // Check if the ONNX file exists. If not, print an error message and return.
60 | if (!std::filesystem::exists(onnxPath)) {
61 | std::cerr << "ONNX file not found: " << onnxPath << std::endl;
62 | return; // Early exit if the ONNX file is missing
63 | }
64 |
65 | // Create an inference builder for building the TensorRT engine.
66 | auto builder = createInferBuilder(gLogger);
67 | assert(builder != nullptr); // Ensure the builder is created successfully
68 |
69 | // Use explicit batch size, which is needed for ONNX models.
70 | const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
71 | INetworkDefinition* network = builder->createNetworkV2(explicitBatch);
72 | assert(network != nullptr); // Ensure the network is created successfully
73 |
74 | // Create a builder configuration object to set options like FP16 precision.
75 | IBuilderConfig* config = builder->createBuilderConfig();
76 | assert(config != nullptr); // Ensure the config is created successfully
77 |
78 | // If dynamic shape support is needed, configure the optimization profile.
79 | if (isDynamicShape) // Only designed for NanoSAM mask decoder
80 | {
81 | // Create an optimization profile for dynamic input shapes.
82 | auto profile = builder->createOptimizationProfile();
83 |
84 | // Set the minimum, optimal, and maximum dimensions for the first input.
85 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMIN, Dims3{ 1, 1, 2 });
86 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kOPT, Dims3{ 1, 1, 2 });
87 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMAX, Dims3{ 1, 10, 2 });
88 |
89 | // Set the minimum, optimal, and maximum dimensions for the second input.
90 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMIN, Dims2{ 1, 1 });
91 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kOPT, Dims2{ 1, 1 });
92 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMAX, Dims2{ 1, 10 });
93 |
94 | // Add the optimization profile to the builder configuration.
95 | config->addOptimizationProfile(profile);
96 | }
97 |
98 | // Enable FP16 mode if specified.
99 | if (isFP16)
100 | {
101 | config->setFlag(BuilderFlag::kFP16); // Use mixed precision for faster inference
102 | }
103 |
104 | // Create a parser to convert the ONNX model to a TensorRT network.
105 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger);
106 | assert(parser != nullptr); // Ensure the parser is created successfully
107 |
108 | // Parse the ONNX model from the specified file.
109 | bool parsed = parser->parseFromFile(onnxPath.c_str(), static_cast(gLogger.getReportableSeverity()));
110 |
111 | // Ensure the CUDA stream used for profiling is valid.
112 | assert(mCudaStream != nullptr);
113 |
114 | // Serialize the built network into a binary plan for execution.
115 | IHostMemory* plan{ builder->buildSerializedNetwork(*network, *config) };
116 | assert(plan != nullptr); // Ensure the network was serialized successfully
117 |
118 | // Create a runtime object for deserializing the engine.
119 | mRuntime = createInferRuntime(gLogger);
120 | assert(mRuntime != nullptr); // Ensure the runtime is created successfully
121 |
122 | // Deserialize the serialized plan to create an execution engine.
123 | mEngine = mRuntime->deserializeCudaEngine(plan->data(), plan->size(), nullptr);
124 | assert(mEngine != nullptr); // Ensure the engine was deserialized successfully
125 |
126 | // Create an execution context for running inference.
127 | mContext = mEngine->createExecutionContext();
128 | assert(mContext != nullptr); // Ensure the context is created successfully
129 |
130 | // Clean up resources.
131 | delete network;
132 | delete config;
133 | delete parser;
134 | delete plan;
135 |
136 | // Initialize the engine with the input and output names.
137 | initialize(inputNames, outputNames);
138 | }
139 |
140 | void EngineTRT::saveEngine(const std::string& engineFilePath) {
141 | if (mEngine) {
142 | // Serialize the engine to a binary format.
143 | IHostMemory* serializedEngine = mEngine->serialize();
144 | std::ofstream engineFile(engineFilePath, std::ios::binary);
145 | if (engineFile) {
146 | // Write the serialized engine data to the specified file.
147 | engineFile.write(reinterpret_cast(serializedEngine->data()), serializedEngine->size());
148 | std::cout << "Serialized engine saved to " << engineFilePath << std::endl;
149 | }
150 | serializedEngine->destroy(); // Destroy the serialized engine memory
151 | }
152 | }
153 |
154 | void EngineTRT::deserializeEngine(string engine_name, vector inputNames, vector outputNames)
155 | {
156 | // Open the engine file in binary mode.
157 | std::ifstream file(engine_name, std::ios::binary);
158 | if (!file.good()) {
159 | std::cerr << "read " << engine_name << " error!" << std::endl;
160 | assert(false); // Trigger an assertion failure if the file cannot be opened
161 | }
162 |
163 | // Determine the size of the file and read the serialized engine data.
164 | size_t size = 0;
165 | file.seekg(0, file.end);
166 | size = file.tellg();
167 | file.seekg(0, file.beg);
168 | char* serializedEngine = new char[size];
169 | assert(serializedEngine); // Ensure memory allocation was successful
170 | file.read(serializedEngine, size);
171 | file.close();
172 |
173 | // Create a runtime object and deserialize the engine.
174 | mRuntime = createInferRuntime(gLogger);
175 | assert(mRuntime); // Ensure the runtime is created successfully
176 | mEngine = mRuntime->deserializeCudaEngine(serializedEngine, size);
177 | mContext = mEngine->createExecutionContext();
178 | delete[] serializedEngine; // Free the serialized engine memory
179 |
180 | // Ensure the number of bindings matches the expected number of inputs and outputs.
181 | assert(mEngine->getNbBindings() == inputNames.size() + outputNames.size());
182 |
183 | // Initialize the engine with the input and output names.
184 | initialize(inputNames, outputNames);
185 | }
186 |
187 | void EngineTRT::initialize(vector inputNames, vector outputNames)
188 | {
189 | // Loop through the input names and get the corresponding binding index from the TensorRT engine
190 | for (int i = 0; i < inputNames.size(); i++)
191 | {
192 | const int inputIndex = mEngine->getBindingIndex(inputNames[i].c_str());
193 | }
194 |
195 | // Loop through the output names and get the corresponding binding index from the TensorRT engine
196 | for (int i = 0; i < outputNames.size(); i++)
197 | {
198 | const int outputIndex = mEngine->getBindingIndex(outputNames[i].c_str());
199 | }
200 |
201 | // Resize the GPU and CPU buffer vectors to accommodate all the engine bindings
202 | mGpuBuffers.resize(mEngine->getNbBindings());
203 | mCpuBuffers.resize(mEngine->getNbBindings());
204 |
205 | // Loop through all bindings to allocate memory and store dimension information
206 | for (size_t i = 0; i < mEngine->getNbBindings(); ++i)
207 | {
208 | // Calculate the size required for the binding based on its dimensions
209 | size_t binding_size = getSizeByDim(mEngine->getBindingDimensions(i));
210 | mBufferBindingSizes.push_back(binding_size); // Store the size of the binding
211 | mBufferBindingBytes.push_back(binding_size * sizeof(float)); // Calculate the size in bytes
212 |
213 | // Allocate host memory for the CPU buffer
214 | mCpuBuffers[i] = new float[binding_size];
215 |
216 | // Allocate device memory for the GPU buffer
217 | cudaMalloc(&mGpuBuffers[i], mBufferBindingBytes[i]);
218 |
219 | // Store input and output dimensions separately based on whether the binding is an input or output
220 | if (mEngine->bindingIsInput(i))
221 | {
222 | mInputDims.push_back(mEngine->getBindingDimensions(i));
223 | }
224 | else
225 | {
226 | mOutputDims.push_back(mEngine->getBindingDimensions(i));
227 | }
228 | }
229 |
230 | // Create a CUDA stream for asynchronous operations
231 | CUDA_CHECK(cudaStreamCreate(&mCudaStream));
232 | }
233 |
234 | bool EngineTRT::infer()
235 | {
236 | // Copy data from host (CPU) input buffers to device (GPU) input buffers asynchronously
237 | copyInputToDeviceAsync(mCudaStream);
238 |
239 | // Perform inference using TensorRT, passing the GPU buffers
240 | bool status = mContext->executeV2(mGpuBuffers.data());
241 |
242 | if (!status)
243 | {
244 | // If inference fails, print an error message and return false
245 | cout << "inference error!" << endl;
246 | return false;
247 | }
248 |
249 | // Copy the results from device (GPU) output buffers to host (CPU) output buffers asynchronously
250 | copyOutputToHostAsync(mCudaStream);
251 |
252 | // Return true if inference was successful
253 | return true;
254 | }
255 |
256 | void EngineTRT::copyInputToDeviceAsync(const cudaStream_t& stream)
257 | {
258 | // Perform asynchronous memory copy from CPU to GPU for the input buffers
259 | memcpyBuffers(true, false, true, stream);
260 | }
261 |
262 | void EngineTRT::copyOutputToHostAsync(const cudaStream_t& stream)
263 | {
264 | // Calls memcpyBuffers to handle the copying of data from GPU to CPU memory.
265 | // Arguments: false (do not copy input buffers), true (copy data from device to host),
266 | // true (perform the copy asynchronously), and the given CUDA stream.
267 | memcpyBuffers(false, true, true, stream);
268 | }
269 |
270 | void EngineTRT::memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream)
271 | {
272 | // Loop through all bindings (inputs and outputs) in the TensorRT engine.
273 | for (int i = 0; i < mEngine->getNbBindings(); i++)
274 | {
275 | // Determine the destination and source pointers based on the copy direction.
276 | void* dstPtr = deviceToHost ? mCpuBuffers[i] : mGpuBuffers[i];
277 | const void* srcPtr = deviceToHost ? mGpuBuffers[i] : mCpuBuffers[i];
278 | // Get the size of the buffer in bytes.
279 | const size_t byteSize = mBufferBindingBytes[i];
280 | // Set the type of memory copy operation based on the direction.
281 | const cudaMemcpyKind memcpyType = deviceToHost ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
282 |
283 | // Check if the current binding is an input or output and copy accordingly.
284 | if ((copyInput && mEngine->bindingIsInput(i)) || (!copyInput && !mEngine->bindingIsInput(i)))
285 | {
286 | if (async)
287 | {
288 | // Perform asynchronous memory copy using the CUDA stream.
289 | CUDA_CHECK(cudaMemcpyAsync(dstPtr, srcPtr, byteSize, memcpyType, stream));
290 | }
291 | else
292 | {
293 | // Perform synchronous memory copy.
294 | CUDA_CHECK(cudaMemcpy(dstPtr, srcPtr, byteSize, memcpyType));
295 | }
296 | }
297 | }
298 | }
299 |
300 | size_t EngineTRT::getSizeByDim(const Dims& dims)
301 | {
302 | size_t size = 1;
303 |
304 | // Loop through each dimension and multiply to calculate the total size.
305 | for (size_t i = 0; i < dims.nbDims; ++i)
306 | {
307 | // If the dimension is -1 (dynamic), use a predefined maximum size.
308 | if (dims.d[i] == -1)
309 | size *= MAX_NUM_PROMPTS;
310 | else
311 | size *= dims.d[i];
312 | }
313 |
314 | return size;
315 | }
316 |
317 | void EngineTRT::setInput(Mat& image)
318 | {
319 | // Extract input dimensions (height and width) from the model's input shape
320 | const int inputH = mInputDims[0].d[2];
321 | const int inputW = mInputDims[0].d[3];
322 |
323 | int i = 0; // Index counter for buffer placement
324 |
325 | // Iterate over each pixel in the input image
326 | for (int row = 0; row < image.rows; ++row)
327 | {
328 | // Pointer to the start of the row in the image data
329 | uchar* uc_pixel = image.data + row * image.step;
330 |
331 | for (int col = 0; col < image.cols; ++col)
332 | {
333 | // Normalizing the pixel values for the RGB channels
334 | // Convert the BGR image to normalized RGB and store in mCpuBuffers
335 | mCpuBuffers[0][i] = ((float)uc_pixel[2] / 255.0f - 0.485f) / 0.229f; // Red channel
336 | mCpuBuffers[0][i + image.rows * image.cols] = ((float)uc_pixel[1] / 255.0f - 0.456f) / 0.224f; // Green channel
337 | mCpuBuffers[0][i + 2 * image.rows * image.cols] = ((float)uc_pixel[0] / 255.0f - 0.406f) / 0.225f; // Blue channel
338 |
339 | uc_pixel += 3; // Move to the next pixel
340 | ++i; // Increment index
341 | }
342 | }
343 | }
344 |
345 | void EngineTRT::setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasMaskInput, int numPoints)
346 | {
347 | // Clean up old buffers and allocate new buffers for the input data
348 | delete[] mCpuBuffers[1];
349 | delete[] mCpuBuffers[2];
350 | mCpuBuffers[1] = new float[numPoints * 2]; // Buffer for point coordinates
351 | mCpuBuffers[2] = new float[numPoints]; // Buffer for point labels
352 |
353 | // Allocate memory on the GPU for the input data
354 | cudaMalloc(&mGpuBuffers[1], sizeof(float) * numPoints * 2); // Coordinates
355 | cudaMalloc(&mGpuBuffers[2], sizeof(float) * numPoints); // Labels
356 |
357 | // Set the size of the data binding in bytes for TensorRT
358 | mBufferBindingBytes[1] = sizeof(float) * numPoints * 2;
359 | mBufferBindingBytes[2] = sizeof(float) * numPoints;
360 |
361 | // Copy input data into CPU buffers
362 | memcpy(mCpuBuffers[0], features, mBufferBindingBytes[0]);
363 | memcpy(mCpuBuffers[1], imagePointCoords, sizeof(float) * numPoints * 2);
364 | memcpy(mCpuBuffers[2], imagePointLabels, sizeof(float) * numPoints);
365 | memcpy(mCpuBuffers[3], maskInput, mBufferBindingBytes[3]);
366 | memcpy(mCpuBuffers[4], hasMaskInput, mBufferBindingBytes[4]);
367 |
368 | // Configure TensorRT to use a dynamic input shape
369 | mContext->setOptimizationProfileAsync(0, mCudaStream); // Set the optimization profile
370 | mContext->setBindingDimensions(1, Dims3{ 1, numPoints, 2 }); // Set input dimensions for coordinates
371 | mContext->setBindingDimensions(2, Dims2{ 1, numPoints }); // Set input dimensions for labels
372 | }
373 |
374 | void EngineTRT::getOutput(float* features)
375 | {
376 | // Copy the output features from the CPU buffer to the provided memory
377 | memcpy(features, mCpuBuffers[1], mBufferBindingBytes[1]);
378 | }
379 |
380 | void EngineTRT::getOutput(float* iouPrediction, float* lowResolutionMasks)
381 | {
382 | // Copy the low-resolution masks and IOU predictions from the CPU buffers
383 | memcpy(lowResolutionMasks, mCpuBuffers[5], mBufferBindingBytes[5]);
384 | memcpy(iouPrediction, mCpuBuffers[6], mBufferBindingBytes[6]);
385 | }
386 |
387 |
--------------------------------------------------------------------------------
/src/main.cpp:
--------------------------------------------------------------------------------
1 | #include "speedSam.h"
2 | #include "utils.h"
3 |
4 | int main()
5 | {
6 | // Option : Set the running Path
7 | std::string path = "";
8 |
9 | // Build the engines from onnx files
10 | SpeedSam Speedsam(path + "model/SAM_encoder.onnx", path + "model/SAM_mask_decoder.onnx");
11 |
12 | /*Segmentation examples */
13 |
14 | // Demo 1: Segment using a point
15 | segmentWithPoint(Speedsam, path + "assets/dog.jpg", path + "assets/dog_mask.jpg");
16 |
17 | // Demo 2: Segment using a bounding box
18 | segmentBbox(Speedsam, path + "assets/dogs.jpg", path + "assets/dogs_mask.jpg");
19 |
20 | return 0;
21 | }
22 |
--------------------------------------------------------------------------------
/src/speedSam.cpp:
--------------------------------------------------------------------------------
1 | #include "speedSam.h"
2 | #include "config.h"
3 |
4 | using namespace std;
5 |
6 | SpeedSam::SpeedSam(string encoderPath, string decoderPath)
7 | {
8 | // Initialize the image encoder and mask decoder using the provided model paths
9 | mImageEncoder = new EngineTRT(encoderPath,
10 | { "image" }, // Input names for the encoder
11 | { "image_embeddings" }, // Output names for the encoder
12 | false, // Not using dynamic shape
13 | true); // Using FP16 precision
14 |
15 | mMaskDecoder = new EngineTRT(decoderPath,
16 | { "image_embeddings", "point_coords", "point_labels", "mask_input", "has_mask_input" }, // Input names for the decoder
17 | { "iou_predictions", "low_res_masks" }, // Output names for the decoder
18 | true, // Using dynamic shape
19 | false); // Not using FP16 precision
20 |
21 | // Allocate memory for model features and inputs
22 | mFeatures = new float[HIDDEN_DIM * FEATURE_WIDTH * FEATURE_HEIGHT];
23 | mMaskInput = new float[HIDDEN_DIM * HIDDEN_DIM];
24 | mHasMaskInput = new float; // Pointer for mask input presence
25 | mIouPrediction = new float[NUM_LABELS]; // IOU prediction output
26 | mLowResMasks = new float[NUM_LABELS * HIDDEN_DIM * HIDDEN_DIM]; // Low-resolution masks output
27 | }
28 |
29 | SpeedSam::~SpeedSam()
30 | {
31 | // Clean up dynamically allocated memory
32 | if (mFeatures) delete[] mFeatures;
33 | if (mMaskInput) delete[] mMaskInput;
34 | if (mIouPrediction) delete[] mIouPrediction;
35 | if (mLowResMasks) delete[] mLowResMasks;
36 |
37 | if (mImageEncoder) delete mImageEncoder;
38 | if (mMaskDecoder) delete mMaskDecoder;
39 | }
40 |
41 | Mat SpeedSam::predict(Mat& image, vector points, vector labels)
42 | {
43 | // If no points are provided, return an empty mask
44 | if (points.size() == 0) return cv::Mat(image.rows, image.cols, CV_32FC1);
45 |
46 | // Preprocess the input image for the encoder
47 | auto resizedImage = resizeImage(image, MODEL_INPUT_WIDTH, MODEL_INPUT_HEIGHT);
48 |
49 | // Perform inference with the image encoder
50 | mImageEncoder->setInput(resizedImage);
51 | mImageEncoder->infer();
52 | mImageEncoder->getOutput(mFeatures);
53 |
54 | // Prepare decoder input data for the specified points
55 | auto pointData = new float[2 * points.size()]; // Array to hold scaled point coordinates
56 | prepareDecoderInput(points, pointData, points.size(), image.cols, image.rows);
57 |
58 | // Perform inference with the mask decoder
59 | mMaskDecoder->setInput(mFeatures, pointData, labels.data(), mMaskInput, mHasMaskInput, points.size());
60 | mMaskDecoder->infer();
61 | mMaskDecoder->getOutput(mIouPrediction, mLowResMasks);
62 |
63 | // Post-process the output mask
64 | Mat imgMask(HIDDEN_DIM, HIDDEN_DIM, CV_32FC1, mLowResMasks);
65 | upscaleMask(imgMask, image.cols, image.rows); // Upscale to original image size
66 |
67 | delete[] pointData; // Clean up dynamically allocated memory for point data
68 |
69 | return imgMask; // Return the segmented mask
70 | }
71 |
72 | void SpeedSam::prepareDecoderInput(vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight)
73 | {
74 | float scale = MODEL_INPUT_WIDTH / max(imageWidth, imageHeight); // Calculate scaling factor
75 |
76 | // Scale point coordinates
77 | for (int i = 0; i < numPoints; i++)
78 | {
79 | pointData[i * 2] = (float)points[i].x * scale; // X coordinate
80 | pointData[i * 2 + 1] = (float)points[i].y * scale; // Y coordinate
81 | }
82 |
83 | // Initialize mask input data
84 | for (int i = 0; i < HIDDEN_DIM * HIDDEN_DIM; i++)
85 | {
86 | mMaskInput[i] = 0; // Set mask input to zero
87 | }
88 | *mHasMaskInput = 0; // Set has mask input to false
89 | }
90 |
91 | Mat SpeedSam::resizeImage(Mat& img, int inputWidth, int inputHeight)
92 | {
93 | int w, h;
94 | float aspectRatio = (float)img.cols / (float)img.rows; // Calculate aspect ratio
95 |
96 | // Determine new dimensions while maintaining aspect ratio
97 | if (aspectRatio >= 1)
98 | {
99 | w = inputWidth;
100 | h = int(inputHeight / aspectRatio);
101 | }
102 | else
103 | {
104 | w = int(inputWidth * aspectRatio);
105 | h = inputHeight;
106 | }
107 |
108 | // Create a new image with the new size
109 | Mat re(h, w, CV_8UC3);
110 | cv::resize(img, re, re.size(), 0, 0, INTER_LINEAR); // Resize the original image
111 | Mat out(inputHeight, inputWidth, CV_8UC3, 0.0); // Initialize output image
112 | re.copyTo(out(Rect(0, 0, re.cols, re.rows))); // Copy resized image to output
113 |
114 | return out; // Return the resized image
115 | }
116 |
117 | void SpeedSam::upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size)
118 | {
119 | int limX, limY;
120 | // Calculate limits for upscaling based on target dimensions
121 | if (targetWidth > targetHeight)
122 | {
123 | limX = size;
124 | limY = size * targetHeight / targetWidth;
125 | }
126 | else
127 | {
128 | limX = size * targetWidth / targetHeight;
129 | limY = size;
130 | }
131 |
132 | // Resize the mask to the target dimensions
133 | cv::resize(mask(Rect(0, 0, limX, limY)), mask, Size(targetWidth, targetHeight));
134 | }
135 |
--------------------------------------------------------------------------------