├── CMakeLists.txt ├── README.md ├── images ├── bus.jpg └── zidane.jpg ├── logging.h ├── main1_onnx2trt.cpp ├── main2_trt_infer.cpp ├── models ├── yolov5s-seg.engine └── yolov5s-seg.onnx ├── output.jpg └── utils.h /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | 3 | project(yolov5_seg) 4 | 5 | add_definitions(-std=c++11) 6 | 7 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 8 | set(CMAKE_CXX_STANDARD 11) 9 | set(CMAKE_BUILD_TYPE Release) 10 | 11 | 12 | include_directories(${PROJECT_SOURCE_DIR}) 13 | 14 | # opencv 15 | find_package(OpenCV) 16 | include_directories(${OpenCV_INCLUDE_DIRS}) 17 | 18 | # include and link dirs of cuda and tensorrt, you need adapt them if yours are different 19 | # cuda 20 | #include_directories(/usr/local/cuda/include) 21 | #link_directories(/usr/local/cuda/lib64) 22 | # tensorrt 23 | #include_directories(/usr/include/arrch64-linux-gnu/) 24 | #link_directories(/usr/lib/arrch64-linux-gnu/) 25 | 26 | # 27 | 28 | find_package(CUDA REQUIRED) 29 | message(STATUS " libraries: ${CUDA_LIBRARIES}") 30 | message(STATUS " include path: ${CUDA_INCLUDE_DIRS}") 31 | include_directories(${CUDA_INCLUDE_DIRS}) 32 | enable_language(CUDA) 33 | 34 | 35 | add_executable(onnx2trt ${PROJECT_SOURCE_DIR}/main1_onnx2trt.cpp) 36 | target_link_libraries(onnx2trt nvinfer) 37 | target_link_libraries(onnx2trt nvonnxparser) 38 | target_link_libraries(onnx2trt cudart) 39 | target_link_libraries(onnx2trt ${OpenCV_LIBS}) 40 | 41 | add_executable(trt_infer ${PROJECT_SOURCE_DIR}/main2_trt_infer.cpp) 42 | target_link_libraries(trt_infer nvinfer) 43 | target_link_libraries(trt_infer nvonnxparser) 44 | target_link_libraries(trt_infer cudart) 45 | target_link_libraries(trt_infer nvinfer_plugin) 46 | target_link_libraries(trt_infer ${OpenCV_LIBS}) 47 | 48 | 49 | add_definitions(-O2 -pthread) 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # General yolov5-seg inference acceleration program 2 | 3 | cuda10.2 cudnn8.2.4 Tensorrt8.0.1.6 Opencv4.5.4 4 | 5 | ## How to Run 6 | 7 | ### 1.Train yolov5s-seg in your own dataset 8 | please follow official code (yolov5-6.2) 9 | 10 | python segment/train.py --data coco128-seg.yaml --weights yolov5s-seg.pt --cfg yolov5s-seg.yaml 11 | 12 | ### 2.Gen tensort engine 13 | #### 2.1 official code to tensorrt 14 | python export.py --data coco128-seg.yaml --weights yolov5s-seg.pt --cfg yolov5s-seg.yaml --include engine 15 | 16 | #### 2.2 ONNX To Tensorrt 17 | 18 | ##### 2.2.1 Gen onnx 19 | You can directly export the engine model with the official code of yolov5/v6.2, but the exported model can only be used on your current computer. I suggest you export the onnx model first, and then use the code I provide. The advantage is that even if you change the computer configuration, you can generate the engine you need as long as there is an onnx 20 | 21 | python export.py --data coco128-seg.yaml --weights yolov5s-seg.pt --cfg yolov5s-seg.yaml --include onnx 22 | 23 | a file 'yolov5s-seg.onnx' will be generated. 24 | 25 | ##### 2.2.1 Gen engine 26 | cd download file path 27 | copy 'yolov5s-seg.onnx' to models/ 28 | 29 | mkdir build 30 | cd build 31 | cmake .. 32 | make 33 | // file onnx2trt and trt_infer will be generated 34 | sudo ./onnx2trt ../models/yolov5s-seg.onnx ../models/yolov5s-seg.engine 35 | // a file 'yolov5s-seg.engine' will be generated. 36 | 37 | ### 3.Use tensorrt engine 38 | 39 | sudo ./trt_infer ../models/yolov5s-seg.engine ../imagaes/street.jpg 40 | 41 | ''' 42 | for (int i = 0; i < 10; i++) {//计算10次的推理速度 43 | auto start = std::chrono::system_clock::now(); 44 | doInference(*context, data, prob, prob1, 1); 45 | auto end = std::chrono::system_clock::now(); 46 | std::cout << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; 47 | } 48 | ''' 49 | The inference time is stable at 10ms (gtx 1080ti) 50 | 51 | 52 | ![image](https://github.com/fish-kong/Yolov5-instance-seg-tensorrt/blob/main/output.jpg) 53 | -------------------------------------------------------------------------------- /images/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov5-instance-seg-tensorrt/bda2aca7604e5543ea95d32b07db6076efb73354/images/bus.jpg -------------------------------------------------------------------------------- /images/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov5-instance-seg-tensorrt/bda2aca7604e5543ea95d32b07db6076efb73354/images/zidane.jpg -------------------------------------------------------------------------------- /logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | using Severity = nvinfer1::ILogger::Severity; 30 | 31 | class LogStreamConsumerBuffer : public std::stringbuf 32 | { 33 | public: 34 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 35 | : mOutput(stream) 36 | , mPrefix(prefix) 37 | , mShouldLog(shouldLog) 38 | { 39 | } 40 | 41 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 42 | : mOutput(other.mOutput) 43 | { 44 | } 45 | 46 | ~LogStreamConsumerBuffer() 47 | { 48 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 49 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 50 | // if the pointer to the beginning is not equal to the pointer to the current position, 51 | // call putOutput() to log the output to the stream 52 | if (pbase() != pptr()) 53 | { 54 | putOutput(); 55 | } 56 | } 57 | 58 | // synchronizes the stream buffer and returns 0 on success 59 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 60 | // resetting the buffer and flushing the stream 61 | virtual int sync() 62 | { 63 | putOutput(); 64 | return 0; 65 | } 66 | 67 | void putOutput() 68 | { 69 | if (mShouldLog) 70 | { 71 | // prepend timestamp 72 | std::time_t timestamp = std::time(nullptr); 73 | tm* tm_local = std::localtime(×tamp); 74 | std::cout << "["; 75 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 76 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 77 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 78 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 81 | // std::stringbuf::str() gets the string contents of the buffer 82 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 83 | mOutput << mPrefix << str(); 84 | // set the buffer to empty 85 | str(""); 86 | // flush the stream 87 | mOutput.flush(); 88 | } 89 | } 90 | 91 | void setShouldLog(bool shouldLog) 92 | { 93 | mShouldLog = shouldLog; 94 | } 95 | 96 | private: 97 | std::ostream& mOutput; 98 | std::string mPrefix; 99 | bool mShouldLog; 100 | }; 101 | 102 | //! 103 | //! \class LogStreamConsumerBase 104 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 105 | //! 106 | class LogStreamConsumerBase 107 | { 108 | public: 109 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 110 | : mBuffer(stream, prefix, shouldLog) 111 | { 112 | } 113 | 114 | protected: 115 | LogStreamConsumerBuffer mBuffer; 116 | }; 117 | 118 | //! 119 | //! \class LogStreamConsumer 120 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 121 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 122 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 123 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 124 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 125 | //! Please do not change the order of the parent classes. 126 | //! 127 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 128 | { 129 | public: 130 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 131 | //! Reportable severity determines if the messages are severe enough to be logged. 132 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 133 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 134 | , std::ostream(&mBuffer) // links the stream buffer with the stream 135 | , mShouldLog(severity <= reportableSeverity) 136 | , mSeverity(severity) 137 | { 138 | } 139 | 140 | LogStreamConsumer(LogStreamConsumer&& other) 141 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 142 | , std::ostream(&mBuffer) // links the stream buffer with the stream 143 | , mShouldLog(other.mShouldLog) 144 | , mSeverity(other.mSeverity) 145 | { 146 | } 147 | 148 | void setReportableSeverity(Severity reportableSeverity) 149 | { 150 | mShouldLog = mSeverity <= reportableSeverity; 151 | mBuffer.setShouldLog(mShouldLog); 152 | } 153 | 154 | private: 155 | static std::ostream& severityOstream(Severity severity) 156 | { 157 | return severity >= Severity::kINFO ? std::cout : std::cerr; 158 | } 159 | 160 | static std::string severityPrefix(Severity severity) 161 | { 162 | switch (severity) 163 | { 164 | case Severity::kINTERNAL_ERROR: return "[F] "; 165 | case Severity::kERROR: return "[E] "; 166 | case Severity::kWARNING: return "[W] "; 167 | case Severity::kINFO: return "[I] "; 168 | case Severity::kVERBOSE: return "[V] "; 169 | default: assert(0); return ""; 170 | } 171 | } 172 | 173 | bool mShouldLog; 174 | Severity mSeverity; 175 | }; 176 | 177 | //! \class Logger 178 | //! 179 | //! \brief Class which manages logging of TensorRT tools and samples 180 | //! 181 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 182 | //! and supports logging two types of messages: 183 | //! 184 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 185 | //! - Test pass/fail messages 186 | //! 187 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 188 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 189 | //! 190 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 191 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 192 | //! 193 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 194 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 195 | //! library and messages coming from the sample. 196 | //! 197 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 198 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 199 | //! object. 200 | 201 | class Logger : public nvinfer1::ILogger 202 | { 203 | public: 204 | Logger(Severity severity = Severity::kWARNING) 205 | : mReportableSeverity(severity) 206 | { 207 | } 208 | 209 | //! 210 | //! \enum TestResult 211 | //! \brief Represents the state of a given test 212 | //! 213 | enum class TestResult 214 | { 215 | kRUNNING, //!< The test is running 216 | kPASSED, //!< The test passed 217 | kFAILED, //!< The test failed 218 | kWAIVED //!< The test was waived 219 | }; 220 | 221 | //! 222 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 223 | //! \return The nvinfer1::ILogger associated with this Logger 224 | //! 225 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 226 | //! we can eliminate the inheritance of Logger from ILogger 227 | //! 228 | nvinfer1::ILogger& getTRTLogger() 229 | { 230 | return *this; 231 | } 232 | 233 | //! 234 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 235 | //! 236 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 237 | //! inheritance from nvinfer1::ILogger 238 | //! 239 | void log(Severity severity, const char* msg) noexcept override 240 | { 241 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 242 | } 243 | 244 | //! 245 | //! \brief Method for controlling the verbosity of logging output 246 | //! 247 | //! \param severity The logger will only emit messages that have severity of this level or higher. 248 | //! 249 | void setReportableSeverity(Severity severity) 250 | { 251 | mReportableSeverity = severity; 252 | } 253 | 254 | //! 255 | //! \brief Opaque handle that holds logging information for a particular test 256 | //! 257 | //! This object is an opaque handle to information used by the Logger to print test results. 258 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 259 | //! with Logger::reportTest{Start,End}(). 260 | //! 261 | class TestAtom 262 | { 263 | public: 264 | TestAtom(TestAtom&&) = default; 265 | 266 | private: 267 | friend class Logger; 268 | 269 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 270 | : mStarted(started) 271 | , mName(name) 272 | , mCmdline(cmdline) 273 | { 274 | } 275 | 276 | bool mStarted; 277 | std::string mName; 278 | std::string mCmdline; 279 | }; 280 | 281 | //! 282 | //! \brief Define a test for logging 283 | //! 284 | //! \param[in] name The name of the test. This should be a string starting with 285 | //! "TensorRT" and containing dot-separated strings containing 286 | //! the characters [A-Za-z0-9_]. 287 | //! For example, "TensorRT.sample_googlenet" 288 | //! \param[in] cmdline The command line used to reproduce the test 289 | // 290 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 291 | //! 292 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 293 | { 294 | return TestAtom(false, name, cmdline); 295 | } 296 | 297 | //! 298 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 299 | //! as input 300 | //! 301 | //! \param[in] name The name of the test 302 | //! \param[in] argc The number of command-line arguments 303 | //! \param[in] argv The array of command-line arguments (given as C strings) 304 | //! 305 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 306 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 307 | { 308 | auto cmdline = genCmdlineString(argc, argv); 309 | return defineTest(name, cmdline); 310 | } 311 | 312 | //! 313 | //! \brief Report that a test has started. 314 | //! 315 | //! \pre reportTestStart() has not been called yet for the given testAtom 316 | //! 317 | //! \param[in] testAtom The handle to the test that has started 318 | //! 319 | static void reportTestStart(TestAtom& testAtom) 320 | { 321 | reportTestResult(testAtom, TestResult::kRUNNING); 322 | assert(!testAtom.mStarted); 323 | testAtom.mStarted = true; 324 | } 325 | 326 | //! 327 | //! \brief Report that a test has ended. 328 | //! 329 | //! \pre reportTestStart() has been called for the given testAtom 330 | //! 331 | //! \param[in] testAtom The handle to the test that has ended 332 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 333 | //! TestResult::kFAILED, TestResult::kWAIVED 334 | //! 335 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 336 | { 337 | assert(result != TestResult::kRUNNING); 338 | assert(testAtom.mStarted); 339 | reportTestResult(testAtom, result); 340 | } 341 | 342 | static int reportPass(const TestAtom& testAtom) 343 | { 344 | reportTestEnd(testAtom, TestResult::kPASSED); 345 | return EXIT_SUCCESS; 346 | } 347 | 348 | static int reportFail(const TestAtom& testAtom) 349 | { 350 | reportTestEnd(testAtom, TestResult::kFAILED); 351 | return EXIT_FAILURE; 352 | } 353 | 354 | static int reportWaive(const TestAtom& testAtom) 355 | { 356 | reportTestEnd(testAtom, TestResult::kWAIVED); 357 | return EXIT_SUCCESS; 358 | } 359 | 360 | static int reportTest(const TestAtom& testAtom, bool pass) 361 | { 362 | return pass ? reportPass(testAtom) : reportFail(testAtom); 363 | } 364 | 365 | Severity getReportableSeverity() const 366 | { 367 | return mReportableSeverity; 368 | } 369 | 370 | private: 371 | //! 372 | //! \brief returns an appropriate string for prefixing a log message with the given severity 373 | //! 374 | static const char* severityPrefix(Severity severity) 375 | { 376 | switch (severity) 377 | { 378 | case Severity::kINTERNAL_ERROR: return "[F] "; 379 | case Severity::kERROR: return "[E] "; 380 | case Severity::kWARNING: return "[W] "; 381 | case Severity::kINFO: return "[I] "; 382 | case Severity::kVERBOSE: return "[V] "; 383 | default: assert(0); return ""; 384 | } 385 | } 386 | 387 | //! 388 | //! \brief returns an appropriate string for prefixing a test result message with the given result 389 | //! 390 | static const char* testResultString(TestResult result) 391 | { 392 | switch (result) 393 | { 394 | case TestResult::kRUNNING: return "RUNNING"; 395 | case TestResult::kPASSED: return "PASSED"; 396 | case TestResult::kFAILED: return "FAILED"; 397 | case TestResult::kWAIVED: return "WAIVED"; 398 | default: assert(0); return ""; 399 | } 400 | } 401 | 402 | //! 403 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 404 | //! 405 | static std::ostream& severityOstream(Severity severity) 406 | { 407 | return severity >= Severity::kINFO ? std::cout : std::cerr; 408 | } 409 | 410 | //! 411 | //! \brief method that implements logging test results 412 | //! 413 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 414 | { 415 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 416 | << testAtom.mCmdline << std::endl; 417 | } 418 | 419 | //! 420 | //! \brief generate a command line string from the given (argc, argv) values 421 | //! 422 | static std::string genCmdlineString(int argc, char const* const* argv) 423 | { 424 | std::stringstream ss; 425 | for (int i = 0; i < argc; i++) 426 | { 427 | if (i > 0) 428 | ss << " "; 429 | ss << argv[i]; 430 | } 431 | return ss.str(); 432 | } 433 | 434 | Severity mReportableSeverity; 435 | }; 436 | 437 | namespace 438 | { 439 | 440 | //! 441 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 442 | //! 443 | //! Example usage: 444 | //! 445 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 446 | //! 447 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 448 | { 449 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 450 | } 451 | 452 | //! 453 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 454 | //! 455 | //! Example usage: 456 | //! 457 | //! LOG_INFO(logger) << "hello world" << std::endl; 458 | //! 459 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 460 | { 461 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 462 | } 463 | 464 | //! 465 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 466 | //! 467 | //! Example usage: 468 | //! 469 | //! LOG_WARN(logger) << "hello world" << std::endl; 470 | //! 471 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 472 | { 473 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 474 | } 475 | 476 | //! 477 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 478 | //! 479 | //! Example usage: 480 | //! 481 | //! LOG_ERROR(logger) << "hello world" << std::endl; 482 | //! 483 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 484 | { 485 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 486 | } 487 | 488 | //! 489 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 490 | // ("fatal" severity) 491 | //! 492 | //! Example usage: 493 | //! 494 | //! LOG_FATAL(logger) << "hello world" << std::endl; 495 | //! 496 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 497 | { 498 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 499 | } 500 | 501 | } // anonymous namespace 502 | 503 | #endif // TENSORRT_LOGGING_H 504 | -------------------------------------------------------------------------------- /main1_onnx2trt.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "logging.h" 3 | #include "NvOnnxParser.h" 4 | #include "NvInfer.h" 5 | #include 6 | 7 | using namespace nvinfer1; 8 | using namespace nvonnxparser; 9 | 10 | static Logger gLogger; 11 | int main(int argc,char** argv) { 12 | if (argc < 2){ 13 | argv[1] = "../models/yolov5s-seg.onnx"; 14 | argv[2] = "../models/yolov5s-seg.engine"; 15 | } 16 | // 1 ����onnxģ�� 17 | IBuilder* builder = createInferBuilder(gLogger); 18 | const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 19 | INetworkDefinition* network = builder->createNetworkV2(explicitBatch); 20 | 21 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger); 22 | 23 | const char* onnx_filename = argv[1]; 24 | parser->parseFromFile(onnx_filename, static_cast(Logger::Severity::kWARNING)); 25 | for (int i = 0; i < parser->getNbErrors(); ++i) 26 | { 27 | std::cout << parser->getError(i)->desc() << std::endl; 28 | } 29 | std::cout << "successfully load the onnx model" << std::endl; 30 | 31 | // 2��build the engine 32 | unsigned int maxBatchSize = 1; 33 | builder->setMaxBatchSize(maxBatchSize); 34 | IBuilderConfig* config = builder->createBuilderConfig(); 35 | //config->setMaxWorkspaceSize(1 << 20); 36 | config->setMaxWorkspaceSize(128 * (1 << 20)); // 16MB 37 | config->setFlag(BuilderFlag::kFP16); 38 | ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); 39 | 40 | // 3��serialize Model 41 | IHostMemory *gieModelStream = engine->serialize(); 42 | std::ofstream p(argv[2], std::ios::binary); 43 | if (!p) 44 | { 45 | std::cerr << "could not open plan output file" << std::endl; 46 | return -1; 47 | } 48 | p.write(reinterpret_cast(gieModelStream->data()), gieModelStream->size()); 49 | gieModelStream->destroy(); 50 | 51 | 52 | std::cout << "successfully generate the trt engine model" << std::endl; 53 | return 0; 54 | } -------------------------------------------------------------------------------- /main2_trt_infer.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov5-instance-seg-tensorrt/bda2aca7604e5543ea95d32b07db6076efb73354/main2_trt_infer.cpp -------------------------------------------------------------------------------- /models/yolov5s-seg.engine: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov5-instance-seg-tensorrt/bda2aca7604e5543ea95d32b07db6076efb73354/models/yolov5s-seg.engine -------------------------------------------------------------------------------- /models/yolov5s-seg.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov5-instance-seg-tensorrt/bda2aca7604e5543ea95d32b07db6076efb73354/models/yolov5s-seg.onnx -------------------------------------------------------------------------------- /output.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov5-instance-seg-tensorrt/bda2aca7604e5543ea95d32b07db6076efb73354/output.jpg -------------------------------------------------------------------------------- /utils.h: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fish-kong/Yolov5-instance-seg-tensorrt/bda2aca7604e5543ea95d32b07db6076efb73354/utils.h --------------------------------------------------------------------------------