├── CMakeLists.txt ├── LICENSE ├── README.md ├── assets ├── BY HAMDI BOUKAMCHA.JPG ├── Speed_SAM_Results.JPG ├── blue-linkedin-logo.png ├── dog.jpg ├── dogs.jpg └── speed_sam_cpp_tenosrrt.PNG ├── include ├── config.h ├── cuda_utils.h ├── engineTRT.h ├── logging.h ├── macros.h ├── speedSam.h └── utils.h ├── model ├── SAM_encoder.onnx └── SAM_mask_decoder.onnx └── src ├── engineTRT.cpp ├── main.cpp └── speedSam.cpp /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18) 2 | project(nanosam LANGUAGES CXX CUDA) 3 | 4 | # Set C++ standard 5 | set(CMAKE_CXX_STANDARD 17) 6 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 7 | 8 | # Suppress ZERO_CHECK generation in Visual Studio 9 | set(CMAKE_SUPPRESS_REGENERATION ON) 10 | 11 | # Find packages 12 | find_package(CUDA REQUIRED) 13 | 14 | # Set OpenCV root directory 15 | set(OpenCV_DIR "F:/Program Files (x86)/opencv-4.10.0-windows/opencv/build") 16 | find_package(OpenCV REQUIRED) 17 | 18 | # Set TensorRT root directory 19 | set(TENSORRT_ROOT "F:/Program Files/TensorRT-8.6.1.6") 20 | find_path(TENSORRT_INCLUDE_DIR NvInfer.h PATHS ${TENSORRT_ROOT}/include) 21 | find_library(NVINFER_LIBRARY nvinfer PATHS ${TENSORRT_ROOT}/lib) 22 | find_library(NVONNXPARSER_LIBRARY nvonnxparser PATHS ${TENSORRT_ROOT}/lib) 23 | 24 | # Include directories 25 | include_directories( 26 | ${CUDA_INCLUDE_DIRS} 27 | ${TENSORRT_INCLUDE_DIR} 28 | ${OpenCV_INCLUDE_DIRS} 29 | ${CMAKE_SOURCE_DIR}/include # Add include directory 30 | ) 31 | 32 | # Add source files 33 | file(GLOB SOURCES 34 | ${CMAKE_SOURCE_DIR}/src/*.cpp 35 | ) 36 | 37 | # Add header files 38 | file(GLOB HEADERS 39 | ${CMAKE_SOURCE_DIR}/include/*.h 40 | ) 41 | 42 | # Organize headers under "Header Files" in Visual Studio 43 | source_group("Header Files" FILES ${HEADERS}) 44 | 45 | # CUDA compilation flags 46 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17") 47 | 48 | # Add main executable without generating ALL_BUILD 49 | add_executable(SpeedSAM ${SOURCES} ${HEADERS}) 50 | 51 | # Link libraries 52 | target_link_libraries(SpeedSAM 53 | ${NVINFER_LIBRARY} 54 | ${NVONNXPARSER_LIBRARY} 55 | ${CUDA_LIBRARIES} 56 | ${OpenCV_LIBS} 57 | ) 58 | 59 | # Set up custom CUDA runtime library if needed 60 | set_target_properties(SpeedSAM PROPERTIES CUDA_SEPARABLE_COMPILATION ON) 61 | 62 | # TensorRT and CUDA runtime flags 63 | target_compile_definitions(SpeedSAM PRIVATE USE_TENSORRT USE_CUDA) 64 | 65 | # Export paths for flexibility in multi-platform use 66 | message(STATUS "CUDA Libraries: ${CUDA_LIBRARIES}") 67 | message(STATUS "TensorRT Libraries: ${TENSORRT_INCLUDE_DIR}, ${NVINFER_LIBRARY}, ${NVONNXPARSER_LIBRARY}") 68 | message(STATUS "OpenCV Directory: ${OpenCV_DIR}") 69 | 70 | # Set main executable as default startup project (optional) 71 | set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY VS_STARTUP_PROJECT SpeedSAM) 72 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Custom License Agreement 2 | 3 | 1. License Grant You are hereby granted a non-exclusive, non-transferable license to use, reproduce, and distribute the code (hereinafter referred to as "the Software") under the following conditions: 4 | 5 | 2. Conditions of Use 6 | 7 | Non-Commercial Use: You may use the Software for personal, educational, or non-commercial purposes without any additional permissions. 8 | Commercial Use: Any commercial use of the Software, including but not limited to selling, licensing, or using it in a commercial product, requires prior written permission from the original developer. 9 | 3. Contact Requirement 10 | 11 | If you wish to use the Software for commercial purposes, you must contact the original developer at [https://www.linkedin.com/in/hamdi-boukamcha/] to obtain a commercial license. 12 | The terms of any commercial license will be mutually agreed upon and may involve a licensing fee. 13 | 4. Attribution 14 | 15 | Regardless of whether you are using the Software for commercial or non-commercial purposes, you must provide appropriate credit to the original developer in any distributions or products that use the Software. 16 | 5. Disclaimer of Warranty 17 | 18 | The Software is provided "as is," without warranty of any kind, express or implied, including but not limited to the warranties of merchantability, fitness for a particular purpose, and non-infringement. In no event shall the original developer be liable for any claim, damages, or other liability, whether in an action of contract, tort, or otherwise, arising from, out of, or in connection with the Software or the use or other dealings in the Software. 19 | 6. Governing Law 20 | 21 | This License Agreement shall be governed by and construed in accordance with the laws of France. 22 | By using the Software, you agree to abide by the terms outlined in this License Agreement. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPEED SAM C++ TENSORRT 2 | ![SAM C++ TENSORRT](assets/speed_sam_cpp_tenosrrt.PNG) 3 | 4 | 5 | GitHub 6 | 7 | 8 | 9 | License 10 | 11 | 12 | ## 🌐 Overview 13 | A high-performance C++ implementation for SAM (segment anything model) using TensorRT and CUDA, optimized for real-time image segmentation tasks. 14 | 15 | ## 📢 Updates 16 | Model Conversion: Build TensorRT engines from ONNX models for accelerated inference. 17 | Segmentation with Points and BBoxes: Easily segment images using selected points or bounding boxes. 18 | FP16 Precision: Choose between FP16 and FP32 for speed and precision balance. 19 | Dynamic Shape Support: Efficient handling of variable input sizes using optimization profiles. 20 | CUDA Optimization: Leverage CUDA for preprocessing and efficient memory handling. 21 | 22 | ## 📢 Performance 23 | ### Infernce Time 24 | 25 | | Component | SpeedSAM | 26 | |----------------------------|-----------| 27 | | **Image Encoder** | | 28 | | Parameters | 5M | 29 | | Speed | 8ms | 30 | | **Mask Decoder** | | 31 | | Parameters | 3.876M | 32 | | Speed | 4ms | 33 | | **Whole Pipeline (Enc+Dec)** | | 34 | | Parameters | 9.66M | 35 | | Speed | 12ms | 36 | ### Results 37 | ![SPEED-SAM-C-TENSORRT RESULT](assets/Speed_SAM_Results.JPG) 38 | 39 | ## 📂 Project Structure 40 | SPEED-SAM-CPP-TENSORRT/ 41 | ├── include 42 | │ ├── config.h # Model configuration and macros 43 | │ ├── cuda_utils.h # CUDA utility macros 44 | │ ├── engineTRT.h # TensorRT engine management 45 | │ ├── logging.h # Logging utilities 46 | │ ├── macros.h # API export/import macros 47 | │ ├── speedSam.h # SpeedSam class definition 48 | │ └── utils.h # Utility functions for image handling 49 | ├── src 50 | │ ├── engineTRT.cpp # Implementation of the TensorRT engine 51 | │ ├── main.cpp # Main entry point 52 | │ └── speedSam.cpp # Implementation of the SpeedSam class 53 | └── CMakeLists.txt # CMake configuration 54 | 55 | # 🚀 Installation 56 | ## Prerequisites 57 | git clone https://github.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT.git 58 | cd SPEED-SAM-CPP-TENSORRT 59 | 60 | # Create a build directory and compile 61 | mkdir build && cd build 62 | cmake .. 63 | make -j$(nproc) 64 | Note: Update the CMakeLists.txt with the correct paths for TensorRT and OpenCV. 65 | 66 | ## 📦 Dependencies 67 | CUDA: NVIDIA's parallel computing platform 68 | TensorRT: High-performance deep learning inference 69 | OpenCV: Image processing library 70 | C++17: Required standard for compilation 71 | 72 | # 🔍 Code Overview 73 | ## Main Components 74 | SpeedSam Class (speedSam.h): Manages image encoding and mask decoding. 75 | EngineTRT Class (engineTRT.h): TensorRT engine creation and inference. 76 | CUDA Utilities (cuda_utils.h): Macros for CUDA error handling. 77 | Config (config.h): Defines model parameters and precision settings. 78 | ## Key Functions 79 | EngineTRT::build: Builds the TensorRT engine from an ONNX model. 80 | EngineTRT::infer: Runs inference on the provided input data. 81 | SpeedSam::predict: Segments an image using input points or bounding boxes. 82 | ## 📞 Contact 83 | 84 | For advanced inquiries, feel free to contact me on LinkedIn: LinkedIn 85 | 86 | ## 📜 Citation 87 | 88 | If you use this code in your research, please cite the repository as follows: 89 | 90 | @misc{boukamcha2024SpeedSam, 91 | author = {Hamdi Boukamcha}, 92 | title = {SPEED-SAM-C-TENSORRT}, 93 | year = {2024}, 94 | publisher = {GitHub}, 95 | howpublished = {\url{https://github.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT//}}, 96 | } 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /assets/BY HAMDI BOUKAMCHA.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/BY HAMDI BOUKAMCHA.JPG -------------------------------------------------------------------------------- /assets/Speed_SAM_Results.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/Speed_SAM_Results.JPG -------------------------------------------------------------------------------- /assets/blue-linkedin-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/blue-linkedin-logo.png -------------------------------------------------------------------------------- /assets/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/dog.jpg -------------------------------------------------------------------------------- /assets/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/dogs.jpg -------------------------------------------------------------------------------- /assets/speed_sam_cpp_tenosrrt.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/assets/speed_sam_cpp_tenosrrt.PNG -------------------------------------------------------------------------------- /include/config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /// \file model_params.h 4 | /// \brief Header file defining model parameters and configuration macros. 5 | /// 6 | /// This header file contains macros for model parameters including input dimensions, 7 | /// feature dimensions, and precision settings. 8 | /// 9 | /// \author Hamdi Boukamcha 10 | /// \date 2024 11 | 12 | #define USE_FP16 ///< Set to use FP16 (float16) precision, or comment to use FP32 (float32) precision. 13 | 14 | #define MAX_NUM_PROMPTS 1 ///< Maximum number of prompts to be processed at once. 15 | 16 | // Model Params 17 | #define MODEL_INPUT_WIDTH 1024.0f ///< Width of the model input in pixels. 18 | #define MODEL_INPUT_HEIGHT 1024.0f ///< Height of the model input in pixels. 19 | #define HIDDEN_DIM 256 ///< Dimension of the hidden layer. 20 | #define NUM_LABELS 4 ///< Number of output labels. 21 | #define FEATURE_WIDTH 64 ///< Width of the feature map. 22 | #define FEATURE_HEIGHT 64 ///< Height of the feature map. 23 | -------------------------------------------------------------------------------- /include/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef TRTX_CUDA_UTILS_H_ 2 | #define TRTX_CUDA_UTILS_H_ 3 | 4 | #include 5 | 6 | /// \file trtx_cuda_utils.h 7 | /// \brief Header file providing CUDA utility macros. 8 | /// 9 | /// This header file defines utility macros for error checking 10 | /// in CUDA operations, allowing for easier debugging and error 11 | /// handling in GPU-related code. 12 | /// 13 | /// \author Hamdi Boukamcha 14 | /// \date 2024 15 | 16 | #ifndef CUDA_CHECK 17 | /// \brief Macro for checking CUDA function calls. 18 | /// 19 | /// This macro checks the return status of a CUDA call and prints 20 | /// an error message if the call fails. It asserts to halt execution 21 | /// in case of an error. 22 | #define CUDA_CHECK(callstr) \ 23 | { \ 24 | cudaError_t error_code = callstr; \ 25 | if (error_code != cudaSuccess) { \ 26 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ 27 | assert(0); \ 28 | } \ 29 | } 30 | #endif // CUDA_CHECK 31 | 32 | /// \brief Macro for checking conditions with a custom error message. 33 | /// 34 | /// This macro checks a condition and, if false, logs an error message 35 | /// and returns a specified value. It includes information about the 36 | /// file, function, and line number where the error occurred. 37 | /// 38 | /// \param status The condition to check. 39 | /// \param val The value to return if the condition is false. 40 | /// \param errMsg The error message to log. 41 | #define CHECK_RETURN_W_MSG(status, val, errMsg) \ 42 | do \ 43 | { \ 44 | if (!(status)) \ 45 | { \ 46 | sample::gLogError << errMsg << " Error in " << __FILE__ << ", function " << FN_NAME << "(), line " << __LINE__ \ 47 | << std::endl; \ 48 | return val; \ 49 | } \ 50 | } while (0) 51 | 52 | #endif // TRTX_CUDA_UTILS_H_ 53 | -------------------------------------------------------------------------------- /include/engineTRT.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "NvInfer.h" 4 | #include 5 | 6 | using namespace nvinfer1; 7 | using namespace std; 8 | using namespace cv; 9 | 10 | /// \class TRTModule 11 | /// \brief A class for handling TensorRT model inference. 12 | /// 13 | /// This class manages loading, setting inputs, and executing inference 14 | /// for a TensorRT model. It provides methods for setting input data, 15 | /// retrieving output predictions, and handling both static and dynamic shapes. 16 | /// 17 | /// \author Hamdi Boukamcha 18 | /// \date 2024 19 | class EngineTRT 20 | { 21 | 22 | public: 23 | /// \brief Constructor for the TRTModule class. 24 | /// 25 | /// \param modelPath Path to the ONNX model file. 26 | /// \param inputNames Names of the input tensors. 27 | /// \param outputNames Names of the output tensors. 28 | /// \param isDynamicShape Indicates if the model uses dynamic shapes. 29 | /// \param isFP16 Indicates if the model should use FP16 precision. 30 | EngineTRT(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16); 31 | 32 | /// \brief Performs inference on the input data. 33 | /// \return True if inference was successful, false otherwise. 34 | bool infer(); 35 | 36 | /// \brief Sets the input image for inference. 37 | /// 38 | /// \param image The input image to be processed. 39 | void setInput(Mat& image); 40 | 41 | /// \brief Sets multiple inputs for inference from raw data. 42 | /// 43 | /// \param features Pointer to feature data. 44 | /// \param imagePointCoords Pointer to image point coordinates. 45 | /// \param imagePointLabels Pointer to image point labels. 46 | /// \param maskInput Pointer to mask input data. 47 | /// \param hasMaskInput Pointer to existence of mask data. 48 | /// \param numPoints The number of points in the input. 49 | void setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasMaskInput, int numPoints); 50 | 51 | /// \brief Retrieves the output predictions for IoU and low-resolution masks. 52 | /// 53 | /// \param iouPrediction Pointer to store IoU prediction output. 54 | /// \param lowResolutionMasks Pointer to store low-resolution mask output. 55 | void getOutput(float* iouPrediction, float* lowResolutionMasks); 56 | 57 | /// \brief Retrieves the output features from inference. 58 | /// 59 | /// \param features Pointer to store output features. 60 | void getOutput(float* features); 61 | 62 | /// \brief Destructor for the TRTModule class. 63 | ~EngineTRT(); 64 | 65 | private: 66 | /// \brief Builds the TensorRT engine from an ONNX model. 67 | /// 68 | /// \param onnxPath Path to the ONNX model file. 69 | /// \param inputNames Names of the input tensors. 70 | /// \param outputNames Names of the output tensors. 71 | /// \param isDynamicShape Indicates if the model uses dynamic shapes. 72 | /// \param isFP16 Indicates if the model should use FP16 precision. 73 | void build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape = false, bool isFP16 = false); 74 | 75 | void saveEngine(const std::string& engineFilePath); 76 | 77 | /// \brief Deserializes the engine from a file. 78 | /// 79 | /// \param engineName Name of the engine file. 80 | /// \param inputNames Names of the input tensors. 81 | /// \param outputNames Names of the output tensors. 82 | void deserializeEngine(string engineName, vector inputNames, vector outputNames); 83 | 84 | /// \brief Initializes the TensorRT module with input and output names. 85 | /// 86 | /// \param inputNames Names of the input tensors. 87 | /// \param outputNames Names of the output tensors. 88 | void initialize(vector inputNames, vector outputNames); 89 | 90 | /// \brief Gets the size of a buffer based on its dimensions. 91 | /// 92 | /// \param dims The dimensions of the buffer. 93 | /// \return The size in bytes of the buffer. 94 | size_t getSizeByDim(const Dims& dims); 95 | 96 | /// \brief Copies buffers between device and host memory. 97 | /// 98 | /// \param copyInput Indicates whether to copy input data. 99 | /// \param deviceToHost Indicates the direction of the copy. 100 | /// \param async Indicates whether the copy should be asynchronous. 101 | /// \param stream The CUDA stream to use for the copy operation. 102 | void memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream = 0); 103 | 104 | /// \brief Asynchronously copies input data to the device. 105 | /// 106 | /// \param stream The CUDA stream to use for the copy operation. 107 | void copyInputToDeviceAsync(const cudaStream_t& stream = 0); 108 | 109 | /// \brief Asynchronously copies output data from the device to host. 110 | /// 111 | /// \param stream The CUDA stream to use for the copy operation. 112 | void copyOutputToHostAsync(const cudaStream_t& stream = 0); 113 | 114 | vector mInputDims; //!< The dimensions of the input to the network. 115 | vector mOutputDims; //!< The dimensions of the output to the network. 116 | vector mGpuBuffers; //!< The vector of device buffers needed for engine execution. 117 | vector mCpuBuffers; //!< The vector of CPU buffers for input/output. 118 | vector mBufferBindingBytes; //!< The sizes in bytes of each buffer binding. 119 | vector mBufferBindingSizes; //!< The sizes of the buffer bindings. 120 | cudaStream_t mCudaStream; //!< The CUDA stream used for asynchronous operations. 121 | 122 | IRuntime* mRuntime; //!< The TensorRT runtime used to deserialize the engine. 123 | ICudaEngine* mEngine; //!< The TensorRT engine used to run the network. 124 | IExecutionContext* mContext; //!< The context for executing inference using an ICudaEngine. 125 | }; 126 | -------------------------------------------------------------------------------- /include/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include "macros.h" 29 | 30 | using Severity = nvinfer1::ILogger::Severity; 31 | 32 | class LogStreamConsumerBuffer : public std::stringbuf 33 | { 34 | public: 35 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 36 | : mOutput(stream) 37 | , mPrefix(prefix) 38 | , mShouldLog(shouldLog) 39 | { 40 | } 41 | 42 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 43 | : mOutput(other.mOutput) 44 | { 45 | } 46 | 47 | ~LogStreamConsumerBuffer() 48 | { 49 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 50 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 51 | // if the pointer to the beginning is not equal to the pointer to the current position, 52 | // call putOutput() to log the output to the stream 53 | if (pbase() != pptr()) 54 | { 55 | putOutput(); 56 | } 57 | } 58 | 59 | // synchronizes the stream buffer and returns 0 on success 60 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 61 | // resetting the buffer and flushing the stream 62 | virtual int sync() 63 | { 64 | putOutput(); 65 | return 0; 66 | } 67 | 68 | void putOutput() 69 | { 70 | if (mShouldLog) 71 | { 72 | // prepend timestamp 73 | std::time_t timestamp = std::time(nullptr); 74 | tm* tm_local = std::localtime(×tamp); 75 | std::cout << "["; 76 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 77 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 78 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 81 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 82 | // std::stringbuf::str() gets the string contents of the buffer 83 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 84 | mOutput << mPrefix << str(); 85 | // set the buffer to empty 86 | str(""); 87 | // flush the stream 88 | mOutput.flush(); 89 | } 90 | } 91 | 92 | void setShouldLog(bool shouldLog) 93 | { 94 | mShouldLog = shouldLog; 95 | } 96 | 97 | private: 98 | std::ostream& mOutput; 99 | std::string mPrefix; 100 | bool mShouldLog; 101 | }; 102 | 103 | //! 104 | //! \class LogStreamConsumerBase 105 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 106 | //! 107 | class LogStreamConsumerBase 108 | { 109 | public: 110 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 111 | : mBuffer(stream, prefix, shouldLog) 112 | { 113 | } 114 | 115 | protected: 116 | LogStreamConsumerBuffer mBuffer; 117 | }; 118 | 119 | //! 120 | //! \class LogStreamConsumer 121 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 122 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 123 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 124 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 125 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 126 | //! Please do not change the order of the parent classes. 127 | //! 128 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 129 | { 130 | public: 131 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 132 | //! Reportable severity determines if the messages are severe enough to be logged. 133 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 134 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 135 | , std::ostream(&mBuffer) // links the stream buffer with the stream 136 | , mShouldLog(severity <= reportableSeverity) 137 | , mSeverity(severity) 138 | { 139 | } 140 | 141 | LogStreamConsumer(LogStreamConsumer&& other) 142 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 143 | , std::ostream(&mBuffer) // links the stream buffer with the stream 144 | , mShouldLog(other.mShouldLog) 145 | , mSeverity(other.mSeverity) 146 | { 147 | } 148 | 149 | void setReportableSeverity(Severity reportableSeverity) 150 | { 151 | mShouldLog = mSeverity <= reportableSeverity; 152 | mBuffer.setShouldLog(mShouldLog); 153 | } 154 | 155 | private: 156 | static std::ostream& severityOstream(Severity severity) 157 | { 158 | return severity >= Severity::kINFO ? std::cout : std::cerr; 159 | } 160 | 161 | static std::string severityPrefix(Severity severity) 162 | { 163 | switch (severity) 164 | { 165 | case Severity::kINTERNAL_ERROR: return "[F] "; 166 | case Severity::kERROR: return "[E] "; 167 | case Severity::kWARNING: return "[W] "; 168 | case Severity::kINFO: return "[I] "; 169 | case Severity::kVERBOSE: return "[V] "; 170 | default: assert(0); return ""; 171 | } 172 | } 173 | 174 | bool mShouldLog; 175 | Severity mSeverity; 176 | }; 177 | 178 | //! \class Logger 179 | //! 180 | //! \brief Class which manages logging of TensorRT tools and samples 181 | //! 182 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 183 | //! and supports logging two types of messages: 184 | //! 185 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 186 | //! - Test pass/fail messages 187 | //! 188 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 189 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 190 | //! 191 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 192 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 193 | //! 194 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 195 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 196 | //! library and messages coming from the sample. 197 | //! 198 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 199 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 200 | //! object. 201 | 202 | class Logger : public nvinfer1::ILogger 203 | { 204 | public: 205 | Logger(Severity severity = Severity::kWARNING) 206 | : mReportableSeverity(severity) 207 | { 208 | } 209 | 210 | //! 211 | //! \enum TestResult 212 | //! \brief Represents the state of a given test 213 | //! 214 | enum class TestResult 215 | { 216 | kRUNNING, //!< The test is running 217 | kPASSED, //!< The test passed 218 | kFAILED, //!< The test failed 219 | kWAIVED //!< The test was waived 220 | }; 221 | 222 | //! 223 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 224 | //! \return The nvinfer1::ILogger associated with this Logger 225 | //! 226 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 227 | //! we can eliminate the inheritance of Logger from ILogger 228 | //! 229 | nvinfer1::ILogger& getTRTLogger() 230 | { 231 | return *this; 232 | } 233 | 234 | //! 235 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 236 | //! 237 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 238 | //! inheritance from nvinfer1::ILogger 239 | //! 240 | void log(Severity severity, const char* msg) TRT_NOEXCEPT override 241 | { 242 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 243 | } 244 | 245 | //! 246 | //! \brief Method for controlling the verbosity of logging output 247 | //! 248 | //! \param severity The logger will only emit messages that have severity of this level or higher. 249 | //! 250 | void setReportableSeverity(Severity severity) 251 | { 252 | mReportableSeverity = severity; 253 | } 254 | 255 | //! 256 | //! \brief Opaque handle that holds logging information for a particular test 257 | //! 258 | //! This object is an opaque handle to information used by the Logger to print test results. 259 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 260 | //! with Logger::reportTest{Start,End}(). 261 | //! 262 | class TestAtom 263 | { 264 | public: 265 | TestAtom(TestAtom&&) = default; 266 | 267 | private: 268 | friend class Logger; 269 | 270 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 271 | : mStarted(started) 272 | , mName(name) 273 | , mCmdline(cmdline) 274 | { 275 | } 276 | 277 | bool mStarted; 278 | std::string mName; 279 | std::string mCmdline; 280 | }; 281 | 282 | //! 283 | //! \brief Define a test for logging 284 | //! 285 | //! \param[in] name The name of the test. This should be a string starting with 286 | //! "TensorRT" and containing dot-separated strings containing 287 | //! the characters [A-Za-z0-9_]. 288 | //! For example, "TensorRT.sample_googlenet" 289 | //! \param[in] cmdline The command line used to reproduce the test 290 | // 291 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 292 | //! 293 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 294 | { 295 | return TestAtom(false, name, cmdline); 296 | } 297 | 298 | //! 299 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 300 | //! as input 301 | //! 302 | //! \param[in] name The name of the test 303 | //! \param[in] argc The number of command-line arguments 304 | //! \param[in] argv The array of command-line arguments (given as C strings) 305 | //! 306 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 307 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 308 | { 309 | auto cmdline = genCmdlineString(argc, argv); 310 | return defineTest(name, cmdline); 311 | } 312 | 313 | //! 314 | //! \brief Report that a test has started. 315 | //! 316 | //! \pre reportTestStart() has not been called yet for the given testAtom 317 | //! 318 | //! \param[in] testAtom The handle to the test that has started 319 | //! 320 | static void reportTestStart(TestAtom& testAtom) 321 | { 322 | reportTestResult(testAtom, TestResult::kRUNNING); 323 | assert(!testAtom.mStarted); 324 | testAtom.mStarted = true; 325 | } 326 | 327 | //! 328 | //! \brief Report that a test has ended. 329 | //! 330 | //! \pre reportTestStart() has been called for the given testAtom 331 | //! 332 | //! \param[in] testAtom The handle to the test that has ended 333 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 334 | //! TestResult::kFAILED, TestResult::kWAIVED 335 | //! 336 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 337 | { 338 | assert(result != TestResult::kRUNNING); 339 | assert(testAtom.mStarted); 340 | reportTestResult(testAtom, result); 341 | } 342 | 343 | static int reportPass(const TestAtom& testAtom) 344 | { 345 | reportTestEnd(testAtom, TestResult::kPASSED); 346 | return EXIT_SUCCESS; 347 | } 348 | 349 | static int reportFail(const TestAtom& testAtom) 350 | { 351 | reportTestEnd(testAtom, TestResult::kFAILED); 352 | return EXIT_FAILURE; 353 | } 354 | 355 | static int reportWaive(const TestAtom& testAtom) 356 | { 357 | reportTestEnd(testAtom, TestResult::kWAIVED); 358 | return EXIT_SUCCESS; 359 | } 360 | 361 | static int reportTest(const TestAtom& testAtom, bool pass) 362 | { 363 | return pass ? reportPass(testAtom) : reportFail(testAtom); 364 | } 365 | 366 | Severity getReportableSeverity() const 367 | { 368 | return mReportableSeverity; 369 | } 370 | 371 | private: 372 | //! 373 | //! \brief returns an appropriate string for prefixing a log message with the given severity 374 | //! 375 | static const char* severityPrefix(Severity severity) 376 | { 377 | switch (severity) 378 | { 379 | case Severity::kINTERNAL_ERROR: return "[F] "; 380 | case Severity::kERROR: return "[E] "; 381 | case Severity::kWARNING: return "[W] "; 382 | case Severity::kINFO: return "[I] "; 383 | case Severity::kVERBOSE: return "[V] "; 384 | default: assert(0); return ""; 385 | } 386 | } 387 | 388 | //! 389 | //! \brief returns an appropriate string for prefixing a test result message with the given result 390 | //! 391 | static const char* testResultString(TestResult result) 392 | { 393 | switch (result) 394 | { 395 | case TestResult::kRUNNING: return "RUNNING"; 396 | case TestResult::kPASSED: return "PASSED"; 397 | case TestResult::kFAILED: return "FAILED"; 398 | case TestResult::kWAIVED: return "WAIVED"; 399 | default: assert(0); return ""; 400 | } 401 | } 402 | 403 | //! 404 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 405 | //! 406 | static std::ostream& severityOstream(Severity severity) 407 | { 408 | return severity >= Severity::kINFO ? std::cout : std::cerr; 409 | } 410 | 411 | //! 412 | //! \brief method that implements logging test results 413 | //! 414 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 415 | { 416 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 417 | << testAtom.mCmdline << std::endl; 418 | } 419 | 420 | //! 421 | //! \brief generate a command line string from the given (argc, argv) values 422 | //! 423 | static std::string genCmdlineString(int argc, char const* const* argv) 424 | { 425 | std::stringstream ss; 426 | for (int i = 0; i < argc; i++) 427 | { 428 | if (i > 0) 429 | ss << " "; 430 | ss << argv[i]; 431 | } 432 | return ss.str(); 433 | } 434 | 435 | Severity mReportableSeverity; 436 | }; 437 | 438 | namespace 439 | { 440 | 441 | //! 442 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 443 | //! 444 | //! Example usage: 445 | //! 446 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 447 | //! 448 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 449 | { 450 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 451 | } 452 | 453 | //! 454 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 455 | //! 456 | //! Example usage: 457 | //! 458 | //! LOG_INFO(logger) << "hello world" << std::endl; 459 | //! 460 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 461 | { 462 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 463 | } 464 | 465 | //! 466 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 467 | //! 468 | //! Example usage: 469 | //! 470 | //! LOG_WARN(logger) << "hello world" << std::endl; 471 | //! 472 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 473 | { 474 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 475 | } 476 | 477 | //! 478 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 479 | //! 480 | //! Example usage: 481 | //! 482 | //! LOG_ERROR(logger) << "hello world" << std::endl; 483 | //! 484 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 485 | { 486 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 487 | } 488 | 489 | //! 490 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 491 | // ("fatal" severity) 492 | //! 493 | //! Example usage: 494 | //! 495 | //! LOG_FATAL(logger) << "hello world" << std::endl; 496 | //! 497 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 498 | { 499 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 500 | } 501 | 502 | } // anonymous namespace 503 | 504 | #endif // TENSORRT_LOGGING_H 505 | -------------------------------------------------------------------------------- /include/macros.h: -------------------------------------------------------------------------------- 1 | #ifndef __MACROS_H 2 | #define __MACROS_H 3 | 4 | /// \file macros.h 5 | /// \brief Header file defining macros for API export and compatibility. 6 | /// 7 | /// This header file contains macros that facilitate 8 | /// the export and import of functions for shared libraries, 9 | /// as well as compatibility settings based on the NV TensorRT version. 10 | /// 11 | /// \author Hamdi Boukamcha 12 | /// \date 2024 13 | 14 | #ifdef API_EXPORTS 15 | #if defined(_MSC_VER) 16 | #define API __declspec(dllexport) ///< Macro for exporting functions in Windows. 17 | #else 18 | #define API __attribute__((visibility("default"))) ///< Macro for exporting functions in non-Windows environments. 19 | #endif 20 | #else 21 | #if defined(_MSC_VER) 22 | #define API __declspec(dllimport) ///< Macro for importing functions in Windows. 23 | #else 24 | #define API ///< No import/export in non-Windows environments. 25 | #endif 26 | #endif // API_EXPORTS 27 | 28 | #if NV_TENSORRT_MAJOR >= 8 29 | #define TRT_NOEXCEPT noexcept ///< Macro for noexcept specification based on TensorRT version. 30 | #define TRT_CONST_ENQUEUE const ///< Macro to define const enqueue for TensorRT version >= 8. 31 | #else 32 | #define TRT_NOEXCEPT ///< No exception specification for TensorRT version < 8. 33 | #define TRT_CONST_ENQUEUE ///< No const enqueue definition for TensorRT version < 8. 34 | #endif 35 | 36 | #endif // __MACROS_H 37 | -------------------------------------------------------------------------------- /include/speedSam.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "engineTRT.h" 5 | 6 | /// \class NanoSam 7 | /// \brief A class for handling image predictions with an encoder and decoder model. 8 | /// 9 | /// This class manages the process of encoding and decoding images, 10 | /// allowing for the prediction of masks based on input images and points. 11 | /// 12 | /// \author Hamdi Boukamcha 13 | /// \date 2024 14 | class SpeedSam 15 | { 16 | 17 | public: 18 | /// \brief Constructor for the NanoSam class. 19 | /// 20 | /// \param encoderPath Path to the encoder model. 21 | /// \param decoderPath Path to the decoder model. 22 | SpeedSam(std::string encoderPath, std::string decoderPath); 23 | 24 | /// \brief Destructor for the NanoSam class. 25 | ~SpeedSam(); 26 | 27 | /// \brief Predicts masks based on the input image and points. 28 | /// 29 | /// \param image The input image for prediction. 30 | /// \param points The points used for mask prediction. 31 | /// \param labels The labels associated with the points. 32 | /// \return A matrix containing the predicted masks. 33 | Mat predict(Mat& image, std::vector points, std::vector labels); 34 | 35 | private: 36 | // Variables 37 | float* mFeatures; ///< Pointer to the feature data. 38 | float* mMaskInput; ///< Pointer to the mask input data. 39 | float* mHasMaskInput; ///< Pointer to the mask existence input data. 40 | float* mIouPrediction; ///< Pointer to the IoU prediction data. 41 | float* mLowResMasks; ///< Pointer to the low-resolution masks. 42 | 43 | EngineTRT* mImageEncoder; ///< Pointer to the image encoder module. 44 | EngineTRT* mMaskDecoder; ///< Pointer to the mask decoder module. 45 | 46 | /// \brief Upscales the given mask to the target width and height. 47 | /// 48 | /// \param mask The mask to upscale. 49 | /// \param targetWidth The target width for upscaling. 50 | /// \param targetHeight The target height for upscaling. 51 | /// \param size The size of the mask (default is 256). 52 | void upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size = 256); 53 | 54 | /// \brief Resizes the input image to match the model's dimensions. 55 | /// 56 | /// \param img The image to resize. 57 | /// \param modelWidth The width required by the model. 58 | /// \param modelHeight The height required by the model. 59 | /// \return A resized image matrix. 60 | Mat resizeImage(Mat& img, int modelWidth, int modelHeight); 61 | 62 | /// \brief Prepares the decoder input from the provided points. 63 | /// 64 | /// \param points The points to be converted into input data. 65 | /// \param pointData Pointer to the data array for points. 66 | /// \param numPoints The number of points. 67 | /// \param imageWidth The width of the input image. 68 | /// \param imageHeight The height of the input image. 69 | void prepareDecoderInput(std::vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight); 70 | }; 71 | -------------------------------------------------------------------------------- /include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | /// \file colors.h 3 | /// \brief Header file defining color constants and utility functions. 4 | /// 5 | /// This header file contains a set of predefined colors for the Cityscapes dataset, 6 | /// a structure to hold clicked point data, and utility functions for overlaying masks 7 | /// on images and handling mouse events. 8 | /// 9 | /// \author Hamdi Boukamcha 10 | /// \date 2024 11 | #include 12 | using namespace cv; 13 | using namespace std; 14 | 15 | // Global variables to store the selected point and image 16 | Point selectedPoint; 17 | bool pointSelected = false; 18 | float resizeScale = 1.0f; // Scale factor for resizing 19 | 20 | // Colors 21 | const std::vector CITYSCAPES_COLORS = { 22 | cv::Scalar(128, 64, 128), 23 | cv::Scalar(232, 35, 244), 24 | cv::Scalar(70, 70, 70), 25 | cv::Scalar(156, 102, 102), 26 | cv::Scalar(153, 153, 190), 27 | cv::Scalar(153, 153, 153), 28 | cv::Scalar(30, 170, 250), 29 | cv::Scalar(0, 220, 220), 30 | cv::Scalar(35, 142, 107), 31 | cv::Scalar(152, 251, 152), 32 | cv::Scalar(180, 130, 70), 33 | cv::Scalar(60, 20, 220), 34 | cv::Scalar(0, 0, 255), 35 | cv::Scalar(142, 0, 0), 36 | cv::Scalar(70, 0, 0), 37 | cv::Scalar(100, 60, 0), 38 | cv::Scalar(90, 0, 0), 39 | cv::Scalar(230, 0, 0), 40 | cv::Scalar(32, 11, 119), 41 | cv::Scalar(0, 74, 111), 42 | cv::Scalar(81, 0, 81) 43 | }; 44 | 45 | /// \struct PointData 46 | /// \brief Structure to hold clicked point coordinates. 47 | /// 48 | /// This structure stores the coordinates of a clicked point 49 | /// and a flag indicating whether the point has been clicked. 50 | struct PointData { 51 | cv::Point point; ///< The coordinates of the clicked point. 52 | bool clicked; ///< Flag indicating if the point was clicked. 53 | }; 54 | 55 | // Mouse callback function to capture the clicked point 56 | void mouseCallback(int event, int x, int y, int flags, void* param) { 57 | if (event == EVENT_LBUTTONDOWN) { 58 | selectedPoint = Point(x, y); 59 | pointSelected = true; 60 | cout << "Point selected (in resized image): " << selectedPoint << endl; 61 | } 62 | } 63 | 64 | // Example overlay function 65 | void overlay(Mat& image, const Mat& mask) { 66 | // Placeholder for the overlay logic 67 | // This function should blend the mask with the original image 68 | addWeighted(image, 0.5, mask, 0.5, 0, image); 69 | } 70 | 71 | /// \brief Overlays a mask on the given image. 72 | /// 73 | /// This function overlays a colored mask on the input image 74 | /// using a specified transparency (alpha) and optionally shows 75 | /// the contour edges of the mask. 76 | /// 77 | /// \param image The image on which to overlay the mask. 78 | /// \param mask The mask to overlay on the image. 79 | /// \param color The color of the overlay mask (default is CITYSCAPES_COLORS[0]). 80 | /// \param alpha The transparency level of the overlay (default is 0.8). 81 | /// \param showEdge Whether to show the contour edges of the mask (default is true). 82 | void overlay(Mat& image, Mat& mask, cv::Scalar color = cv::Scalar(128, 64, 128), float alpha = 0.8f, bool showEdge = true) 83 | { 84 | // Draw mask 85 | Mat ucharMask(image.rows, image.cols, CV_8UC3, color); 86 | image.copyTo(ucharMask, mask <= 0); 87 | addWeighted(ucharMask, alpha, image, 1.0 - alpha, 0.0f, image); 88 | 89 | // Draw contour edge 90 | if (showEdge) 91 | { 92 | vector> contours; 93 | vector hierarchy; 94 | findContours(mask <= 0, contours, hierarchy, RETR_TREE, CHAIN_APPROX_NONE); 95 | drawContours(image, contours, -1, Scalar(255, 255, 255), 2); 96 | } 97 | } 98 | 99 | /// \brief Handles mouse events for clicking on the image. 100 | /// 101 | /// This function processes mouse events, storing the clicked 102 | /// point coordinates in the provided PointData structure. 103 | /// 104 | /// \param event The type of mouse event. 105 | /// \param x The x-coordinate of the mouse event. 106 | /// \param y The y-coordinate of the mouse event. 107 | /// \param flags Any relevant flags associated with the mouse event. 108 | /// \param userdata User data pointer to store the PointData structure. 109 | void onMouse(int event, int x, int y, int flags, void* userdata) { 110 | PointData* pd = (PointData*)userdata; 111 | if (event == cv::EVENT_LBUTTONDOWN) { 112 | // Save the clicked coordinates 113 | pd->point = cv::Point(x, y); 114 | pd->clicked = true; 115 | } 116 | } 117 | 118 | /// \brief Segments the image based on a user-selected point. 119 | /// 120 | /// This function allows the user to click on a point within the image, which is then 121 | /// used to perform segmentation using the NanoSam model. The segmented result is 122 | /// overlaid on the original image and displayed in the same window. The final image 123 | /// is saved to the specified output path. 124 | /// 125 | /// \param nanosam Reference to the SpeedSam model used for segmentation. 126 | /// \param imagePath Path to the input image. 127 | /// \param outputPath Path to save the segmented output image. 128 | void segmentWithPoint(SpeedSam& nanosam, const string& imagePath, const string& outputPath) { 129 | // Load the image from the specified path 130 | Mat image = imread(imagePath); 131 | if (image.empty()) { 132 | cerr << "Error: Unable to load image from " << imagePath << endl; 133 | return; 134 | } 135 | 136 | // Resize the image for easier viewing 137 | Mat resizedImage; 138 | const int maxWidth = 800; // Maximum width for display 139 | if (image.cols > maxWidth) { 140 | resizeScale = static_cast(maxWidth) / image.cols; 141 | resize(image, resizedImage, Size(), resizeScale, resizeScale); 142 | } 143 | else { 144 | resizedImage = image; 145 | } 146 | 147 | // Display the image in a window 148 | namedWindow("Select Point", WINDOW_AUTOSIZE); 149 | setMouseCallback("Select Point", mouseCallback, nullptr); 150 | imshow("Select Point", resizedImage); 151 | 152 | // Wait indefinitely until the user clicks a point 153 | while (!pointSelected) { 154 | waitKey(10); // Small delay to prevent high CPU usage 155 | } 156 | 157 | // Scale the selected point back to the original image size 158 | Point originalPoint(static_cast(selectedPoint.x / resizeScale), 159 | static_cast(selectedPoint.y / resizeScale)); 160 | cout << "Point mapped to original image: " << originalPoint << endl; 161 | 162 | // Label indicating that the prompt point corresponds to the foreground class 163 | vector labels = { 1.0f }; 164 | 165 | // Perform prediction using the NanoSam model at the specified prompt point 166 | auto mask = nanosam.predict(image, { originalPoint }, labels); 167 | 168 | // Overlay the segmentation mask on the original image 169 | overlay(image, mask); 170 | 171 | // Save the resulting image with the overlay to the specified output path 172 | imwrite(outputPath, image); 173 | 174 | if (image.cols > maxWidth) { 175 | resizeScale = static_cast(maxWidth) / image.cols; 176 | resize(image, resizedImage, Size(), resizeScale, resizeScale); 177 | } 178 | else { 179 | resizedImage = image; 180 | } 181 | 182 | // Update the same window with the segmented image 183 | imshow("Select Point", resizedImage); 184 | waitKey(0); // Wait for another key press to close the window 185 | destroyAllWindows(); // Close all OpenCV windows 186 | } 187 | 188 | /// \brief Segments an image using bounding box information and saves the result. 189 | /// 190 | /// This function loads an image from the specified path, 191 | /// performs segmentation using the provided bounding boxes, 192 | /// and saves the segmented image with an overlay to the specified output path. 193 | /// 194 | /// \param nanosam The SpeedSam model used for segmentation. 195 | /// \param imagePath The path to the input image. 196 | /// \param outputPath The path where the output image will be saved. 197 | /// \param bbox A vector of Points representing the top-left and bottom-right 198 | /// corners of the bounding box for segmentation. 199 | void segmentBbox(SpeedSam& nanosam, string imagePath, string outputPath) { 200 | // Load the image from the specified path 201 | auto image = imread(imagePath); 202 | 203 | // Check if the image was loaded successfully 204 | if (image.empty()) { 205 | cerr << "Error: Unable to load image from " << imagePath << endl; 206 | return; 207 | } 208 | 209 | // Create a window for user interaction 210 | namedWindow("Select and View Result", cv::WINDOW_AUTOSIZE); 211 | 212 | // Let the user select the bounding box 213 | cv::Rect bbox = selectROI("Select and View Result", image, false, false); 214 | 215 | // Check if a valid bounding box was selected 216 | if (bbox.width == 0 || bbox.height == 0) { 217 | cerr << "No valid bounding box selected." << endl; 218 | return; 219 | } 220 | 221 | // Convert the selected bounding box to a vector of points 222 | vector bboxPoints = { 223 | Point(bbox.x, bbox.y), // Top-left point 224 | Point(bbox.x + bbox.width, bbox.y + bbox.height) // Bottom-right point 225 | }; 226 | 227 | // Labels corresponding to the bounding box classes 228 | // 2 : Bounding box top-left, 3 : Bounding box bottom-right 229 | vector labels = { 2, 3 }; 230 | 231 | // Perform prediction using the NanoSam model with the given bounding boxes and labels 232 | auto mask = nanosam.predict(image, bboxPoints, labels); 233 | 234 | // Overlay the segmentation mask on the original image 235 | overlay(image, mask); 236 | 237 | // Draw the bounding box on the image 238 | rectangle(image, bboxPoints[0], bboxPoints[1], cv::Scalar(255, 255, 0), 3); 239 | 240 | // Display the updated image in the same window 241 | imshow("Select and View Result", image); 242 | waitKey(0); 243 | 244 | // Save the resulting image to the specified output path 245 | imwrite(outputPath, image); 246 | } 247 | -------------------------------------------------------------------------------- /model/SAM_encoder.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/model/SAM_encoder.onnx -------------------------------------------------------------------------------- /model/SAM_mask_decoder.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hamdiboukamcha/SPEED-SAM-C-TENSORRT/8fe06fa1082e04c9388abc1ea909579edc53cfc3/model/SAM_mask_decoder.onnx -------------------------------------------------------------------------------- /src/engineTRT.cpp: -------------------------------------------------------------------------------- 1 | #include "engineTRT.h" 2 | #include "logging.h" 3 | #include "cuda_utils.h" 4 | #include "config.h" 5 | #include "macros.h" 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | static Logger gLogger; 15 | 16 | std::string getFileExtension(const std::string& filePath) { 17 | // Find the position of the last dot in the file path 18 | size_t dotPos = filePath.find_last_of("."); 19 | // If a dot is found, extract and return the substring after the dot as the file extension 20 | if (dotPos != std::string::npos) { 21 | return filePath.substr(dotPos + 1); 22 | } 23 | // Return an empty string if no extension is found 24 | return ""; 25 | } 26 | 27 | EngineTRT::EngineTRT(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16) { 28 | // Check if the model file has an ".onnx" extension 29 | if (getFileExtension(modelPath) == "onnx") { 30 | // If the file is an ONNX model, build the engine using the provided parameters 31 | cout << "Building Engine from " << modelPath << endl; 32 | build(modelPath, inputNames, outputNames, isDynamicShape, isFP16); 33 | } 34 | else { 35 | // If the file is not an ONNX model, deserialize an existing engine 36 | cout << "Deserializing Engine." << endl; 37 | deserializeEngine(modelPath, inputNames, outputNames); 38 | } 39 | } 40 | 41 | EngineTRT::~EngineTRT() { 42 | // Release the CUDA stream 43 | cudaStreamDestroy(mCudaStream); 44 | // Free GPU buffers allocated for inference 45 | for (int i = 0; i < mGpuBuffers.size(); i++) 46 | CUDA_CHECK(cudaFree(mGpuBuffers[i])); 47 | // Free CPU buffers 48 | for (int i = 0; i < mCpuBuffers.size(); i++) 49 | delete[] mCpuBuffers[i]; 50 | 51 | // Clean up and destroy the TensorRT engine components 52 | delete mContext; // Destroy the execution context 53 | delete mEngine; // Destroy the engine 54 | delete mRuntime; // Destroy the runtime 55 | } 56 | 57 | void EngineTRT::build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16) 58 | { 59 | // Check if the ONNX file exists. If not, print an error message and return. 60 | if (!std::filesystem::exists(onnxPath)) { 61 | std::cerr << "ONNX file not found: " << onnxPath << std::endl; 62 | return; // Early exit if the ONNX file is missing 63 | } 64 | 65 | // Create an inference builder for building the TensorRT engine. 66 | auto builder = createInferBuilder(gLogger); 67 | assert(builder != nullptr); // Ensure the builder is created successfully 68 | 69 | // Use explicit batch size, which is needed for ONNX models. 70 | const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 71 | INetworkDefinition* network = builder->createNetworkV2(explicitBatch); 72 | assert(network != nullptr); // Ensure the network is created successfully 73 | 74 | // Create a builder configuration object to set options like FP16 precision. 75 | IBuilderConfig* config = builder->createBuilderConfig(); 76 | assert(config != nullptr); // Ensure the config is created successfully 77 | 78 | // If dynamic shape support is needed, configure the optimization profile. 79 | if (isDynamicShape) // Only designed for NanoSAM mask decoder 80 | { 81 | // Create an optimization profile for dynamic input shapes. 82 | auto profile = builder->createOptimizationProfile(); 83 | 84 | // Set the minimum, optimal, and maximum dimensions for the first input. 85 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMIN, Dims3{ 1, 1, 2 }); 86 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kOPT, Dims3{ 1, 1, 2 }); 87 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMAX, Dims3{ 1, 10, 2 }); 88 | 89 | // Set the minimum, optimal, and maximum dimensions for the second input. 90 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMIN, Dims2{ 1, 1 }); 91 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kOPT, Dims2{ 1, 1 }); 92 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMAX, Dims2{ 1, 10 }); 93 | 94 | // Add the optimization profile to the builder configuration. 95 | config->addOptimizationProfile(profile); 96 | } 97 | 98 | // Enable FP16 mode if specified. 99 | if (isFP16) 100 | { 101 | config->setFlag(BuilderFlag::kFP16); // Use mixed precision for faster inference 102 | } 103 | 104 | // Create a parser to convert the ONNX model to a TensorRT network. 105 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger); 106 | assert(parser != nullptr); // Ensure the parser is created successfully 107 | 108 | // Parse the ONNX model from the specified file. 109 | bool parsed = parser->parseFromFile(onnxPath.c_str(), static_cast(gLogger.getReportableSeverity())); 110 | 111 | // Ensure the CUDA stream used for profiling is valid. 112 | assert(mCudaStream != nullptr); 113 | 114 | // Serialize the built network into a binary plan for execution. 115 | IHostMemory* plan{ builder->buildSerializedNetwork(*network, *config) }; 116 | assert(plan != nullptr); // Ensure the network was serialized successfully 117 | 118 | // Create a runtime object for deserializing the engine. 119 | mRuntime = createInferRuntime(gLogger); 120 | assert(mRuntime != nullptr); // Ensure the runtime is created successfully 121 | 122 | // Deserialize the serialized plan to create an execution engine. 123 | mEngine = mRuntime->deserializeCudaEngine(plan->data(), plan->size(), nullptr); 124 | assert(mEngine != nullptr); // Ensure the engine was deserialized successfully 125 | 126 | // Create an execution context for running inference. 127 | mContext = mEngine->createExecutionContext(); 128 | assert(mContext != nullptr); // Ensure the context is created successfully 129 | 130 | // Clean up resources. 131 | delete network; 132 | delete config; 133 | delete parser; 134 | delete plan; 135 | 136 | // Initialize the engine with the input and output names. 137 | initialize(inputNames, outputNames); 138 | } 139 | 140 | void EngineTRT::saveEngine(const std::string& engineFilePath) { 141 | if (mEngine) { 142 | // Serialize the engine to a binary format. 143 | IHostMemory* serializedEngine = mEngine->serialize(); 144 | std::ofstream engineFile(engineFilePath, std::ios::binary); 145 | if (engineFile) { 146 | // Write the serialized engine data to the specified file. 147 | engineFile.write(reinterpret_cast(serializedEngine->data()), serializedEngine->size()); 148 | std::cout << "Serialized engine saved to " << engineFilePath << std::endl; 149 | } 150 | serializedEngine->destroy(); // Destroy the serialized engine memory 151 | } 152 | } 153 | 154 | void EngineTRT::deserializeEngine(string engine_name, vector inputNames, vector outputNames) 155 | { 156 | // Open the engine file in binary mode. 157 | std::ifstream file(engine_name, std::ios::binary); 158 | if (!file.good()) { 159 | std::cerr << "read " << engine_name << " error!" << std::endl; 160 | assert(false); // Trigger an assertion failure if the file cannot be opened 161 | } 162 | 163 | // Determine the size of the file and read the serialized engine data. 164 | size_t size = 0; 165 | file.seekg(0, file.end); 166 | size = file.tellg(); 167 | file.seekg(0, file.beg); 168 | char* serializedEngine = new char[size]; 169 | assert(serializedEngine); // Ensure memory allocation was successful 170 | file.read(serializedEngine, size); 171 | file.close(); 172 | 173 | // Create a runtime object and deserialize the engine. 174 | mRuntime = createInferRuntime(gLogger); 175 | assert(mRuntime); // Ensure the runtime is created successfully 176 | mEngine = mRuntime->deserializeCudaEngine(serializedEngine, size); 177 | mContext = mEngine->createExecutionContext(); 178 | delete[] serializedEngine; // Free the serialized engine memory 179 | 180 | // Ensure the number of bindings matches the expected number of inputs and outputs. 181 | assert(mEngine->getNbBindings() == inputNames.size() + outputNames.size()); 182 | 183 | // Initialize the engine with the input and output names. 184 | initialize(inputNames, outputNames); 185 | } 186 | 187 | void EngineTRT::initialize(vector inputNames, vector outputNames) 188 | { 189 | // Loop through the input names and get the corresponding binding index from the TensorRT engine 190 | for (int i = 0; i < inputNames.size(); i++) 191 | { 192 | const int inputIndex = mEngine->getBindingIndex(inputNames[i].c_str()); 193 | } 194 | 195 | // Loop through the output names and get the corresponding binding index from the TensorRT engine 196 | for (int i = 0; i < outputNames.size(); i++) 197 | { 198 | const int outputIndex = mEngine->getBindingIndex(outputNames[i].c_str()); 199 | } 200 | 201 | // Resize the GPU and CPU buffer vectors to accommodate all the engine bindings 202 | mGpuBuffers.resize(mEngine->getNbBindings()); 203 | mCpuBuffers.resize(mEngine->getNbBindings()); 204 | 205 | // Loop through all bindings to allocate memory and store dimension information 206 | for (size_t i = 0; i < mEngine->getNbBindings(); ++i) 207 | { 208 | // Calculate the size required for the binding based on its dimensions 209 | size_t binding_size = getSizeByDim(mEngine->getBindingDimensions(i)); 210 | mBufferBindingSizes.push_back(binding_size); // Store the size of the binding 211 | mBufferBindingBytes.push_back(binding_size * sizeof(float)); // Calculate the size in bytes 212 | 213 | // Allocate host memory for the CPU buffer 214 | mCpuBuffers[i] = new float[binding_size]; 215 | 216 | // Allocate device memory for the GPU buffer 217 | cudaMalloc(&mGpuBuffers[i], mBufferBindingBytes[i]); 218 | 219 | // Store input and output dimensions separately based on whether the binding is an input or output 220 | if (mEngine->bindingIsInput(i)) 221 | { 222 | mInputDims.push_back(mEngine->getBindingDimensions(i)); 223 | } 224 | else 225 | { 226 | mOutputDims.push_back(mEngine->getBindingDimensions(i)); 227 | } 228 | } 229 | 230 | // Create a CUDA stream for asynchronous operations 231 | CUDA_CHECK(cudaStreamCreate(&mCudaStream)); 232 | } 233 | 234 | bool EngineTRT::infer() 235 | { 236 | // Copy data from host (CPU) input buffers to device (GPU) input buffers asynchronously 237 | copyInputToDeviceAsync(mCudaStream); 238 | 239 | // Perform inference using TensorRT, passing the GPU buffers 240 | bool status = mContext->executeV2(mGpuBuffers.data()); 241 | 242 | if (!status) 243 | { 244 | // If inference fails, print an error message and return false 245 | cout << "inference error!" << endl; 246 | return false; 247 | } 248 | 249 | // Copy the results from device (GPU) output buffers to host (CPU) output buffers asynchronously 250 | copyOutputToHostAsync(mCudaStream); 251 | 252 | // Return true if inference was successful 253 | return true; 254 | } 255 | 256 | void EngineTRT::copyInputToDeviceAsync(const cudaStream_t& stream) 257 | { 258 | // Perform asynchronous memory copy from CPU to GPU for the input buffers 259 | memcpyBuffers(true, false, true, stream); 260 | } 261 | 262 | void EngineTRT::copyOutputToHostAsync(const cudaStream_t& stream) 263 | { 264 | // Calls memcpyBuffers to handle the copying of data from GPU to CPU memory. 265 | // Arguments: false (do not copy input buffers), true (copy data from device to host), 266 | // true (perform the copy asynchronously), and the given CUDA stream. 267 | memcpyBuffers(false, true, true, stream); 268 | } 269 | 270 | void EngineTRT::memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream) 271 | { 272 | // Loop through all bindings (inputs and outputs) in the TensorRT engine. 273 | for (int i = 0; i < mEngine->getNbBindings(); i++) 274 | { 275 | // Determine the destination and source pointers based on the copy direction. 276 | void* dstPtr = deviceToHost ? mCpuBuffers[i] : mGpuBuffers[i]; 277 | const void* srcPtr = deviceToHost ? mGpuBuffers[i] : mCpuBuffers[i]; 278 | // Get the size of the buffer in bytes. 279 | const size_t byteSize = mBufferBindingBytes[i]; 280 | // Set the type of memory copy operation based on the direction. 281 | const cudaMemcpyKind memcpyType = deviceToHost ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; 282 | 283 | // Check if the current binding is an input or output and copy accordingly. 284 | if ((copyInput && mEngine->bindingIsInput(i)) || (!copyInput && !mEngine->bindingIsInput(i))) 285 | { 286 | if (async) 287 | { 288 | // Perform asynchronous memory copy using the CUDA stream. 289 | CUDA_CHECK(cudaMemcpyAsync(dstPtr, srcPtr, byteSize, memcpyType, stream)); 290 | } 291 | else 292 | { 293 | // Perform synchronous memory copy. 294 | CUDA_CHECK(cudaMemcpy(dstPtr, srcPtr, byteSize, memcpyType)); 295 | } 296 | } 297 | } 298 | } 299 | 300 | size_t EngineTRT::getSizeByDim(const Dims& dims) 301 | { 302 | size_t size = 1; 303 | 304 | // Loop through each dimension and multiply to calculate the total size. 305 | for (size_t i = 0; i < dims.nbDims; ++i) 306 | { 307 | // If the dimension is -1 (dynamic), use a predefined maximum size. 308 | if (dims.d[i] == -1) 309 | size *= MAX_NUM_PROMPTS; 310 | else 311 | size *= dims.d[i]; 312 | } 313 | 314 | return size; 315 | } 316 | 317 | void EngineTRT::setInput(Mat& image) 318 | { 319 | // Extract input dimensions (height and width) from the model's input shape 320 | const int inputH = mInputDims[0].d[2]; 321 | const int inputW = mInputDims[0].d[3]; 322 | 323 | int i = 0; // Index counter for buffer placement 324 | 325 | // Iterate over each pixel in the input image 326 | for (int row = 0; row < image.rows; ++row) 327 | { 328 | // Pointer to the start of the row in the image data 329 | uchar* uc_pixel = image.data + row * image.step; 330 | 331 | for (int col = 0; col < image.cols; ++col) 332 | { 333 | // Normalizing the pixel values for the RGB channels 334 | // Convert the BGR image to normalized RGB and store in mCpuBuffers 335 | mCpuBuffers[0][i] = ((float)uc_pixel[2] / 255.0f - 0.485f) / 0.229f; // Red channel 336 | mCpuBuffers[0][i + image.rows * image.cols] = ((float)uc_pixel[1] / 255.0f - 0.456f) / 0.224f; // Green channel 337 | mCpuBuffers[0][i + 2 * image.rows * image.cols] = ((float)uc_pixel[0] / 255.0f - 0.406f) / 0.225f; // Blue channel 338 | 339 | uc_pixel += 3; // Move to the next pixel 340 | ++i; // Increment index 341 | } 342 | } 343 | } 344 | 345 | void EngineTRT::setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasMaskInput, int numPoints) 346 | { 347 | // Clean up old buffers and allocate new buffers for the input data 348 | delete[] mCpuBuffers[1]; 349 | delete[] mCpuBuffers[2]; 350 | mCpuBuffers[1] = new float[numPoints * 2]; // Buffer for point coordinates 351 | mCpuBuffers[2] = new float[numPoints]; // Buffer for point labels 352 | 353 | // Allocate memory on the GPU for the input data 354 | cudaMalloc(&mGpuBuffers[1], sizeof(float) * numPoints * 2); // Coordinates 355 | cudaMalloc(&mGpuBuffers[2], sizeof(float) * numPoints); // Labels 356 | 357 | // Set the size of the data binding in bytes for TensorRT 358 | mBufferBindingBytes[1] = sizeof(float) * numPoints * 2; 359 | mBufferBindingBytes[2] = sizeof(float) * numPoints; 360 | 361 | // Copy input data into CPU buffers 362 | memcpy(mCpuBuffers[0], features, mBufferBindingBytes[0]); 363 | memcpy(mCpuBuffers[1], imagePointCoords, sizeof(float) * numPoints * 2); 364 | memcpy(mCpuBuffers[2], imagePointLabels, sizeof(float) * numPoints); 365 | memcpy(mCpuBuffers[3], maskInput, mBufferBindingBytes[3]); 366 | memcpy(mCpuBuffers[4], hasMaskInput, mBufferBindingBytes[4]); 367 | 368 | // Configure TensorRT to use a dynamic input shape 369 | mContext->setOptimizationProfileAsync(0, mCudaStream); // Set the optimization profile 370 | mContext->setBindingDimensions(1, Dims3{ 1, numPoints, 2 }); // Set input dimensions for coordinates 371 | mContext->setBindingDimensions(2, Dims2{ 1, numPoints }); // Set input dimensions for labels 372 | } 373 | 374 | void EngineTRT::getOutput(float* features) 375 | { 376 | // Copy the output features from the CPU buffer to the provided memory 377 | memcpy(features, mCpuBuffers[1], mBufferBindingBytes[1]); 378 | } 379 | 380 | void EngineTRT::getOutput(float* iouPrediction, float* lowResolutionMasks) 381 | { 382 | // Copy the low-resolution masks and IOU predictions from the CPU buffers 383 | memcpy(lowResolutionMasks, mCpuBuffers[5], mBufferBindingBytes[5]); 384 | memcpy(iouPrediction, mCpuBuffers[6], mBufferBindingBytes[6]); 385 | } 386 | 387 | -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | #include "speedSam.h" 2 | #include "utils.h" 3 | 4 | int main() 5 | { 6 | // Option : Set the running Path 7 | std::string path = ""; 8 | 9 | // Build the engines from onnx files 10 | SpeedSam Speedsam(path + "model/SAM_encoder.onnx", path + "model/SAM_mask_decoder.onnx"); 11 | 12 | /*Segmentation examples */ 13 | 14 | // Demo 1: Segment using a point 15 | segmentWithPoint(Speedsam, path + "assets/dog.jpg", path + "assets/dog_mask.jpg"); 16 | 17 | // Demo 2: Segment using a bounding box 18 | segmentBbox(Speedsam, path + "assets/dogs.jpg", path + "assets/dogs_mask.jpg"); 19 | 20 | return 0; 21 | } 22 | -------------------------------------------------------------------------------- /src/speedSam.cpp: -------------------------------------------------------------------------------- 1 | #include "speedSam.h" 2 | #include "config.h" 3 | 4 | using namespace std; 5 | 6 | SpeedSam::SpeedSam(string encoderPath, string decoderPath) 7 | { 8 | // Initialize the image encoder and mask decoder using the provided model paths 9 | mImageEncoder = new EngineTRT(encoderPath, 10 | { "image" }, // Input names for the encoder 11 | { "image_embeddings" }, // Output names for the encoder 12 | false, // Not using dynamic shape 13 | true); // Using FP16 precision 14 | 15 | mMaskDecoder = new EngineTRT(decoderPath, 16 | { "image_embeddings", "point_coords", "point_labels", "mask_input", "has_mask_input" }, // Input names for the decoder 17 | { "iou_predictions", "low_res_masks" }, // Output names for the decoder 18 | true, // Using dynamic shape 19 | false); // Not using FP16 precision 20 | 21 | // Allocate memory for model features and inputs 22 | mFeatures = new float[HIDDEN_DIM * FEATURE_WIDTH * FEATURE_HEIGHT]; 23 | mMaskInput = new float[HIDDEN_DIM * HIDDEN_DIM]; 24 | mHasMaskInput = new float; // Pointer for mask input presence 25 | mIouPrediction = new float[NUM_LABELS]; // IOU prediction output 26 | mLowResMasks = new float[NUM_LABELS * HIDDEN_DIM * HIDDEN_DIM]; // Low-resolution masks output 27 | } 28 | 29 | SpeedSam::~SpeedSam() 30 | { 31 | // Clean up dynamically allocated memory 32 | if (mFeatures) delete[] mFeatures; 33 | if (mMaskInput) delete[] mMaskInput; 34 | if (mIouPrediction) delete[] mIouPrediction; 35 | if (mLowResMasks) delete[] mLowResMasks; 36 | 37 | if (mImageEncoder) delete mImageEncoder; 38 | if (mMaskDecoder) delete mMaskDecoder; 39 | } 40 | 41 | Mat SpeedSam::predict(Mat& image, vector points, vector labels) 42 | { 43 | // If no points are provided, return an empty mask 44 | if (points.size() == 0) return cv::Mat(image.rows, image.cols, CV_32FC1); 45 | 46 | // Preprocess the input image for the encoder 47 | auto resizedImage = resizeImage(image, MODEL_INPUT_WIDTH, MODEL_INPUT_HEIGHT); 48 | 49 | // Perform inference with the image encoder 50 | mImageEncoder->setInput(resizedImage); 51 | mImageEncoder->infer(); 52 | mImageEncoder->getOutput(mFeatures); 53 | 54 | // Prepare decoder input data for the specified points 55 | auto pointData = new float[2 * points.size()]; // Array to hold scaled point coordinates 56 | prepareDecoderInput(points, pointData, points.size(), image.cols, image.rows); 57 | 58 | // Perform inference with the mask decoder 59 | mMaskDecoder->setInput(mFeatures, pointData, labels.data(), mMaskInput, mHasMaskInput, points.size()); 60 | mMaskDecoder->infer(); 61 | mMaskDecoder->getOutput(mIouPrediction, mLowResMasks); 62 | 63 | // Post-process the output mask 64 | Mat imgMask(HIDDEN_DIM, HIDDEN_DIM, CV_32FC1, mLowResMasks); 65 | upscaleMask(imgMask, image.cols, image.rows); // Upscale to original image size 66 | 67 | delete[] pointData; // Clean up dynamically allocated memory for point data 68 | 69 | return imgMask; // Return the segmented mask 70 | } 71 | 72 | void SpeedSam::prepareDecoderInput(vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight) 73 | { 74 | float scale = MODEL_INPUT_WIDTH / max(imageWidth, imageHeight); // Calculate scaling factor 75 | 76 | // Scale point coordinates 77 | for (int i = 0; i < numPoints; i++) 78 | { 79 | pointData[i * 2] = (float)points[i].x * scale; // X coordinate 80 | pointData[i * 2 + 1] = (float)points[i].y * scale; // Y coordinate 81 | } 82 | 83 | // Initialize mask input data 84 | for (int i = 0; i < HIDDEN_DIM * HIDDEN_DIM; i++) 85 | { 86 | mMaskInput[i] = 0; // Set mask input to zero 87 | } 88 | *mHasMaskInput = 0; // Set has mask input to false 89 | } 90 | 91 | Mat SpeedSam::resizeImage(Mat& img, int inputWidth, int inputHeight) 92 | { 93 | int w, h; 94 | float aspectRatio = (float)img.cols / (float)img.rows; // Calculate aspect ratio 95 | 96 | // Determine new dimensions while maintaining aspect ratio 97 | if (aspectRatio >= 1) 98 | { 99 | w = inputWidth; 100 | h = int(inputHeight / aspectRatio); 101 | } 102 | else 103 | { 104 | w = int(inputWidth * aspectRatio); 105 | h = inputHeight; 106 | } 107 | 108 | // Create a new image with the new size 109 | Mat re(h, w, CV_8UC3); 110 | cv::resize(img, re, re.size(), 0, 0, INTER_LINEAR); // Resize the original image 111 | Mat out(inputHeight, inputWidth, CV_8UC3, 0.0); // Initialize output image 112 | re.copyTo(out(Rect(0, 0, re.cols, re.rows))); // Copy resized image to output 113 | 114 | return out; // Return the resized image 115 | } 116 | 117 | void SpeedSam::upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size) 118 | { 119 | int limX, limY; 120 | // Calculate limits for upscaling based on target dimensions 121 | if (targetWidth > targetHeight) 122 | { 123 | limX = size; 124 | limY = size * targetHeight / targetWidth; 125 | } 126 | else 127 | { 128 | limX = size * targetWidth / targetHeight; 129 | limY = size; 130 | } 131 | 132 | // Resize the mask to the target dimensions 133 | cv::resize(mask(Rect(0, 0, limX, limY)), mask, Size(targetWidth, targetHeight)); 134 | } 135 | --------------------------------------------------------------------------------