├── CMakeLists.txt ├── README.md ├── images ├── bus.jpg └── zidane.jpg ├── l.jpg ├── logging.h ├── m.jpg ├── main1_onnx2trt.cpp ├── main2_trt_infer.cpp ├── models ├── yolov8n-seg.onnx └── yolov8s-seg.onnx ├── n.jpg ├── output.jpg ├── s.jpg ├── utils.h └── x.jpg /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | 3 | project(yolov8_seg) 4 | 5 | add_definitions(-std=c++11) 6 | 7 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 8 | set(CMAKE_CXX_STANDARD 11) 9 | set(CMAKE_BUILD_TYPE Release) 10 | 11 | 12 | include_directories(${PROJECT_SOURCE_DIR}) 13 | 14 | # opencv 15 | find_package(OpenCV) 16 | include_directories(${OpenCV_INCLUDE_DIRS}) 17 | 18 | # include and link dirs of cuda and tensorrt, you need adapt them if yours are different 19 | # cuda 20 | #include_directories(/usr/local/cuda/include) 21 | #link_directories(/usr/local/cuda/lib64) 22 | # tensorrt 23 | #include_directories(/usr/include/arrch64-linux-gnu/) 24 | #link_directories(/usr/lib/arrch64-linux-gnu/) 25 | 26 | # 27 | 28 | find_package(CUDA REQUIRED) 29 | message(STATUS " libraries: ${CUDA_LIBRARIES}") 30 | message(STATUS " include path: ${CUDA_INCLUDE_DIRS}") 31 | include_directories(${CUDA_INCLUDE_DIRS}) 32 | enable_language(CUDA) 33 | 34 | 35 | add_executable(onnx2trt ${PROJECT_SOURCE_DIR}/main1_onnx2trt.cpp) 36 | target_link_libraries(onnx2trt nvinfer) 37 | target_link_libraries(onnx2trt nvonnxparser) 38 | target_link_libraries(onnx2trt cudart) 39 | target_link_libraries(onnx2trt ${OpenCV_LIBS}) 40 | 41 | add_executable(trt_infer ${PROJECT_SOURCE_DIR}/main2_trt_infer.cpp) 42 | target_link_libraries(trt_infer nvinfer) 43 | target_link_libraries(trt_infer nvonnxparser) 44 | target_link_libraries(trt_infer cudart) 45 | target_link_libraries(trt_infer nvinfer_plugin) 46 | target_link_libraries(trt_infer ${OpenCV_LIBS}) 47 | 48 | 49 | add_definitions(-O2 -pthread) 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yolov8-instance-seg-tensorrt 2 | based on the yolov8,provide pt-onnx-tensorrt transcode and infer code by c++ 3 | 4 | mkdir build 5 | cd build 6 | cmake .. 7 | make 8 | sudo ./onnx2trt ../models/yolov8n-seg.onnx ../models/yolov8n-seg.engine 9 | sudo ./trt_infer ../models/yolov8n-seg.onnx ../images/bus.jpg 10 | ![image](https://github.com/fish-kong/Yolov8-instance-seg-tensorrt/blob/main/x.jpg) 11 | -------------------------------------------------------------------------------- /images/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/images/bus.jpg -------------------------------------------------------------------------------- /images/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/images/zidane.jpg -------------------------------------------------------------------------------- /l.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/l.jpg -------------------------------------------------------------------------------- /logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | using Severity = nvinfer1::ILogger::Severity; 30 | 31 | class LogStreamConsumerBuffer : public std::stringbuf 32 | { 33 | public: 34 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 35 | : mOutput(stream) 36 | , mPrefix(prefix) 37 | , mShouldLog(shouldLog) 38 | { 39 | } 40 | 41 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 42 | : mOutput(other.mOutput) 43 | { 44 | } 45 | 46 | ~LogStreamConsumerBuffer() 47 | { 48 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 49 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 50 | // if the pointer to the beginning is not equal to the pointer to the current position, 51 | // call putOutput() to log the output to the stream 52 | if (pbase() != pptr()) 53 | { 54 | putOutput(); 55 | } 56 | } 57 | 58 | // synchronizes the stream buffer and returns 0 on success 59 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 60 | // resetting the buffer and flushing the stream 61 | virtual int sync() 62 | { 63 | putOutput(); 64 | return 0; 65 | } 66 | 67 | void putOutput() 68 | { 69 | if (mShouldLog) 70 | { 71 | // prepend timestamp 72 | std::time_t timestamp = std::time(nullptr); 73 | tm* tm_local = std::localtime(×tamp); 74 | std::cout << "["; 75 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 76 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 77 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 78 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 81 | // std::stringbuf::str() gets the string contents of the buffer 82 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 83 | mOutput << mPrefix << str(); 84 | // set the buffer to empty 85 | str(""); 86 | // flush the stream 87 | mOutput.flush(); 88 | } 89 | } 90 | 91 | void setShouldLog(bool shouldLog) 92 | { 93 | mShouldLog = shouldLog; 94 | } 95 | 96 | private: 97 | std::ostream& mOutput; 98 | std::string mPrefix; 99 | bool mShouldLog; 100 | }; 101 | 102 | //! 103 | //! \class LogStreamConsumerBase 104 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 105 | //! 106 | class LogStreamConsumerBase 107 | { 108 | public: 109 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 110 | : mBuffer(stream, prefix, shouldLog) 111 | { 112 | } 113 | 114 | protected: 115 | LogStreamConsumerBuffer mBuffer; 116 | }; 117 | 118 | //! 119 | //! \class LogStreamConsumer 120 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 121 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 122 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 123 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 124 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 125 | //! Please do not change the order of the parent classes. 126 | //! 127 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 128 | { 129 | public: 130 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 131 | //! Reportable severity determines if the messages are severe enough to be logged. 132 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 133 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 134 | , std::ostream(&mBuffer) // links the stream buffer with the stream 135 | , mShouldLog(severity <= reportableSeverity) 136 | , mSeverity(severity) 137 | { 138 | } 139 | 140 | LogStreamConsumer(LogStreamConsumer&& other) 141 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 142 | , std::ostream(&mBuffer) // links the stream buffer with the stream 143 | , mShouldLog(other.mShouldLog) 144 | , mSeverity(other.mSeverity) 145 | { 146 | } 147 | 148 | void setReportableSeverity(Severity reportableSeverity) 149 | { 150 | mShouldLog = mSeverity <= reportableSeverity; 151 | mBuffer.setShouldLog(mShouldLog); 152 | } 153 | 154 | private: 155 | static std::ostream& severityOstream(Severity severity) 156 | { 157 | return severity >= Severity::kINFO ? std::cout : std::cerr; 158 | } 159 | 160 | static std::string severityPrefix(Severity severity) 161 | { 162 | switch (severity) 163 | { 164 | case Severity::kINTERNAL_ERROR: return "[F] "; 165 | case Severity::kERROR: return "[E] "; 166 | case Severity::kWARNING: return "[W] "; 167 | case Severity::kINFO: return "[I] "; 168 | case Severity::kVERBOSE: return "[V] "; 169 | default: assert(0); return ""; 170 | } 171 | } 172 | 173 | bool mShouldLog; 174 | Severity mSeverity; 175 | }; 176 | 177 | //! \class Logger 178 | //! 179 | //! \brief Class which manages logging of TensorRT tools and samples 180 | //! 181 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 182 | //! and supports logging two types of messages: 183 | //! 184 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 185 | //! - Test pass/fail messages 186 | //! 187 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 188 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 189 | //! 190 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 191 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 192 | //! 193 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 194 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 195 | //! library and messages coming from the sample. 196 | //! 197 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 198 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 199 | //! object. 200 | 201 | class Logger : public nvinfer1::ILogger 202 | { 203 | public: 204 | Logger(Severity severity = Severity::kWARNING) 205 | : mReportableSeverity(severity) 206 | { 207 | } 208 | 209 | //! 210 | //! \enum TestResult 211 | //! \brief Represents the state of a given test 212 | //! 213 | enum class TestResult 214 | { 215 | kRUNNING, //!< The test is running 216 | kPASSED, //!< The test passed 217 | kFAILED, //!< The test failed 218 | kWAIVED //!< The test was waived 219 | }; 220 | 221 | //! 222 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 223 | //! \return The nvinfer1::ILogger associated with this Logger 224 | //! 225 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 226 | //! we can eliminate the inheritance of Logger from ILogger 227 | //! 228 | nvinfer1::ILogger& getTRTLogger() 229 | { 230 | return *this; 231 | } 232 | 233 | //! 234 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 235 | //! 236 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 237 | //! inheritance from nvinfer1::ILogger 238 | //! 239 | void log(Severity severity, const char* msg) noexcept override 240 | { 241 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 242 | } 243 | 244 | //! 245 | //! \brief Method for controlling the verbosity of logging output 246 | //! 247 | //! \param severity The logger will only emit messages that have severity of this level or higher. 248 | //! 249 | void setReportableSeverity(Severity severity) 250 | { 251 | mReportableSeverity = severity; 252 | } 253 | 254 | //! 255 | //! \brief Opaque handle that holds logging information for a particular test 256 | //! 257 | //! This object is an opaque handle to information used by the Logger to print test results. 258 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 259 | //! with Logger::reportTest{Start,End}(). 260 | //! 261 | class TestAtom 262 | { 263 | public: 264 | TestAtom(TestAtom&&) = default; 265 | 266 | private: 267 | friend class Logger; 268 | 269 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 270 | : mStarted(started) 271 | , mName(name) 272 | , mCmdline(cmdline) 273 | { 274 | } 275 | 276 | bool mStarted; 277 | std::string mName; 278 | std::string mCmdline; 279 | }; 280 | 281 | //! 282 | //! \brief Define a test for logging 283 | //! 284 | //! \param[in] name The name of the test. This should be a string starting with 285 | //! "TensorRT" and containing dot-separated strings containing 286 | //! the characters [A-Za-z0-9_]. 287 | //! For example, "TensorRT.sample_googlenet" 288 | //! \param[in] cmdline The command line used to reproduce the test 289 | // 290 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 291 | //! 292 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 293 | { 294 | return TestAtom(false, name, cmdline); 295 | } 296 | 297 | //! 298 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 299 | //! as input 300 | //! 301 | //! \param[in] name The name of the test 302 | //! \param[in] argc The number of command-line arguments 303 | //! \param[in] argv The array of command-line arguments (given as C strings) 304 | //! 305 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 306 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 307 | { 308 | auto cmdline = genCmdlineString(argc, argv); 309 | return defineTest(name, cmdline); 310 | } 311 | 312 | //! 313 | //! \brief Report that a test has started. 314 | //! 315 | //! \pre reportTestStart() has not been called yet for the given testAtom 316 | //! 317 | //! \param[in] testAtom The handle to the test that has started 318 | //! 319 | static void reportTestStart(TestAtom& testAtom) 320 | { 321 | reportTestResult(testAtom, TestResult::kRUNNING); 322 | assert(!testAtom.mStarted); 323 | testAtom.mStarted = true; 324 | } 325 | 326 | //! 327 | //! \brief Report that a test has ended. 328 | //! 329 | //! \pre reportTestStart() has been called for the given testAtom 330 | //! 331 | //! \param[in] testAtom The handle to the test that has ended 332 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 333 | //! TestResult::kFAILED, TestResult::kWAIVED 334 | //! 335 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 336 | { 337 | assert(result != TestResult::kRUNNING); 338 | assert(testAtom.mStarted); 339 | reportTestResult(testAtom, result); 340 | } 341 | 342 | static int reportPass(const TestAtom& testAtom) 343 | { 344 | reportTestEnd(testAtom, TestResult::kPASSED); 345 | return EXIT_SUCCESS; 346 | } 347 | 348 | static int reportFail(const TestAtom& testAtom) 349 | { 350 | reportTestEnd(testAtom, TestResult::kFAILED); 351 | return EXIT_FAILURE; 352 | } 353 | 354 | static int reportWaive(const TestAtom& testAtom) 355 | { 356 | reportTestEnd(testAtom, TestResult::kWAIVED); 357 | return EXIT_SUCCESS; 358 | } 359 | 360 | static int reportTest(const TestAtom& testAtom, bool pass) 361 | { 362 | return pass ? reportPass(testAtom) : reportFail(testAtom); 363 | } 364 | 365 | Severity getReportableSeverity() const 366 | { 367 | return mReportableSeverity; 368 | } 369 | 370 | private: 371 | //! 372 | //! \brief returns an appropriate string for prefixing a log message with the given severity 373 | //! 374 | static const char* severityPrefix(Severity severity) 375 | { 376 | switch (severity) 377 | { 378 | case Severity::kINTERNAL_ERROR: return "[F] "; 379 | case Severity::kERROR: return "[E] "; 380 | case Severity::kWARNING: return "[W] "; 381 | case Severity::kINFO: return "[I] "; 382 | case Severity::kVERBOSE: return "[V] "; 383 | default: assert(0); return ""; 384 | } 385 | } 386 | 387 | //! 388 | //! \brief returns an appropriate string for prefixing a test result message with the given result 389 | //! 390 | static const char* testResultString(TestResult result) 391 | { 392 | switch (result) 393 | { 394 | case TestResult::kRUNNING: return "RUNNING"; 395 | case TestResult::kPASSED: return "PASSED"; 396 | case TestResult::kFAILED: return "FAILED"; 397 | case TestResult::kWAIVED: return "WAIVED"; 398 | default: assert(0); return ""; 399 | } 400 | } 401 | 402 | //! 403 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 404 | //! 405 | static std::ostream& severityOstream(Severity severity) 406 | { 407 | return severity >= Severity::kINFO ? std::cout : std::cerr; 408 | } 409 | 410 | //! 411 | //! \brief method that implements logging test results 412 | //! 413 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 414 | { 415 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 416 | << testAtom.mCmdline << std::endl; 417 | } 418 | 419 | //! 420 | //! \brief generate a command line string from the given (argc, argv) values 421 | //! 422 | static std::string genCmdlineString(int argc, char const* const* argv) 423 | { 424 | std::stringstream ss; 425 | for (int i = 0; i < argc; i++) 426 | { 427 | if (i > 0) 428 | ss << " "; 429 | ss << argv[i]; 430 | } 431 | return ss.str(); 432 | } 433 | 434 | Severity mReportableSeverity; 435 | }; 436 | 437 | namespace 438 | { 439 | 440 | //! 441 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 442 | //! 443 | //! Example usage: 444 | //! 445 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 446 | //! 447 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 448 | { 449 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 450 | } 451 | 452 | //! 453 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 454 | //! 455 | //! Example usage: 456 | //! 457 | //! LOG_INFO(logger) << "hello world" << std::endl; 458 | //! 459 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 460 | { 461 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 462 | } 463 | 464 | //! 465 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 466 | //! 467 | //! Example usage: 468 | //! 469 | //! LOG_WARN(logger) << "hello world" << std::endl; 470 | //! 471 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 472 | { 473 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 474 | } 475 | 476 | //! 477 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 478 | //! 479 | //! Example usage: 480 | //! 481 | //! LOG_ERROR(logger) << "hello world" << std::endl; 482 | //! 483 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 484 | { 485 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 486 | } 487 | 488 | //! 489 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 490 | // ("fatal" severity) 491 | //! 492 | //! Example usage: 493 | //! 494 | //! LOG_FATAL(logger) << "hello world" << std::endl; 495 | //! 496 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 497 | { 498 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 499 | } 500 | 501 | } // anonymous namespace 502 | 503 | #endif // TENSORRT_LOGGING_H 504 | -------------------------------------------------------------------------------- /m.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/m.jpg -------------------------------------------------------------------------------- /main1_onnx2trt.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "logging.h" 3 | #include "NvOnnxParser.h" 4 | #include "NvInfer.h" 5 | #include 6 | 7 | using namespace nvinfer1; 8 | using namespace nvonnxparser; 9 | 10 | static Logger gLogger; 11 | int main(int argc,char** argv) { 12 | if (argc < 2) { 13 | argv[1] = "../../models/yolov8n-seg.onnx"; 14 | argv[2] = "../../models/yolov8n-seg.engine"; 15 | } 16 | // 1 onnx解析器 17 | IBuilder* builder = createInferBuilder(gLogger); 18 | const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 19 | INetworkDefinition* network = builder->createNetworkV2(explicitBatch); 20 | 21 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger); 22 | 23 | const char* onnx_filename = argv[1]; 24 | parser->parseFromFile(onnx_filename, static_cast(Logger::Severity::kWARNING)); 25 | for (int i = 0; i < parser->getNbErrors(); ++i) 26 | { 27 | std::cout << parser->getError(i)->desc() << std::endl; 28 | } 29 | std::cout << "successfully load the onnx model" << std::endl; 30 | 31 | // 2build the engine 32 | unsigned int maxBatchSize = 1; 33 | builder->setMaxBatchSize(maxBatchSize); 34 | IBuilderConfig* config = builder->createBuilderConfig(); 35 | config->setMaxWorkspaceSize(1 << 20); 36 | //config->setMaxWorkspaceSize(128 * (1 << 20)); // 16MB 37 | config->setFlag(BuilderFlag::kFP16); 38 | ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); 39 | 40 | // 3serialize Model 41 | IHostMemory *gieModelStream = engine->serialize(); 42 | std::ofstream p(argv[2], std::ios::binary); 43 | if (!p) 44 | { 45 | std::cerr << "could not open plan output file" << std::endl; 46 | return -1; 47 | } 48 | p.write(reinterpret_cast(gieModelStream->data()), gieModelStream->size()); 49 | gieModelStream->destroy(); 50 | 51 | 52 | std::cout << "successfully generate the trt engine model" << std::endl; 53 | return 0; 54 | } -------------------------------------------------------------------------------- /main2_trt_infer.cpp: -------------------------------------------------------------------------------- 1 | #include "NvInfer.h" 2 | #include "cuda_runtime_api.h" 3 | #include "NvInferPlugin.h" 4 | #include "logging.h" 5 | #include 6 | #include "utils.h" 7 | #include 8 | using namespace nvinfer1; 9 | using namespace cv; 10 | 11 | // stuff we know about the network and the input/output blobs 12 | static const int INPUT_H = 640; 13 | static const int INPUT_W = 640; 14 | static const int _segWidth = 160; 15 | static const int _segHeight = 160; 16 | static const int _segChannels = 32; 17 | static const int CLASSES = 80; 18 | static const int Num_box = 8400; 19 | static const int OUTPUT_SIZE = Num_box * (CLASSES+4 + _segChannels);//output0 20 | static const int OUTPUT_SIZE1 = _segChannels * _segWidth * _segHeight ;//output1 21 | 22 | 23 | static const float CONF_THRESHOLD = 0.1; 24 | static const float NMS_THRESHOLD = 0.5; 25 | static const float MASK_THRESHOLD = 0.5; 26 | const char* INPUT_BLOB_NAME = "images"; 27 | const char* OUTPUT_BLOB_NAME = "output0";//detect 28 | const char* OUTPUT_BLOB_NAME1 = "output1";//mask 29 | 30 | 31 | struct OutputSeg { 32 | int id; //结果类别id 33 | float confidence; //结果置信度 34 | cv::Rect box; //矩形框 35 | cv::Mat boxMask; //矩形框内mask,节省内存空间和加快速度 36 | }; 37 | 38 | void DrawPred(Mat& img,std:: vector result) { 39 | //生成随机颜色 40 | std::vector color; 41 | srand(time(0)); 42 | for (int i = 0; i < CLASSES; i++) { 43 | int b = rand() % 256; 44 | int g = rand() % 256; 45 | int r = rand() % 256; 46 | color.push_back(Scalar(b, g, r)); 47 | } 48 | Mat mask = img.clone(); 49 | for (int i = 0; i < result.size(); i++) { 50 | int left, top; 51 | left = result[i].box.x; 52 | top = result[i].box.y; 53 | int color_num = i; 54 | rectangle(img, result[i].box, color[result[i].id], 2, 8); 55 | 56 | mask(result[i].box).setTo(color[result[i].id], result[i].boxMask); 57 | char label[100]; 58 | sprintf(label, "%d:%.2f", result[i].id, result[i].confidence); 59 | 60 | //std::string label = std::to_string(result[i].id) + ":" + std::to_string(result[i].confidence); 61 | int baseLine; 62 | Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); 63 | top = max(top, labelSize.height); 64 | putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 1, color[result[i].id], 2); 65 | } 66 | 67 | addWeighted(img, 0.5, mask, 0.8, 1, img); //将mask加在原图上面 68 | 69 | 70 | } 71 | 72 | 73 | 74 | static Logger gLogger; 75 | void doInference(IExecutionContext& context, float* input, float* output, float* output1, int batchSize) 76 | { 77 | const ICudaEngine& engine = context.getEngine(); 78 | 79 | // Pointers to input and output device buffers to pass to engine. 80 | // Engine requires exactly IEngine::getNbBindings() number of buffers. 81 | assert(engine.getNbBindings() == 3); 82 | void* buffers[3]; 83 | 84 | // In order to bind the buffers, we need to know the names of the input and output tensors. 85 | // Note that indices are guaranteed to be less than IEngine::getNbBindings() 86 | const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME); 87 | const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME); 88 | const int outputIndex1 = engine.getBindingIndex(OUTPUT_BLOB_NAME1); 89 | 90 | // Create GPU buffers on device 91 | CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * INPUT_H * INPUT_W * sizeof(float)));// 92 | CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float))); 93 | CHECK(cudaMalloc(&buffers[outputIndex1], batchSize * OUTPUT_SIZE1 * sizeof(float))); 94 | // cudaMalloc分配内存 cudaFree释放内存 cudaMemcpy或 cudaMemcpyAsync 在主机和设备之间传输数据 95 | // cudaMemcpy cudaMemcpyAsync 显式地阻塞传输 显式地非阻塞传输 96 | // Create stream 97 | cudaStream_t stream; 98 | CHECK(cudaStreamCreate(&stream)); 99 | 100 | // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host 101 | CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * 3 * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream)); 102 | context.enqueue(batchSize, buffers, stream, nullptr); 103 | CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream)); 104 | CHECK(cudaMemcpyAsync(output1, buffers[outputIndex1], batchSize * OUTPUT_SIZE1 * sizeof(float), cudaMemcpyDeviceToHost, stream)); 105 | cudaStreamSynchronize(stream); 106 | 107 | // Release stream and buffers 108 | cudaStreamDestroy(stream); 109 | CHECK(cudaFree(buffers[inputIndex])); 110 | CHECK(cudaFree(buffers[outputIndex])); 111 | CHECK(cudaFree(buffers[outputIndex1])); 112 | } 113 | 114 | 115 | 116 | int main(int argc, char** argv) 117 | { 118 | if (argc < 2) { 119 | argv[1] = "../models/yolov8n-seg.engine"; 120 | argv[2] = "../images/bus.jpg"; 121 | } 122 | // create a model using the API directly and serialize it to a stream 123 | char* trtModelStream{ nullptr }; //char* trtModelStream==nullptr; 开辟空指针后 要和new配合使用,比如89行 trtModelStream = new char[size] 124 | size_t size{ 0 };//与int固定四个字节不同有所不同,size_t的取值range是目标平台下最大可能的数组尺寸,一些平台下size_t的范围小于int的正数范围,又或者大于unsigned int. 使用Int既有可能浪费,又有可能范围不够大。 125 | 126 | std::ifstream file(argv[1], std::ios::binary); 127 | if (file.good()) { 128 | std::cout << "load engine success" << std::endl; 129 | file.seekg(0, file.end);//指向文件的最后地址 130 | size = file.tellg();//把文件长度告诉给size 131 | //std::cout << "\nfile:" << argv[1] << " size is"; 132 | //std::cout << size << ""; 133 | 134 | file.seekg(0, file.beg);//指回文件的开始地址 135 | trtModelStream = new char[size];//开辟一个char 长度是文件的长度 136 | assert(trtModelStream);// 137 | file.read(trtModelStream, size);//将文件内容传给trtModelStream 138 | file.close();//关闭 139 | } 140 | else { 141 | std::cout << "load engine failed" << std::endl; 142 | return 1; 143 | } 144 | 145 | 146 | Mat src = imread(argv[2], 1); 147 | if (src.empty()) { std::cout << "image load faild" << std::endl; return 1; } 148 | int img_width = src.cols; 149 | int img_height = src.rows; 150 | std::cout << "宽高:" << img_width << " " << img_height << std::endl; 151 | // Subtract mean from image 152 | static float data[3 * INPUT_H * INPUT_W]; 153 | Mat pr_img0, pr_img; 154 | std::vector padsize; 155 | pr_img = preprocess_img(src, INPUT_H, INPUT_W, padsize); // Resize 156 | int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3]; 157 | float ratio_h = (float)src.rows / newh; 158 | float ratio_w = (float)src.cols / neww; 159 | int i = 0;// [1,3,INPUT_H,INPUT_W] 160 | //std::cout << "pr_img.step" << pr_img.step << std::endl; 161 | for (int row = 0; row < INPUT_H; ++row) { 162 | uchar* uc_pixel = pr_img.data + row * pr_img.step;//pr_img.step=widthx3 就是每一行有width个3通道的值 163 | for (int col = 0; col < INPUT_W; ++col) 164 | { 165 | 166 | data[i] = (float)uc_pixel[2] / 255.0; 167 | data[i + INPUT_H * INPUT_W] = (float)uc_pixel[1] / 255.0; 168 | data[i + 2 * INPUT_H * INPUT_W] = (float)uc_pixel[0] / 255.; 169 | uc_pixel += 3; 170 | ++i; 171 | } 172 | } 173 | 174 | IRuntime* runtime = createInferRuntime(gLogger); 175 | assert(runtime != nullptr); 176 | bool didInitPlugins = initLibNvInferPlugins(nullptr, ""); 177 | ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size, nullptr); 178 | assert(engine != nullptr); 179 | IExecutionContext* context = engine->createExecutionContext(); 180 | assert(context != nullptr); 181 | delete[] trtModelStream; 182 | 183 | // Run inference 184 | static float prob[OUTPUT_SIZE]; 185 | static float prob1[OUTPUT_SIZE1]; 186 | 187 | //for (int i = 0; i < 10; i++) {//计算10次的推理速度 188 | // auto start = std::chrono::system_clock::now(); 189 | // doInference(*context, data, prob, prob1, 1); 190 | // auto end = std::chrono::system_clock::now(); 191 | // std::cout << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; 192 | // } 193 | auto start = std::chrono::system_clock::now(); 194 | doInference(*context, data, prob, prob1, 1); 195 | auto end = std::chrono::system_clock::now(); 196 | std::cout << "推理时间:" << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; 197 | 198 | std::vector classIds;//结果id数组 199 | std::vector confidences;//结果每个id对应置信度数组 200 | std::vector boxes;//每个id矩形框 201 | std::vector picked_proposals; //后续计算mask 202 | 203 | 204 | 205 | // 处理box 206 | int net_length = CLASSES + 4 + _segChannels; 207 | cv::Mat out1 = cv::Mat(net_length, Num_box, CV_32F, prob); 208 | 209 | start = std::chrono::system_clock::now(); 210 | for (int i = 0; i < Num_box; i++) { 211 | //输出是1*net_length*Num_box;所以每个box的属性是每隔Num_box取一个值,共net_length个值 212 | cv::Mat scores = out1(Rect(i, 4, 1, CLASSES)).clone(); 213 | Point classIdPoint; 214 | double max_class_socre; 215 | minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint); 216 | max_class_socre = (float)max_class_socre; 217 | if (max_class_socre >= CONF_THRESHOLD) { 218 | cv::Mat temp_proto = out1(Rect(i, 4 + CLASSES, 1, _segChannels)).clone(); 219 | picked_proposals.push_back(temp_proto.t()); 220 | float x = (out1.at(0, i) - padw) * ratio_w; //cx 221 | float y = (out1.at(1, i) - padh) * ratio_h; //cy 222 | float w = out1.at(2, i) * ratio_w; //w 223 | float h = out1.at(3, i) * ratio_h; //h 224 | int left = MAX((x - 0.5 * w), 0); 225 | int top = MAX((y - 0.5 * h), 0); 226 | int width = (int)w; 227 | int height = (int)h; 228 | if (width <= 0 || height <= 0) { continue; } 229 | 230 | classIds.push_back(classIdPoint.y); 231 | confidences.push_back(max_class_socre); 232 | boxes.push_back(Rect(left, top, width, height)); 233 | } 234 | 235 | } 236 | //执行非最大抑制以消除具有较低置信度的冗余重叠框(NMS) 237 | std::vector nms_result; 238 | cv::dnn::NMSBoxes(boxes, confidences, CONF_THRESHOLD, NMS_THRESHOLD, nms_result); 239 | std::vector temp_mask_proposals; 240 | std::vector output; 241 | Rect holeImgRect(0, 0, src.cols, src.rows); 242 | for (int i = 0; i < nms_result.size(); ++i) { 243 | int idx = nms_result[i]; 244 | OutputSeg result; 245 | result.id = classIds[idx]; 246 | result.confidence = confidences[idx]; 247 | result.box = boxes[idx]& holeImgRect; 248 | output.push_back(result); 249 | temp_mask_proposals.push_back(picked_proposals[idx]); 250 | } 251 | 252 | // 处理mask 253 | Mat maskProposals; 254 | for (int i = 0; i < temp_mask_proposals.size(); ++i) 255 | maskProposals.push_back(temp_mask_proposals[i]); 256 | 257 | Mat protos = Mat(_segChannels, _segWidth * _segHeight, CV_32F, prob1); 258 | Mat matmulRes = (maskProposals * protos).t();//n*32 32*25600 A*B是以数学运算中矩阵相乘的方式实现的,要求A的列数等于B的行数时 259 | Mat masks = matmulRes.reshape(output.size(), { _segWidth,_segHeight });//n*160*160 260 | 261 | std::vector maskChannels; 262 | cv::split(masks, maskChannels); 263 | Rect roi(int((float)padw / INPUT_W * _segWidth), int((float)padh / INPUT_H * _segHeight), int(_segWidth - padw / 2), int(_segHeight - padh / 2)); 264 | for (int i = 0; i < output.size(); ++i) { 265 | Mat dest, mask; 266 | cv::exp(-maskChannels[i], dest);//sigmoid 267 | dest = 1.0 / (1.0 + dest);//160*160 268 | dest = dest(roi); 269 | resize(dest, mask, cv::Size(src.cols, src.rows), INTER_NEAREST); 270 | //crop----截取box中的mask作为该box对应的mask 271 | Rect temp_rect = output[i].box; 272 | mask = mask(temp_rect) > MASK_THRESHOLD; 273 | output[i].boxMask = mask; 274 | } 275 | end = std::chrono::system_clock::now(); 276 | std::cout << "后处理时间:" << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; 277 | 278 | DrawPred(src, output); 279 | cv::imshow("output.jpg", src); 280 | char c = cv::waitKey(0); 281 | 282 | // Destroy the engine 283 | context->destroy(); 284 | engine->destroy(); 285 | runtime->destroy(); 286 | 287 | system("pause"); 288 | return 0; 289 | } 290 | -------------------------------------------------------------------------------- /models/yolov8n-seg.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/models/yolov8n-seg.onnx -------------------------------------------------------------------------------- /models/yolov8s-seg.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/models/yolov8s-seg.onnx -------------------------------------------------------------------------------- /n.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/n.jpg -------------------------------------------------------------------------------- /output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/output.jpg -------------------------------------------------------------------------------- /s.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/s.jpg -------------------------------------------------------------------------------- /utils.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/utils.h -------------------------------------------------------------------------------- /x.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov8-instance-seg-tensorrt/58bfd1fe27437386bf486d82c1447a85e75f8c30/x.jpg --------------------------------------------------------------------------------