├── .gitignore ├── CMakeLists.txt ├── PINetTensorrt.cpp ├── README.md ├── common ├── BatchStream.h ├── EntropyCalibrator.h ├── ErrorRecorder.h ├── argsParser.h ├── buffers.h ├── common.h ├── dumpTFWts.py ├── getOptions.cpp ├── getOptions.h ├── half.h ├── logger.cpp ├── logger.h ├── logging.h ├── parserOnnxConfig.h ├── sampleConfig.h ├── sampleEngines.cpp ├── sampleEngines.h ├── sampleOptions.cpp ├── sampleOptions.h └── sampleUtils.h └── data └── 1492638000682869180 ├── 1.jpg ├── 10.jpg ├── 11.jpg ├── 12.jpg ├── 13.jpg ├── 14.jpg ├── 15.jpg ├── 16.jpg ├── 17.jpg ├── 18.jpg ├── 19.jpg ├── 2.jpg ├── 20.jpg ├── 3.jpg ├── 4.jpg ├── 5.jpg ├── 6.jpg ├── 7.jpg ├── 8.jpg └── 9.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | .vscode -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(PINetTensorrt) 4 | 5 | # add_compile_options("-g") 6 | add_compile_options("-o2") 7 | 8 | find_package(OpenCV REQUIRED) 9 | 10 | set(TEGRA_LIB_DIR /usr/lib/aarch64-linux-gnu/tegra) 11 | set(CUDA_INSTALL_DIR /usr/local/cuda/) 12 | set(CUDA_INCLUDE_DIR ${CUDA_INSTALL_DIR}/include) 13 | set(CUDA_LIB_DIR ${CUDA_INSTALL_DIR}/lib64) 14 | 15 | include_directories(common ${CUDA_INCLUDE_DIR} ${OpenCV_INCLUDE_DIRS} ) 16 | 17 | aux_source_directory(. SRCS) 18 | aux_source_directory(common COMMON_SRCS) 19 | 20 | link_directories(${CUDA_LIB_DIR} ${TEGRA_LIB_DIR}) 21 | add_executable(${PROJECT_NAME} ${COMMON_SRCS} ${SRCS}) 22 | 23 | set(CUDA_LIB cuda cudnn cublas cudart culibos) 24 | set(NV_LIB nvinfer nvparsers nvinfer_plugin nvonnxparser) 25 | 26 | target_link_libraries(${PROJECT_NAME} ${CUDA_LIB} ${NV_LIB} ${OpenCV_LIBS}) -------------------------------------------------------------------------------- /PINetTensorrt.cpp: -------------------------------------------------------------------------------- 1 | #include "argsParser.h" 2 | #include "buffers.h" 3 | #include "common.h" 4 | #include "logger.h" 5 | #include "parserOnnxConfig.h" 6 | 7 | #include "NvInfer.h" 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | #include 22 | #include 23 | #include 24 | 25 | namespace { 26 | const std::string gSampleName = "TensorRT.onnx_PINet"; 27 | 28 | const int output_base_index = 3; 29 | const float threshold_point = 0.81f; 30 | const float threshold_instance = 0.22f; 31 | const int resize_ratio = 8; 32 | 33 | int64 total_inference_execute_elasped_time = 0; 34 | int64 total_inference_execute_times = 0; 35 | 36 | using LaneLine = std::vector; 37 | using LaneLines = std::vector; 38 | 39 | cv::Mat chwDataToMat(int channelNum, int height, int width, float* data, cv::Mat& mask) { 40 | std::vector channels(channelNum); 41 | int data_size = width * height; 42 | for (int c = 0; c < channelNum; ++c) { 43 | float* channel_data = data + data_size * c; 44 | cv::Mat channel(height, width, CV_32FC1); 45 | for (int h = 0; h < height; ++h) { 46 | for (int w = 0; w < width; ++w, ++channel_data) { 47 | channel.at(h, w) = *channel_data * (int)mask.at(h, w); 48 | } 49 | } 50 | channels[c] = channel; 51 | } 52 | 53 | cv::Mat mergedMat; 54 | cv::merge(channels.data(), channelNum, mergedMat); 55 | return mergedMat; 56 | } 57 | 58 | void getFiles(std::string root_dir, std::string ext, std::vector& files) { 59 | DIR *dir; 60 | struct dirent *ptr; 61 | 62 | if ((dir = opendir(root_dir.c_str())) == NULL) { 63 | gLogInfo << "Open dir error..." << std::endl; 64 | return; 65 | } 66 | 67 | while ((ptr = readdir(dir)) != NULL) { 68 | if (strcmp(ptr->d_name,".") == 0 || strcmp(ptr->d_name,"..") == 0) { 69 | continue; 70 | } else if(ptr->d_type == 8) {// file 71 | char* dot = strchr(ptr->d_name, '.'); 72 | if (dot && !strcasecmp(dot, ext.c_str())) { 73 | std::string filename(root_dir); 74 | filename.append("/").append(ptr->d_name); 75 | files.push_back(filename); 76 | } 77 | } else if(ptr->d_type == 10) { // link file 78 | continue; 79 | } else if(ptr->d_type == 4) {// dir 80 | std::string dir_path(root_dir); 81 | dir_path.append("/").append(ptr->d_name); 82 | getFiles(dir_path.c_str(), ext, files); 83 | } 84 | } 85 | 86 | closedir(dir); 87 | } 88 | } 89 | 90 | //! \brief The PINetTensorrt class implements the ONNX PINet sample 91 | //! 92 | //! \details It creates the network using an ONNX model 93 | //! 94 | class PINetTensorrt 95 | { 96 | template 97 | using UniquePtr = std::unique_ptr; 98 | 99 | public: 100 | PINetTensorrt(const common::OnnxParams& params) 101 | : mParams(params) 102 | , mEngine(nullptr) 103 | { 104 | } 105 | 106 | //! 107 | //! \brief Function builds the network engine 108 | //! 109 | bool build(); 110 | 111 | //! 112 | //! \brief Runs the TensorRT inference engine for this sample 113 | //! 114 | bool infer(); 115 | 116 | void setImageFile(const std::string& imageFileName) { 117 | mImageFileName = imageFileName; 118 | } 119 | 120 | private: 121 | common::OnnxParams mParams; //!< The parameters for the sample. 122 | 123 | nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network. 124 | std::vector mOutputDims; //!< The dimensions of the output to the network. 125 | std::string mImageFileName; //!< The number to classify 126 | cv::Mat mInputImage; 127 | 128 | std::shared_ptr mEngine; //!< The TensorRT engine used to run the network 129 | 130 | //! 131 | //! \brief Parses an ONNX model for MNIST and creates a TensorRT network 132 | //! 133 | bool constructNetwork(UniquePtr& builder, 134 | UniquePtr& network, UniquePtr& config, 135 | UniquePtr& parser); 136 | 137 | //! 138 | //! \brief Reads the input and stores the result in a managed buffer 139 | //! 140 | bool processInput(const common::BufferManager& buffers); 141 | //! 142 | //! \brief Classifies digits and verify result 143 | //! 144 | bool verifyOutput(const common::BufferManager& buffers); 145 | 146 | void generatePostData(float* confidance_data, float* offsets_data, float* instance_data, cv::Mat& mask, cv::Mat& offsets, cv::Mat& features); 147 | 148 | LaneLines generateLaneLine(float* confidance_data, float* offsets_data, float* instance_data); 149 | }; 150 | 151 | //! 152 | //! \brief Creates the network, configures the builder and creates the network engine 153 | //! 154 | //! \details This function creates the Onnx MNIST network by parsing the Onnx model and builds 155 | //! the engine that will be used to run MNIST (mEngine) 156 | //! 157 | //! \return Returns true if the engine was created successfully and false otherwise 158 | //! 159 | bool PINetTensorrt::build() 160 | { 161 | auto builder = UniquePtr(nvinfer1::createInferBuilder(gLogger.getTRTLogger())); 162 | if (!builder) 163 | { 164 | return false; 165 | } 166 | 167 | auto network = UniquePtr(builder->createNetwork()); 168 | if (!network) 169 | { 170 | return false; 171 | } 172 | 173 | auto config = UniquePtr(builder->createBuilderConfig()); 174 | if (!config) 175 | { 176 | return false; 177 | } 178 | 179 | auto parser = UniquePtr(nvonnxparser::createParser(*network, gLogger.getTRTLogger())); 180 | if (!parser) 181 | { 182 | return false; 183 | } 184 | 185 | auto constructed = constructNetwork(builder, network, config, parser); 186 | if (!constructed) 187 | { 188 | return false; 189 | } 190 | 191 | mEngine = std::shared_ptr(builder->buildEngineWithConfig(*network, *config), common::InferDeleter()); 192 | if (!mEngine) 193 | { 194 | return false; 195 | } 196 | 197 | if (gLogger.getReportableSeverity() == Logger::Severity::kVERBOSE) { 198 | for (int i = 0; i < network->getNbInputs(); ++i) { 199 | nvinfer1::Dims dim = network->getInput(i)->getDimensions(); 200 | gLogInfo << "InputDims: " << i << " " << dim.d[0] << " " << dim.d[1] << " " << dim.d[2] << std::endl; 201 | } 202 | 203 | for (int i = 0; i < network->getNbOutputs(); ++i) { 204 | nvinfer1::Dims dim = network->getOutput(i)->getDimensions(); 205 | gLogInfo << "OutputDims: " << i << " " << dim.d[0] << " " << dim.d[1] << " " << dim.d[2] << std::endl; 206 | } 207 | } 208 | 209 | assert(network->getNbInputs() == 1); 210 | mInputDims = network->getInput(0)->getDimensions(); 211 | assert(mInputDims.nbDims == 3); 212 | 213 | assert(network->getNbOutputs() == 6); 214 | for (int i = 0; i < network->getNbOutputs(); ++i) { 215 | nvinfer1::Dims dim = network->getOutput(i)->getDimensions(); 216 | mOutputDims.push_back(dim); 217 | assert(dim.nbDims == 3); 218 | } 219 | 220 | return true; 221 | } 222 | 223 | //! 224 | //! \brief Uses a ONNX parser to create the Onnx MNIST Network and marks the 225 | //! output layers 226 | //! 227 | //! \param network Pointer to the network that will be populated with the Onnx MNIST network 228 | //! 229 | //! \param builder Pointer to the engine builder 230 | //! 231 | bool PINetTensorrt::constructNetwork(UniquePtr& builder, 232 | UniquePtr& network, UniquePtr& config, 233 | UniquePtr& parser) 234 | { 235 | auto parsed = parser->parseFromFile(mParams.onnxFileName.c_str(), static_cast(gLogger.getReportableSeverity())); 236 | if (!parsed) 237 | { 238 | return false; 239 | } 240 | 241 | builder->setMaxBatchSize(mParams.batchSize); 242 | config->setMaxWorkspaceSize(1 << 30); 243 | if (mParams.fp16) 244 | { 245 | config->setFlag(BuilderFlag::kFP16); 246 | } 247 | if (mParams.int8) 248 | { 249 | config->setFlag(BuilderFlag::kINT8); 250 | common::setAllTensorScales(network.get(), 127.0f, 127.0f); 251 | } 252 | 253 | common::enableDLA(builder.get(), config.get(), mParams.dlaCore); 254 | 255 | return true; 256 | } 257 | 258 | //! 259 | //! \brief Runs the TensorRT inference engine for this sample 260 | //! 261 | //! \details This function is the main execution function of the sample. It allocates the buffer, 262 | //! sets inputs and executes the engine. 263 | //! 264 | bool PINetTensorrt::infer() 265 | { 266 | // Create RAII buffer manager object 267 | common::BufferManager buffers(mEngine, mParams.batchSize); 268 | 269 | auto context = UniquePtr(mEngine->createExecutionContext()); 270 | if (!context) 271 | { 272 | return false; 273 | } 274 | 275 | // Read the input data into the managed buffers 276 | assert(mParams.inputTensorNames.size() == 1); 277 | if (!processInput(buffers)) 278 | { 279 | return false; 280 | } 281 | 282 | auto inferenceBeginTime = std::chrono::high_resolution_clock::now(); 283 | // Memcpy from host input buffers to device input buffers 284 | buffers.copyInputToDevice(); 285 | 286 | bool status = context->execute(mParams.batchSize, buffers.getDeviceBindings().data()); 287 | if (!status) 288 | { 289 | return false; 290 | } 291 | 292 | // Memcpy from device output buffers to host output buffers 293 | buffers.copyOutputToHost(); 294 | 295 | auto inference_execute_elapsed_time = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - inferenceBeginTime); 296 | total_inference_execute_elasped_time += inference_execute_elapsed_time.count(); 297 | ++total_inference_execute_times; 298 | 299 | //gLogInfo << "inference elapsed time: " << inferenceElapsedTime.count() / 1000.f << " milliseconds" << std::endl; 300 | 301 | // Verify results 302 | if (!verifyOutput(buffers)) 303 | { 304 | return false; 305 | } 306 | 307 | return true; 308 | } 309 | 310 | //! 311 | //! \brief Reads the input and stores the result in a managed buffer 312 | //! 313 | bool PINetTensorrt::processInput(const common::BufferManager& buffers) 314 | { 315 | const int inputC = mInputDims.d[0]; 316 | const int inputW = mInputDims.d[1]; 317 | const int inputH = mInputDims.d[2]; 318 | 319 | cv::Mat image = cv::imread(mImageFileName, 1); 320 | assert(inputC == image.channels()); 321 | cv::resize(image, image, cv::Size(inputH, inputW)); 322 | 323 | mInputImage = image; 324 | 325 | float* hostDataBuffer = static_cast(buffers.getHostBuffer(mParams.inputTensorNames[0])); 326 | uchar* imageData = image.ptr(); 327 | for (int c = 0; c < inputC; ++c) { 328 | for (unsigned j = 0, volChl = inputH * inputW; j < volChl; ++j) { 329 | hostDataBuffer[c * volChl + j] = float(imageData[j * inputC + c]) / 255.f; 330 | } 331 | } 332 | 333 | return true; 334 | } 335 | 336 | void PINetTensorrt::generatePostData(float* confidance_data, float* offsets_data, float* instance_data, cv::Mat& mask, cv::Mat& offsets, cv::Mat& features) 337 | { 338 | const nvinfer1::Dims& dim = mOutputDims[output_base_index];//1 32 64 339 | const nvinfer1::Dims& offset_dim = mOutputDims[output_base_index + 1];//2 32 64 340 | const nvinfer1::Dims& instance_dim = mOutputDims[output_base_index + 2];//4 32 64 341 | 342 | mask = cv::Mat::zeros(dim.d[1], dim.d[2], CV_8UC1); 343 | float* confidance_ptr = confidance_data; 344 | for (int i = 0; i < dim.d[1]; ++i) { 345 | for (int j = 0; j < dim.d[2]; ++j, ++confidance_ptr) { 346 | if (*confidance_ptr > threshold_point) { 347 | mask.at(i, j) = 1; 348 | } 349 | } 350 | } 351 | 352 | if (gLogger.getReportableSeverity() == Logger::Severity::kVERBOSE) { 353 | gLogInfo << "Output mask:" << std::endl; 354 | for (int i = 0; i < dim.d[1]; ++i) { 355 | for (int j = 0; j < dim.d[2]; ++j) { 356 | gLogInfo << (int)mask.at(i, j); 357 | } 358 | gLogInfo << std::endl; 359 | } 360 | 361 | cv::Mat maskImage = mInputImage.clone(); 362 | cv::Scalar color(0, 0, 255); 363 | for (int i = 0; i < dim.d[1]; ++i) { 364 | for (int j = 0; j < dim.d[2]; ++j) { 365 | if ((int)mask.at(i, j)) { 366 | cv::circle(maskImage, cv::Point2f(j * 8, i * 8), 3, color, -1); 367 | } 368 | } 369 | } 370 | cv::imshow("mask", maskImage); 371 | cv::waitKey(0); 372 | } 373 | 374 | offsets = chwDataToMat(offset_dim.d[0], offset_dim.d[1], offset_dim.d[2], offsets_data, mask); 375 | features = chwDataToMat(instance_dim.d[0], instance_dim.d[1], instance_dim.d[2], instance_data, mask); 376 | 377 | if (gLogger.getReportableSeverity() == Logger::Severity::kVERBOSE) { 378 | gLogInfo << "Output offset:" << std::endl; 379 | for (int i = 0; i < dim.d[1]; ++i) { 380 | for (int j = 0; j < dim.d[2]; ++j) { 381 | gLogInfo << (offsets.at(i, j)[0] ? 1 : 0); 382 | } 383 | gLogInfo << std::endl; 384 | } 385 | 386 | cv::Mat offsetImage = mInputImage.clone(); 387 | cv::Scalar color(0, 0, 255); 388 | for (int i = 0; i < dim.d[1]; ++i) { 389 | for (int j = 0; j < dim.d[2]; ++j) { 390 | if ((int)mask.at(i, j)) { 391 | cv::Vec2f pointOffset = offsets.at(i, j); 392 | cv::Point2f point(pointOffset[1] + j, pointOffset[0] + i); 393 | cv::circle(offsetImage, point * 8, 3, color, -1); 394 | } 395 | } 396 | } 397 | cv::imshow("offset", offsetImage); 398 | cv::waitKey(0); 399 | 400 | gLogInfo << "Output instance:" << std::endl; 401 | for (int i = 0; i < dim.d[1]; ++i) { 402 | for (int j = 0; j < dim.d[2]; ++j) { 403 | gLogInfo << (features.at(i, j)[0] ? 1 : 0); 404 | } 405 | gLogInfo << std::endl; 406 | } 407 | } 408 | } 409 | 410 | LaneLines PINetTensorrt::generateLaneLine(float* confidance_data, float* offsets_data, float* instance_data) 411 | { 412 | const nvinfer1::Dims& dim = mOutputDims[output_base_index];//1 32 64 413 | 414 | cv::Mat mask, offsets, features; 415 | generatePostData(confidance_data, offsets_data, instance_data, mask, offsets, features); 416 | 417 | LaneLines laneLines; 418 | std::vector laneFeatures; 419 | 420 | auto findNearestFeature = [&laneFeatures](const cv::Vec4f& feature) -> int { 421 | for (int i = 0; i < laneFeatures.size(); ++i) { 422 | auto delta = laneFeatures[i] - feature; 423 | if (delta.dot(delta) <= threshold_instance) { 424 | return i; 425 | } 426 | } 427 | return -1; 428 | }; 429 | 430 | for (int i = 0; i < dim.d[1]; ++i) { 431 | for (int j = 0; j < dim.d[2]; ++j) { 432 | if ((int)mask.at(i, j) == 0) { 433 | continue; 434 | } 435 | 436 | const cv::Vec2f& offset = offsets.at(i, j); 437 | cv::Point2f point(offset[1] + j, offset[0] + i); 438 | if (point.x > dim.d[2] || point.x < 0.f) continue; 439 | if (point.y > dim.d[1] || point.y < 0.f) continue; 440 | 441 | const cv::Vec4f& feature = features.at(i, j); 442 | int lane_index = findNearestFeature(feature); 443 | 444 | if (lane_index == -1) { 445 | laneLines.emplace_back(LaneLine({point})); 446 | laneFeatures.emplace_back(feature); 447 | } else { 448 | auto& laneline = laneLines[lane_index]; 449 | auto& lanefeature = laneFeatures[lane_index]; 450 | 451 | auto point_size = laneline.size(); 452 | 453 | lanefeature = lanefeature.mul(cv::Vec4f::all(point_size)) + feature; 454 | lanefeature = lanefeature.mul(cv::Vec4f::all(1.f / (point_size + 1))); 455 | laneline.emplace_back(point); 456 | } 457 | } 458 | } 459 | 460 | for (auto itr = laneLines.begin(); itr != laneLines.end();) { 461 | if ((*itr).size() < 2) { 462 | itr = laneLines.erase(itr); 463 | } else { 464 | ++itr; 465 | } 466 | } 467 | 468 | return laneLines; 469 | } 470 | 471 | //! 472 | //! \brief verify result 473 | //! 474 | //! \return whether output matches expectations 475 | //! 476 | bool PINetTensorrt::verifyOutput(const common::BufferManager& buffers) 477 | { 478 | float *confidance, *offset, *instance; 479 | confidance = static_cast(buffers.getHostBuffer(mParams.outputTensorNames[output_base_index + 0])); 480 | offset = static_cast(buffers.getHostBuffer(mParams.outputTensorNames[output_base_index + 1])); 481 | instance = static_cast(buffers.getHostBuffer(mParams.outputTensorNames[output_base_index + 2])); 482 | 483 | nvinfer1::Dims confidanceDims = mOutputDims[output_base_index + 0]; 484 | nvinfer1::Dims offsetDims = mOutputDims[output_base_index + 1]; 485 | nvinfer1::Dims instanceDims = mOutputDims[output_base_index + 2]; 486 | 487 | assert(confidanceDims.d[0] == 1); 488 | assert(offsetDims.d[0] == 2); 489 | assert(instanceDims.d[0] == 4); 490 | 491 | LaneLines lanelines = generateLaneLine(confidance, offset, instance); 492 | if (lanelines.empty()) 493 | return false; 494 | 495 | cv::Scalar color[] = {{255, 0, 0}, { 0, 255, 0}, { 0, 0, 255}, 496 | {255, 255, 0}, {255, 0, 255}, { 0, 255, 255}, 497 | {255, 255, 255}, {100, 255, 0}, {100, 0, 255}, 498 | {255, 100, 0}, { 0, 100, 255}, {255, 0, 100}, 499 | { 0, 255, 100}}; 500 | 501 | cv::Mat lanelineImage = mInputImage; 502 | for (int i = 0; i < lanelines.size(); ++i) { 503 | for (const auto& point : lanelines[i]) { 504 | cv::circle(lanelineImage, cv::Point2f(point * 8), 3, color[i], -1); 505 | } 506 | } 507 | 508 | if (gLogger.getReportableSeverity() == Logger::Severity::kINFO) { 509 | cv::imwrite("lanelines.jpg", lanelineImage); 510 | 511 | cv::imshow("lanelines", lanelineImage); 512 | cv::waitKey(0); 513 | } 514 | 515 | return true; 516 | } 517 | //! 518 | //! \brief Initializes members of the params struct using the command line args 519 | //! 520 | common::OnnxParams initializeSampleParams(const common::Args& args) 521 | { 522 | common::OnnxParams params; 523 | if (args.dataDirs.empty()) {//!< Use default directories if user hasn't provided directory paths 524 | params.dataDirs.push_back("./data/1492638000682869180"); 525 | } else {//!< Use the data directory provided by the user 526 | params.dataDirs = args.dataDirs; 527 | } 528 | 529 | char pwd[1024] = {0}; 530 | getcwd(pwd, sizeof(pwd)); 531 | 532 | params.onnxFileName = "pinet.onnx"; 533 | params.inputTensorNames.push_back("0"); 534 | params.batchSize = 1; 535 | params.outputTensorNames.push_back("1431"); 536 | params.outputTensorNames.push_back("1438"); 537 | params.outputTensorNames.push_back("1445"); 538 | params.outputTensorNames.push_back("1679"); 539 | params.outputTensorNames.push_back("1686"); 540 | params.outputTensorNames.push_back("1693"); 541 | params.dlaCore = args.useDLACore; 542 | params.int8 = args.runInInt8; 543 | params.fp16 = args.runInFp16; 544 | 545 | return params; 546 | } 547 | 548 | //! 549 | //! \brief Prints the help information for running this sample 550 | //! 551 | void printHelpInfo() 552 | { 553 | std::cout << "Usage: ./pinettensorrt [-h or --help] [-d or --datadir=] [--useDLACore=]" << std::endl; 554 | std::cout << "--help Display help information" << std::endl; 555 | std::cout << "--datadir Specify path to a data path, overriding the default. This option can be used multiple times to add multiple directories. If no data directories are given, the default is to use (data/samples/mnist/, data/mnist/)" << std::endl; 556 | std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, where n is the number of DLA engines on the platform." << std::endl; 557 | std::cout << "--int8 Run in Int8 mode." << std::endl; 558 | std::cout << "--fp16 Run in FP16 mode." << std::endl; 559 | } 560 | 561 | int main(int argc, char** argv) 562 | { 563 | common::Args args; 564 | bool argsOK = common::parseArgs(args, argc, argv); 565 | if (!argsOK) 566 | { 567 | gLogError << "Invalid arguments" << std::endl; 568 | printHelpInfo(); 569 | return EXIT_FAILURE; 570 | } 571 | if (args.help) 572 | { 573 | printHelpInfo(); 574 | return EXIT_SUCCESS; 575 | } 576 | 577 | setReportableSeverity(Logger::Severity::kINFO); 578 | auto test = gLogger.defineTest(gSampleName, argc, argv); 579 | 580 | gLogger.reportTestStart(test); 581 | 582 | common::OnnxParams onnx_args = initializeSampleParams(args); 583 | PINetTensorrt sample(onnx_args); 584 | 585 | gLogInfo << "Building and running a GPU inference engine for Onnx PINet" << std::endl; 586 | 587 | if (!sample.build()) 588 | { 589 | return gLogger.reportFail(test); 590 | } 591 | 592 | std::vector filenames; 593 | filenames.reserve(20480); 594 | for (size_t i = 0; i < onnx_args.dataDirs.size(); i++) { 595 | getFiles(onnx_args.dataDirs[i], ".jpg", filenames); 596 | } 597 | 598 | auto inference_begin_time = std::chrono::high_resolution_clock::now(); 599 | 600 | for (const auto& filename : filenames) { 601 | sample.setImageFile(filename); 602 | if (!sample.infer()) { 603 | gLogger.reportFail(test); 604 | } 605 | } 606 | 607 | auto inference_elapsed_time = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - inference_begin_time); 608 | 609 | gLogger.reportPass(test); 610 | 611 | gLogInfo << std::endl; 612 | 613 | gLogInfo << "totally inference time : " << inference_elapsed_time.count() / 1000.f << " milliseconds" << std::endl; 614 | if (filenames.size()) { 615 | gLogInfo << "totally inference times : " << filenames.size() << std::endl; 616 | gLogInfo << "average inference time : " << inference_elapsed_time.count() / filenames.size() / 1000.f << " milliseconds"<< std::endl; 617 | } 618 | 619 | if (total_inference_execute_times > 0) { 620 | gLogInfo << "totally execute elapsed time: " << total_inference_execute_elasped_time / 1000.f << " milliseconds" << std::endl << std::endl; 621 | gLogInfo << "inference execute times : " << total_inference_execute_times << std::endl; 622 | gLogInfo << "average execute elapsed time: " << total_inference_execute_elasped_time / total_inference_execute_times / 1000.f << " milliseconds" << std::endl << std::endl; 623 | } 624 | 625 | return 0; 626 | } 627 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## key points estimation and point instance segmentation approach for lane detection 2 | 3 | - Paper : key points estimation and point instance segmentation approach for lane detection 4 | - Paper Link : https://arxiv.org/abs/2002.06604 5 | - Author : Yeongmin Ko, Jiwon Jun, Donghwuy Ko, Moongu Jeon (Gwanju Institute of Science and Technology) 6 | 7 | 8 | - This repository is TensorRT implement of [PINet](https://github.com/koyeongmin/PINet) 9 | 10 | 11 | ## Dependency 12 | 13 | - TensorRT 6.0 14 | - OpenCV 15 | 16 | ## Convert 17 | 18 | - clone [PINet](https://github.com/koyeongmin/PINet) source code 19 | ```python 20 | git clone https://github.com/koyeongmin/PINet.git 21 | ``` 22 | 23 | you can convert Pytorch weights file to onnx file, follow as: 24 | 25 | - insert this code at end of agent.py : 26 | 27 | ```python 28 | def export_onnx(self, input_image, filename): 29 | torch_out = torch.onnx.export(self.lane_detection_network, input_image, filename, verbose=True) 30 | ``` 31 | 32 | - run this code to convert weights file to onnx, please use pytorch 1.0.1 33 | 34 | ```python 35 | import torch 36 | import agent 37 | 38 | batch_size = 1 39 | input_shape = (3, 256, 512) 40 | dummy_input = torch.randn(batch_size, *input_shape, device='cuda') 41 | lane_agent = agent.Agent() 42 | lane_agent.load_weights(640, "tensor(0.2298)") 43 | lane_agent.cuda() 44 | lane_agent.evaluate_mode() 45 | lane_agent.export_onnx(dummy_input, "pinet.onnx") 46 | ``` 47 | 48 | ## Run 49 | - run this program with image directory 50 | 51 | ```shell 52 | ./PINetTensorrt --datadir= 53 | ``` 54 | 55 | - or run this program with default images 56 | 57 | ```shell 58 | ./PINetTensorrt 59 | ``` 60 | 61 | ## Test 62 | 63 | ### Object 64 | - Pytorch implement of PINet 65 | - Tensorrt C++ implement of PINet 66 | 67 | 68 | ### Purpose 69 | 70 | - Tensorrt performance under X86 architecture 71 | - Tensorrt performance under Xavier 72 | 73 | 74 | ### Dataset 75 | 76 | data source:tusimple dataset 0531 directory 77 | 78 | image format:jpg 79 | 80 | image size:1280 x 720 81 | 82 | image channels:RGB 83 | 84 | count of images:14300 85 | 86 | disk space size: 3GB 87 | 88 | 89 | --- 90 | ### X86 Computer 91 | 92 | OS:ubuntu 18.04 93 | 94 | CPU:AMD Ryzen 7 3700X 8-Core Processor 95 | 96 | CPU Frequency: 3600 mhz 97 | 98 | ram:32GB 3200mhz 99 | 100 | video card:Nvidia Titan 101 | 102 | vram: 6G 103 | 104 | disk:Seagate cool fish 7200rpm 105 | 106 | 107 | 108 | ### Xavier 109 | 110 | OS:ubuntu 18.04 111 | 112 | CPU:ARMv8 Processor rev 0 (v8l) 113 | 114 | CPU Frequency: 2036 mhz 115 | 116 | ram:16GB 117 | 118 | 119 | ### Test 120 | 121 | ##### Explain 122 | 123 | - end to end:elapsed time of read image, inference,post processing,draw lane line result to image 124 | - execute: elapsed time of copy host ram to device vram,inference exectute, copy device vram to host ram 125 | - totally end to end : elapsed time of dataset test, sum of end to end 126 | - totally execute: elapsed time of dataset test, sum of execute 127 | 128 | ``` 129 | end to end = totally end to end / count of image in dataset 130 | execute = totally execute / count of image in dataset 131 | ``` 132 | 133 | 134 | ##### X86 && Pytorch Implement 135 | 136 | | NO. | totally end to end | end to end(ms) | totally execute(s) | execute(ms) | 137 | | ---- | ------------------ | ---------------- | ------------------- | ------------- | 138 | | 1 | 39m54.79s | 167.46 | 229.92 | 16.07 | 139 | | 2 | 38m28.37s | 161.42 | 222.29 | 15.54 | 140 | | 3 | 38m04.18s | 159.73 | 224.73 | 15.71 | 141 | | 4 | 37m40.54s | 158.08 | 218.91 | 15.30 | 142 | | 5 | 38m05.84s | 159.84 | 223.84 | 15.65 | 143 | | average | 38m26.74s | 161.31 | 233.94 | 15.65 | 144 | 145 | 146 | 147 | ##### X86 && Tensorrt C++ Implement 148 | 149 | | NO. | totally end to end(s) | end to end(ms) | totally execute(s) | execute(ms) | 150 | | ---- | ---------------------- | ---------------- | ------------------- | ------------- | 151 | | 1 | 335.970 | 23.494 | 152.362 | 10.654 | 152 | | 2 | 346.983 | 24.264 | 154.873 | 10.83 | 153 | | 3 | 338.014 | 23.637 | 153.296 | 10.72 | 154 | | 4 | 342.812 | 23.972 | 154.606 | 10.811 | 155 | | 5 | 343.489 | 24.02 | 154.693 | 10.817 | 156 | | average | 341.45 | 23.88 | 153.966 | 10.77 | 157 | 158 | 159 | ##### Xavier && Tensorrt C++ Implement 160 | 161 | | NO. | totally end to end(s) | end to end(ms) | totally execute(s) | execute(ms) | 162 | | ---- | ---------------------- | ---------------- | ------------------- | ------------- | 163 | | 1 | 709.816 | 49.637 | 289.398 | 20.237 | 164 | | 2 | 652.201 | 45.608 | 287.493 | 20.104 | 165 | | 3 | 651.780 | 45.578 | 290.308 | 20.301 | 166 | | 4 | 650.099 | 45.461 | 287.789 | 20.125 | 167 | | 5 | 657.376 | 45.97 | 287.569 | 20.109 | 168 | | average | 664.254 | 46.45 | 288.51 | 20.175 | 169 | 170 | 171 | ##### Result 172 | - elapsed time of inference execute under x86 architecture, Tensorrt C++ implement faster 1.5 times than Pytorch implement 173 | - elapsed time of end to end under x86 architecture, Tensorrt C++ implement faster 10 times than Pytorch implement 174 | - elapsed time of inference execute under Xavier, x86 architecture faster 2 times,takes 20 ms on average -------------------------------------------------------------------------------- /common/BatchStream.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef BATCH_STREAM_H 17 | #define BATCH_STREAM_H 18 | 19 | #include "NvInfer.h" 20 | #include "common.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | class IBatchStream 27 | { 28 | public: 29 | virtual void reset(int firstBatch) = 0; 30 | virtual bool next() = 0; 31 | virtual void skip(int skipCount) = 0; 32 | virtual float* getBatch() = 0; 33 | virtual float* getLabels() = 0; 34 | virtual int getBatchesRead() const = 0; 35 | virtual int getBatchSize() const = 0; 36 | virtual nvinfer1::Dims getDims() const = 0; 37 | virtual nvinfer1::Dims getImageDims() const = 0; 38 | }; 39 | 40 | class MNISTBatchStream : public IBatchStream 41 | { 42 | public: 43 | MNISTBatchStream(int batchSize, int maxBatches, const std::string& dataFile, const std::string& labelsFile, 44 | const std::vector& directories) 45 | : mBatchSize{batchSize} 46 | , mMaxBatches{maxBatches} 47 | , mDims{3, 1, 28, 28} //!< We already know the dimensions of MNIST images. 48 | { 49 | readDataFile(locateFile(dataFile, directories)); 50 | readLabelsFile(locateFile(labelsFile, directories)); 51 | } 52 | 53 | void reset(int firstBatch) override 54 | { 55 | mBatchCount = firstBatch; 56 | } 57 | 58 | bool next() override 59 | { 60 | if (mBatchCount >= mMaxBatches) 61 | { 62 | return false; 63 | } 64 | ++mBatchCount; 65 | return true; 66 | } 67 | 68 | void skip(int skipCount) override 69 | { 70 | mBatchCount += skipCount; 71 | } 72 | 73 | float* getBatch() override 74 | { 75 | return mData.data() + (mBatchCount * mBatchSize * common::volume(mDims)); 76 | } 77 | 78 | float* getLabels() override 79 | { 80 | return mLabels.data() + (mBatchCount * mBatchSize); 81 | } 82 | 83 | int getBatchesRead() const override 84 | { 85 | return mBatchCount; 86 | } 87 | 88 | int getBatchSize() const override 89 | { 90 | return mBatchSize; 91 | } 92 | 93 | nvinfer1::Dims getDims() const override 94 | { 95 | return mDims; 96 | } 97 | 98 | nvinfer1::Dims getImageDims() const override 99 | { 100 | return Dims3{1, 28, 28}; 101 | } 102 | 103 | private: 104 | void readDataFile(const std::string& dataFilePath) 105 | { 106 | std::ifstream file{dataFilePath.c_str(), std::ios::binary}; 107 | 108 | int magicNumber, numImages, imageH, imageW; 109 | file.read(reinterpret_cast(&magicNumber), sizeof(magicNumber)); 110 | // All values in the MNIST files are big endian. 111 | magicNumber = common::swapEndianness(magicNumber); 112 | assert(magicNumber == 2051 && "Magic Number does not match the expected value for an MNIST image set"); 113 | 114 | // Read number of images and dimensions 115 | file.read(reinterpret_cast(&numImages), sizeof(numImages)); 116 | file.read(reinterpret_cast(&imageH), sizeof(imageH)); 117 | file.read(reinterpret_cast(&imageW), sizeof(imageW)); 118 | 119 | numImages = common::swapEndianness(numImages); 120 | imageH = common::swapEndianness(imageH); 121 | imageW = common::swapEndianness(imageW); 122 | 123 | // The MNIST data is made up of unsigned bytes, so we need to cast to float and normalize. 124 | int numElements = numImages * imageH * imageW; 125 | std::vector rawData(numElements); 126 | file.read(reinterpret_cast(rawData.data()), numElements * sizeof(uint8_t)); 127 | mData.resize(numElements); 128 | std::transform( 129 | rawData.begin(), rawData.end(), mData.begin(), [](uint8_t val) { return static_cast(val) / 255.f; }); 130 | } 131 | 132 | void readLabelsFile(const std::string& labelsFilePath) 133 | { 134 | std::ifstream file{labelsFilePath.c_str(), std::ios::binary}; 135 | int magicNumber, numImages; 136 | file.read(reinterpret_cast(&magicNumber), sizeof(magicNumber)); 137 | // All values in the MNIST files are big endian. 138 | magicNumber = common::swapEndianness(magicNumber); 139 | assert(magicNumber == 2049 && "Magic Number does not match the expected value for an MNIST labels file"); 140 | 141 | file.read(reinterpret_cast(&numImages), sizeof(numImages)); 142 | numImages = common::swapEndianness(numImages); 143 | 144 | std::vector rawLabels(numImages); 145 | file.read(reinterpret_cast(rawLabels.data()), numImages * sizeof(uint8_t)); 146 | mLabels.resize(numImages); 147 | std::transform( 148 | rawLabels.begin(), rawLabels.end(), mLabels.begin(), [](uint8_t val) { return static_cast(val); }); 149 | } 150 | 151 | int mBatchSize{0}; 152 | int mBatchCount{0}; //!< The batch that will be read on the next invocation of next() 153 | int mMaxBatches{0}; 154 | Dims mDims{}; 155 | std::vector mData{}; 156 | std::vector mLabels{}; 157 | }; 158 | 159 | class BatchStream : public IBatchStream 160 | { 161 | public: 162 | BatchStream( 163 | int batchSize, int maxBatches, std::string prefix, std::string suffix, std::vector directories) 164 | : mBatchSize(batchSize) 165 | , mMaxBatches(maxBatches) 166 | , mPrefix(prefix) 167 | , mSuffix(suffix) 168 | , mDataDir(directories) 169 | { 170 | FILE* file = fopen(locateFile(mPrefix + std::string("0") + mSuffix, mDataDir).c_str(), "rb"); 171 | assert(file != nullptr); 172 | int d[4]; 173 | size_t readSize = fread(d, sizeof(int), 4, file); 174 | assert(readSize == 4); 175 | mDims.nbDims = 4; // The number of dimensions. 176 | mDims.d[0] = d[0]; // Batch Size 177 | mDims.d[1] = d[1]; // Channels 178 | mDims.d[2] = d[2]; // Height 179 | mDims.d[3] = d[3]; // Width 180 | assert(mDims.d[0] > 0 && mDims.d[1] > 0 && mDims.d[2] > 0 && mDims.d[3] > 0); 181 | fclose(file); 182 | 183 | mImageSize = mDims.d[1] * mDims.d[2] * mDims.d[3]; 184 | mBatch.resize(mBatchSize * mImageSize, 0); 185 | mLabels.resize(mBatchSize, 0); 186 | mFileBatch.resize(mDims.d[0] * mImageSize, 0); 187 | mFileLabels.resize(mDims.d[0], 0); 188 | reset(0); 189 | } 190 | 191 | BatchStream(int batchSize, int maxBatches, std::string prefix, std::vector directories) 192 | : BatchStream(batchSize, maxBatches, prefix, ".batch", directories) 193 | { 194 | } 195 | 196 | // This constructor expects that the dimensions include the batch dimension. 197 | BatchStream(int maxBatches, nvinfer1::Dims dims, std::string listFile, std::vector directories) 198 | : mBatchSize(dims.d[0]) 199 | , mMaxBatches(maxBatches) 200 | , mDims(dims) 201 | , mListFile(listFile) 202 | , mDataDir(directories) 203 | { 204 | mImageSize = mDims.d[1] * mDims.d[2] * mDims.d[3]; 205 | mBatch.resize(mBatchSize * mImageSize, 0); 206 | mLabels.resize(mBatchSize, 0); 207 | mFileBatch.resize(mDims.d[0] * mImageSize, 0); 208 | mFileLabels.resize(mDims.d[0], 0); 209 | reset(0); 210 | } 211 | 212 | // Resets data members 213 | void reset(int firstBatch) override 214 | { 215 | mBatchCount = 0; 216 | mFileCount = 0; 217 | mFileBatchPos = mDims.d[0]; 218 | skip(firstBatch); 219 | } 220 | 221 | // Advance to next batch and return true, or return false if there is no batch left. 222 | bool next() override 223 | { 224 | if (mBatchCount == mMaxBatches) 225 | { 226 | return false; 227 | } 228 | 229 | for (int csize = 1, batchPos = 0; batchPos < mBatchSize; batchPos += csize, mFileBatchPos += csize) 230 | { 231 | assert(mFileBatchPos > 0 && mFileBatchPos <= mDims.d[0]); 232 | if (mFileBatchPos == mDims.d[0] && !update()) 233 | { 234 | return false; 235 | } 236 | 237 | // copy the smaller of: elements left to fulfill the request, or elements left in the file buffer. 238 | csize = std::min(mBatchSize - batchPos, mDims.d[0] - mFileBatchPos); 239 | std::copy_n( 240 | getFileBatch() + mFileBatchPos * mImageSize, csize * mImageSize, getBatch() + batchPos * mImageSize); 241 | std::copy_n(getFileLabels() + mFileBatchPos, csize, getLabels() + batchPos); 242 | } 243 | mBatchCount++; 244 | return true; 245 | } 246 | 247 | // Skips the batches 248 | void skip(int skipCount) override 249 | { 250 | if (mBatchSize >= mDims.d[0] && mBatchSize % mDims.d[0] == 0 && mFileBatchPos == mDims.d[0]) 251 | { 252 | mFileCount += skipCount * mBatchSize / mDims.d[0]; 253 | return; 254 | } 255 | 256 | int x = mBatchCount; 257 | for (int i = 0; i < skipCount; i++) 258 | { 259 | next(); 260 | } 261 | mBatchCount = x; 262 | } 263 | 264 | float* getBatch() override 265 | { 266 | return mBatch.data(); 267 | } 268 | 269 | float* getLabels() override 270 | { 271 | return mLabels.data(); 272 | } 273 | 274 | int getBatchesRead() const override 275 | { 276 | return mBatchCount; 277 | } 278 | 279 | int getBatchSize() const override 280 | { 281 | return mBatchSize; 282 | } 283 | 284 | nvinfer1::Dims getDims() const override 285 | { 286 | return mDims; 287 | } 288 | 289 | nvinfer1::Dims getImageDims() const override 290 | { 291 | return Dims3{mDims.d[1], mDims.d[2], mDims.d[3]}; 292 | } 293 | 294 | private: 295 | float* getFileBatch() 296 | { 297 | return mFileBatch.data(); 298 | } 299 | 300 | float* getFileLabels() { return mFileLabels.data(); } 301 | 302 | bool update() 303 | { 304 | if (mListFile.empty()) 305 | { 306 | std::string inputFileName = locateFile(mPrefix + std::to_string(mFileCount++) + mSuffix, mDataDir); 307 | FILE* file = fopen(inputFileName.c_str(), "rb"); 308 | if (!file) 309 | { 310 | return false; 311 | } 312 | 313 | int d[4]; 314 | size_t readSize = fread(d, sizeof(int), 4, file); 315 | assert(readSize == 4); 316 | assert(mDims.d[0] == d[0] && mDims.d[1] == d[1] && mDims.d[2] == d[2] && mDims.d[3] == d[3]); 317 | size_t readInputCount = fread(getFileBatch(), sizeof(float), mDims.d[0] * mImageSize, file); 318 | assert(readInputCount == size_t(mDims.d[0] * mImageSize)); 319 | size_t readLabelCount = fread(getFileLabels(), sizeof(float), mDims.d[0], file); 320 | assert(readLabelCount == 0 || readLabelCount == size_t(mDims.d[0])); 321 | 322 | fclose(file); 323 | } 324 | else 325 | { 326 | std::vector fNames; 327 | std::ifstream file(locateFile(mListFile, mDataDir), std::ios::binary); 328 | if (!file) 329 | { 330 | return false; 331 | } 332 | 333 | gLogInfo << "Batch #" << mFileCount << std::endl; 334 | file.seekg(((mBatchCount * mBatchSize)) * 7); 335 | 336 | for (int i = 1; i <= mBatchSize; i++) 337 | { 338 | std::string sName; 339 | std::getline(file, sName); 340 | sName = sName + ".ppm"; 341 | gLogInfo << "Calibrating with file " << sName << std::endl; 342 | fNames.emplace_back(sName); 343 | } 344 | 345 | mFileCount++; 346 | 347 | const int imageC = 3; 348 | const int imageH = 300; 349 | const int imageW = 300; 350 | std::vector> ppms(fNames.size()); 351 | for (uint32_t i = 0; i < fNames.size(); ++i) 352 | { 353 | readPPMFile(locateFile(fNames[i], mDataDir), ppms[i]); 354 | } 355 | 356 | std::vector data(common::volume(mDims)); 357 | const float scale = 2.0 / 255.0; 358 | const float bias = 1.0; 359 | long int volChl = mDims.d[2] * mDims.d[3]; 360 | 361 | // Normalize input data 362 | for (int i = 0, volImg = mDims.d[1] * mDims.d[2] * mDims.d[3]; i < mBatchSize; ++i) 363 | { 364 | for (int c = 0; c < mDims.d[1]; ++c) 365 | { 366 | for (int j = 0; j < volChl; ++j) 367 | { 368 | data[i * volImg + c * volChl + j] = scale * float(ppms[i].buffer[j * mDims.d[1] + c]) - bias; 369 | } 370 | } 371 | } 372 | 373 | std::copy_n(data.data(), mDims.d[0] * mImageSize, getFileBatch()); 374 | } 375 | 376 | mFileBatchPos = 0; 377 | return true; 378 | } 379 | 380 | int mBatchSize{0}; 381 | int mMaxBatches{0}; 382 | int mBatchCount{0}; 383 | int mFileCount{0}; 384 | int mFileBatchPos{0}; 385 | int mImageSize{0}; 386 | std::vector mBatch; //!< Data for the batch 387 | std::vector mLabels; //!< Labels for the batch 388 | std::vector mFileBatch; //!< List of image files 389 | std::vector mFileLabels; //!< List of label files 390 | std::string mPrefix; //!< Batch file name prefix 391 | std::string mSuffix; //!< Batch file name suffix 392 | nvinfer1::Dims mDims; //!< Input dimensions 393 | std::string mListFile; //!< File name of the list of image names 394 | std::vector mDataDir; //!< Directories where the files can be found 395 | }; 396 | 397 | #endif 398 | -------------------------------------------------------------------------------- /common/EntropyCalibrator.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef ENTROPY_CALIBRATOR_H 18 | #define ENTROPY_CALIBRATOR_H 19 | 20 | #include "BatchStream.h" 21 | #include "NvInfer.h" 22 | 23 | //! \class EntropyCalibratorImpl 24 | //! 25 | //! \brief Implements common functionality for Entropy calibrators. 26 | //! 27 | template 28 | class EntropyCalibratorImpl 29 | { 30 | public: 31 | EntropyCalibratorImpl( 32 | TBatchStream stream, int firstBatch, std::string networkName, const char* inputBlobName, bool readCache = true) 33 | : mStream{stream} 34 | , mCalibrationTableName("CalibrationTable" + networkName) 35 | , mInputBlobName(inputBlobName) 36 | , mReadCache(readCache) 37 | { 38 | nvinfer1::Dims imageDims = mStream.getImageDims(); 39 | mInputCount = common::volume(imageDims) * mStream.getBatchSize(); 40 | CHECK(cudaMalloc(&mDeviceInput, mInputCount * sizeof(float))); 41 | mStream.reset(firstBatch); 42 | } 43 | 44 | virtual ~EntropyCalibratorImpl() 45 | { 46 | CHECK(cudaFree(mDeviceInput)); 47 | } 48 | 49 | int getBatchSize() const { return mStream.getBatchSize(); } 50 | 51 | bool getBatch(void* bindings[], const char* names[], int nbBindings) 52 | { 53 | if (!mStream.next()) 54 | { 55 | return false; 56 | } 57 | CHECK(cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice)); 58 | assert(!strcmp(names[0], mInputBlobName)); 59 | bindings[0] = mDeviceInput; 60 | return true; 61 | } 62 | 63 | const void* readCalibrationCache(size_t& length) 64 | { 65 | mCalibrationCache.clear(); 66 | std::ifstream input(mCalibrationTableName, std::ios::binary); 67 | input >> std::noskipws; 68 | if (mReadCache && input.good()) 69 | { 70 | std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(mCalibrationCache)); 71 | } 72 | length = mCalibrationCache.size(); 73 | return length ? mCalibrationCache.data() : nullptr; 74 | } 75 | 76 | void writeCalibrationCache(const void* cache, size_t length) 77 | { 78 | std::ofstream output(mCalibrationTableName, std::ios::binary); 79 | output.write(reinterpret_cast(cache), length); 80 | } 81 | 82 | private: 83 | TBatchStream mStream; 84 | size_t mInputCount; 85 | std::string mCalibrationTableName; 86 | const char* mInputBlobName; 87 | bool mReadCache{true}; 88 | void* mDeviceInput{nullptr}; 89 | std::vector mCalibrationCache; 90 | }; 91 | 92 | //! \class Int8EntropyCalibrator2 93 | //! 94 | //! \brief Implements Entropy calibrator 2. 95 | //! CalibrationAlgoType is kENTROPY_CALIBRATION_2. 96 | //! 97 | template 98 | class Int8EntropyCalibrator2 : public IInt8EntropyCalibrator2 99 | { 100 | public: 101 | Int8EntropyCalibrator2( 102 | TBatchStream stream, int firstBatch, const char* networkName, const char* inputBlobName, bool readCache = true) 103 | : mImpl(stream, firstBatch, networkName, inputBlobName, readCache) 104 | { 105 | } 106 | 107 | int getBatchSize() const override { return mImpl.getBatchSize(); } 108 | 109 | bool getBatch(void* bindings[], const char* names[], int nbBindings) override 110 | { 111 | return mImpl.getBatch(bindings, names, nbBindings); 112 | } 113 | 114 | const void* readCalibrationCache(size_t& length) override 115 | { 116 | return mImpl.readCalibrationCache(length); 117 | } 118 | 119 | void writeCalibrationCache(const void* cache, size_t length) override 120 | { 121 | mImpl.writeCalibrationCache(cache, length); 122 | } 123 | 124 | private: 125 | EntropyCalibratorImpl mImpl; 126 | }; 127 | 128 | #endif // ENTROPY_CALIBRATOR_H 129 | -------------------------------------------------------------------------------- /common/ErrorRecorder.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | #ifndef ERROR_RECORDER_H 50 | #define ERROR_RECORDER_H 51 | #include 52 | #include 53 | #include 54 | #include 55 | #include 56 | #include "NvInferRuntimeCommon.h" 57 | using namespace nvinfer1; 58 | //! 59 | //! A simple imeplementation of the IErrorRecorder interface for 60 | //! use by samples. This interface also can be used as a reference 61 | //! implementation. 62 | //! The sample Error recorder is based on a vector that pairs the error 63 | //! code and the error string into a single element. It also uses 64 | //! standard mutex's and atomics in order to make sure that the code 65 | //! works in a multi-threaded environment. 66 | //! SampleErrorRecorder is not intended for use in automotive safety 67 | //! environments. 68 | //! 69 | class SampleErrorRecorder : public IErrorRecorder 70 | { 71 | using errorPair = std::pair; 72 | using errorStack = std::vector; 73 | 74 | public: 75 | SampleErrorRecorder() = default; 76 | 77 | virtual ~SampleErrorRecorder() noexcept {} 78 | int32_t getNbErrors() const noexcept final 79 | { 80 | return mErrorStack.size(); 81 | } 82 | ErrorCode getErrorCode(int32_t errorIdx) const noexcept final 83 | { 84 | return indexCheck(errorIdx) ? ErrorCode::kINVALID_ARGUMENT : (*this)[errorIdx].first; 85 | }; 86 | IErrorRecorder::ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept final 87 | { 88 | return indexCheck(errorIdx) ? "errorIdx out of range." : (*this)[errorIdx].second.c_str(); 89 | } 90 | // This class can never overflow since we have dynamic resize via std::vector usage. 91 | bool hasOverflowed() const noexcept final 92 | { 93 | return false; 94 | } 95 | 96 | // Empty the errorStack. 97 | void clear() noexcept final 98 | { 99 | try 100 | { 101 | // grab a lock so that there is no addition while clearing. 102 | std::lock_guard guard(mStackLock); 103 | mErrorStack.clear(); 104 | } 105 | catch (const std::exception& e) 106 | { 107 | getLogger()->log(ILogger::Severity::kINTERNAL_ERROR, e.what()); 108 | } 109 | }; 110 | 111 | //! Simple helper function that 112 | bool empty() const noexcept 113 | { 114 | return mErrorStack.empty(); 115 | } 116 | 117 | bool reportError(ErrorCode val, IErrorRecorder::ErrorDesc desc) noexcept final { 118 | try 119 | { 120 | std::lock_guard guard(mStackLock); 121 | mErrorStack.push_back(errorPair(val, desc)); 122 | } 123 | catch(const std::exception& e) 124 | { 125 | getLogger()->log(ILogger::Severity::kINTERNAL_ERROR, e.what()); 126 | } 127 | // All errors are considered fatal. 128 | return true; 129 | } 130 | 131 | // Atomically increment or decrement the ref counter. 132 | IErrorRecorder::RefCount incRefCount() noexcept final 133 | { 134 | return ++mRefCount; 135 | } 136 | IErrorRecorder::RefCount decRefCount() noexcept final 137 | { 138 | return --mRefCount; 139 | } 140 | 141 | private: 142 | // Simple helper functions. 143 | const errorPair& operator[](size_t index) const noexcept 144 | { 145 | return mErrorStack[index]; 146 | } 147 | 148 | bool indexCheck(int32_t index) const noexcept 149 | { 150 | // By converting signed to unsigned, we only need a single check since 151 | // negative numbers turn into large positive greater than the size. 152 | size_t sIndex = index; 153 | return sIndex >= mErrorStack.size(); 154 | } 155 | // Mutex to hold when locking mErrorStack. 156 | std::mutex mStackLock; 157 | 158 | // Reference count of the class. Destruction of the class when mRefCount 159 | // is not zero causes undefined behavior. 160 | std::atomic mRefCount{0}; 161 | 162 | // The error stack that holds the errors recorded by TensorRT. 163 | errorStack mErrorStack; 164 | }; // class SampleErrorRecorder 165 | #endif // ERROR_RECORDER_H 166 | -------------------------------------------------------------------------------- /common/argsParser.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef TENSORRT_ARGS_PARSER_H 51 | #define TENSORRT_ARGS_PARSER_H 52 | 53 | #include 54 | #include 55 | #ifdef _MSC_VER 56 | #include "..\common\windows\getopt.h" 57 | #else 58 | #include 59 | #endif 60 | #include 61 | 62 | namespace common 63 | { 64 | 65 | //! 66 | //! \brief The SampleParams structure groups the basic parameters required by 67 | //! all sample networks. 68 | //! 69 | struct Params 70 | { 71 | int batchSize{1}; //!< Number of inputs in a batch 72 | int dlaCore{-1}; //!< Specify the DLA core to run network on. 73 | bool int8{false}; //!< Allow runnning the network in Int8 mode. 74 | bool fp16{false}; //!< Allow running the network in FP16 mode. 75 | std::vector dataDirs; //!< Directory paths where sample data files are stored 76 | std::vector inputTensorNames; 77 | std::vector outputTensorNames; 78 | }; 79 | 80 | //! 81 | //! \brief The CaffeParams structure groups the additional parameters required by 82 | //! networks that use caffe 83 | //! 84 | struct CaffeParams : public Params 85 | { 86 | std::string prototxtFileName; //!< Filename of prototxt design file of a network 87 | std::string weightsFileName; //!< Filename of trained weights file of a network 88 | std::string meanFileName; //!< Filename of mean file of a network 89 | }; 90 | 91 | 92 | //! 93 | //! \brief The OnnxParams structure groups the additional parameters required by 94 | //! networks that use ONNX 95 | //! 96 | struct OnnxParams : public Params 97 | { 98 | std::string onnxFileName; //!< Filename of ONNX file of a network 99 | }; 100 | 101 | //! 102 | //! \brief The UffSampleParams structure groups the additional parameters required by 103 | //! networks that use Uff 104 | //! 105 | struct UffParams : public Params 106 | { 107 | std::string uffFileName; //!< Filename of uff file of a network 108 | }; 109 | 110 | //! 111 | //! /brief Struct to maintain command-line arguments. 112 | //! 113 | struct Args 114 | { 115 | bool runInInt8{false}; 116 | bool runInFp16{false}; 117 | bool help{false}; 118 | int useDLACore{-1}; 119 | std::vector dataDirs; 120 | }; 121 | 122 | //! 123 | //! \brief Populates the Args struct with the provided command-line parameters. 124 | //! 125 | //! \throw invalid_argument if any of the arguments are not valid 126 | //! 127 | //! \return boolean If return value is true, execution can continue, otherwise program should exit 128 | //! 129 | inline bool parseArgs(Args& args, int argc, char* argv[]) 130 | { 131 | while (1) 132 | { 133 | int arg; 134 | static struct option long_options[] = { 135 | {"help", no_argument, 0, 'h'}, 136 | {"datadir", required_argument, 0, 'd'}, 137 | {"int8", no_argument, 0, 'i'}, 138 | {"fp16", no_argument, 0, 'f'}, 139 | {"useDLACore", required_argument, 0, 'u'}, 140 | {nullptr, 0, nullptr, 0}}; 141 | int option_index = 0; 142 | arg = getopt_long(argc, argv, "hd:iu", long_options, &option_index); 143 | if (arg == -1) 144 | { 145 | break; 146 | } 147 | 148 | switch (arg) 149 | { 150 | case 'h': 151 | args.help = true; 152 | return true; 153 | case 'd': 154 | if (optarg) 155 | { 156 | args.dataDirs.push_back(optarg); 157 | } 158 | else 159 | { 160 | std::cerr << "ERROR: --datadir requires option argument" << std::endl; 161 | return false; 162 | } 163 | break; 164 | case 'i': 165 | args.runInInt8 = true; 166 | break; 167 | case 'f': 168 | args.runInFp16 = true; 169 | break; 170 | case 'u': 171 | if (optarg) 172 | { 173 | args.useDLACore = std::stoi(optarg); 174 | } 175 | break; 176 | default: 177 | return false; 178 | } 179 | } 180 | return true; 181 | } 182 | 183 | } // namespace common 184 | 185 | #endif // TENSORRT_ARGS_PARSER_H 186 | -------------------------------------------------------------------------------- /common/buffers.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef TENSORRT_BUFFERS_H 51 | #define TENSORRT_BUFFERS_H 52 | 53 | #include "NvInfer.h" 54 | #include "half.h" 55 | #include "common.h" 56 | #include 57 | #include 58 | #include 59 | #include 60 | #include 61 | #include 62 | #include 63 | #include 64 | #include 65 | 66 | using namespace std; 67 | 68 | namespace common 69 | { 70 | 71 | //! 72 | //! \brief The GenericBuffer class is a templated class for buffers. 73 | //! 74 | //! \details This templated RAII (Resource Acquisition Is Initialization) class handles the allocation, 75 | //! deallocation, querying of buffers on both the device and the host. 76 | //! It can handle data of arbitrary types because it stores byte buffers. 77 | //! The template parameters AllocFunc and FreeFunc are used for the 78 | //! allocation and deallocation of the buffer. 79 | //! AllocFunc must be a functor that takes in (void** ptr, size_t size) 80 | //! and returns bool. ptr is a pointer to where the allocated buffer address should be stored. 81 | //! size is the amount of memory in bytes to allocate. 82 | //! The boolean indicates whether or not the memory allocation was successful. 83 | //! FreeFunc must be a functor that takes in (void* ptr) and returns void. 84 | //! ptr is the allocated buffer address. It must work with nullptr input. 85 | //! 86 | template 87 | class GenericBuffer 88 | { 89 | public: 90 | //! 91 | //! \brief Construct an empty buffer. 92 | //! 93 | GenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT) 94 | : mSize(0) 95 | , mCapacity(0) 96 | , mType(type) 97 | , mBuffer(nullptr) 98 | { 99 | } 100 | 101 | //! 102 | //! \brief Construct a buffer with the specified allocation size in bytes. 103 | //! 104 | GenericBuffer(size_t size, nvinfer1::DataType type) 105 | : mSize(size) 106 | , mCapacity(size) 107 | , mType(type) 108 | { 109 | if (!allocFn(&mBuffer, this->nbBytes())) 110 | { 111 | throw std::bad_alloc(); 112 | } 113 | } 114 | 115 | GenericBuffer(GenericBuffer&& buf) 116 | : mSize(buf.mSize) 117 | , mCapacity(buf.mCapacity) 118 | , mType(buf.mType) 119 | , mBuffer(buf.mBuffer) 120 | { 121 | buf.mSize = 0; 122 | buf.mCapacity = 0; 123 | buf.mType = nvinfer1::DataType::kFLOAT; 124 | buf.mBuffer = nullptr; 125 | } 126 | 127 | GenericBuffer& operator=(GenericBuffer&& buf) 128 | { 129 | if (this != &buf) 130 | { 131 | freeFn(mBuffer); 132 | mSize = buf.mSize; 133 | mCapacity = buf.mCapacity; 134 | mType = buf.mType; 135 | mBuffer = buf.mBuffer; 136 | // Reset buf. 137 | buf.mSize = 0; 138 | buf.mCapacity = 0; 139 | buf.mBuffer = nullptr; 140 | } 141 | return *this; 142 | } 143 | 144 | //! 145 | //! \brief Returns pointer to underlying array. 146 | //! 147 | void* data() { return mBuffer; } 148 | 149 | //! 150 | //! \brief Returns pointer to underlying array. 151 | //! 152 | const void* data() const 153 | { 154 | return mBuffer; 155 | } 156 | 157 | //! 158 | //! \brief Returns the size (in number of elements) of the buffer. 159 | //! 160 | size_t size() const 161 | { 162 | return mSize; 163 | } 164 | 165 | //! 166 | //! \brief Returns the size (in bytes) of the buffer. 167 | //! 168 | size_t nbBytes() const 169 | { 170 | return this->size() * common::getElementSize(mType); 171 | } 172 | 173 | //! 174 | //! \brief Resizes the buffer. This is a no-op if the new size is smaller than or equal to the current capacity. 175 | //! 176 | void resize(size_t newSize) 177 | { 178 | mSize = newSize; 179 | if (mCapacity < newSize) 180 | { 181 | freeFn(mBuffer); 182 | if (!allocFn(&mBuffer, this->nbBytes())) 183 | { 184 | throw std::bad_alloc{}; 185 | } 186 | mCapacity = newSize; 187 | } 188 | } 189 | 190 | //! 191 | //! \brief Overload of resize that accepts Dims 192 | //! 193 | void resize(const nvinfer1::Dims& dims) 194 | { 195 | return this->resize(common::volume(dims)); 196 | } 197 | 198 | ~GenericBuffer() 199 | { 200 | freeFn(mBuffer); 201 | } 202 | 203 | private: 204 | size_t mSize{0}, mCapacity{0}; 205 | nvinfer1::DataType mType; 206 | void* mBuffer; 207 | AllocFunc allocFn; 208 | FreeFunc freeFn; 209 | }; 210 | 211 | class DeviceAllocator 212 | { 213 | public: 214 | bool operator()(void** ptr, size_t size) const { return cudaMalloc(ptr, size) == cudaSuccess; } 215 | }; 216 | 217 | class DeviceFree 218 | { 219 | public: 220 | void operator()(void* ptr) const { cudaFree(ptr); } 221 | }; 222 | 223 | class HostAllocator 224 | { 225 | public: 226 | bool operator()(void** ptr, size_t size) const 227 | { 228 | *ptr = malloc(size); 229 | return *ptr != nullptr; 230 | } 231 | }; 232 | 233 | class HostFree 234 | { 235 | public: 236 | void operator()(void* ptr) const { free(ptr); } 237 | }; 238 | 239 | using DeviceBuffer = GenericBuffer; 240 | using HostBuffer = GenericBuffer; 241 | 242 | //! 243 | //! \brief The ManagedBuffer class groups together a pair of corresponding device and host buffers. 244 | //! 245 | class ManagedBuffer 246 | { 247 | public: 248 | DeviceBuffer deviceBuffer; 249 | HostBuffer hostBuffer; 250 | }; 251 | 252 | //! 253 | //! \brief The BufferManager class handles host and device buffer allocation and deallocation. 254 | //! 255 | //! \details This RAII class handles host and device buffer allocation and deallocation, 256 | //! memcpy between host and device buffers to aid with inference, 257 | //! and debugging dumps to validate inference. The BufferManager class is meant to be 258 | //! used to simplify buffer management and any interactions between buffers and the engine. 259 | //! 260 | class BufferManager 261 | { 262 | public: 263 | static const size_t kINVALID_SIZE_VALUE = ~size_t(0); 264 | 265 | //! 266 | //! \brief Create a BufferManager for handling buffer interactions with engine. 267 | //! 268 | BufferManager(std::shared_ptr engine, const int& batchSize, const nvinfer1::IExecutionContext* context = nullptr) 269 | : mEngine(engine) 270 | , mBatchSize(batchSize) 271 | { 272 | // Create host and device buffers 273 | for (int i = 0; i < mEngine->getNbBindings(); i++) 274 | { 275 | auto dims = context ? context->getBindingDimensions(i) : mEngine->getBindingDimensions(i); 276 | size_t vol = context ? 1 : static_cast(mBatchSize); 277 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 278 | int vecDim = mEngine->getBindingVectorizedDim(i); 279 | if (-1 != vecDim) // i.e., 0 != lgScalarsPerVector 280 | { 281 | int scalarsPerVec = mEngine->getBindingComponentsPerElement(i); 282 | dims.d[vecDim] = divUp(dims.d[vecDim], scalarsPerVec); 283 | vol *= scalarsPerVec; 284 | } 285 | vol *= common::volume(dims); 286 | std::unique_ptr manBuf{new ManagedBuffer()}; 287 | manBuf->deviceBuffer = DeviceBuffer(vol, type); 288 | manBuf->hostBuffer = HostBuffer(vol, type); 289 | mDeviceBindings.emplace_back(manBuf->deviceBuffer.data()); 290 | mManagedBuffers.emplace_back(std::move(manBuf)); 291 | } 292 | } 293 | 294 | //! 295 | //! \brief Returns a vector of device buffers that you can use directly as 296 | //! bindings for the execute and enqueue methods of IExecutionContext. 297 | //! 298 | std::vector& getDeviceBindings() { return mDeviceBindings; } 299 | 300 | //! 301 | //! \brief Returns a vector of device buffers. 302 | //! 303 | const std::vector& getDeviceBindings() const { return mDeviceBindings; } 304 | 305 | //! 306 | //! \brief Returns the device buffer corresponding to tensorName. 307 | //! Returns nullptr if no such tensor can be found. 308 | //! 309 | void* getDeviceBuffer(const std::string& tensorName) const { return getBuffer(false, tensorName); } 310 | 311 | //! 312 | //! \brief Returns the host buffer corresponding to tensorName. 313 | //! Returns nullptr if no such tensor can be found. 314 | //! 315 | void* getHostBuffer(const std::string& tensorName) const { return getBuffer(true, tensorName); } 316 | 317 | //! 318 | //! \brief Returns the size of the host and device buffers that correspond to tensorName. 319 | //! Returns kINVALID_SIZE_VALUE if no such tensor can be found. 320 | //! 321 | size_t size(const std::string& tensorName) const 322 | { 323 | int index = mEngine->getBindingIndex(tensorName.c_str()); 324 | if (index == -1) 325 | return kINVALID_SIZE_VALUE; 326 | return mManagedBuffers[index]->hostBuffer.nbBytes(); 327 | } 328 | 329 | //! 330 | //! \brief Dump host buffer with specified tensorName to ostream. 331 | //! Prints error message to std::ostream if no such tensor can be found. 332 | //! 333 | void dumpBuffer(std::ostream& os, const std::string& tensorName) 334 | { 335 | int index = mEngine->getBindingIndex(tensorName.c_str()); 336 | if (index == -1) 337 | { 338 | os << "Invalid tensor name" << std::endl; 339 | return; 340 | } 341 | void* buf = mManagedBuffers[index]->hostBuffer.data(); 342 | size_t bufSize = mManagedBuffers[index]->hostBuffer.nbBytes(); 343 | nvinfer1::Dims bufDims = mEngine->getBindingDimensions(index); 344 | size_t rowCount = static_cast(bufDims.nbDims >= 1 ? bufDims.d[bufDims.nbDims - 1] : mBatchSize); 345 | 346 | os << "[" << mBatchSize; 347 | for (int i = 0; i < bufDims.nbDims; i++) 348 | os << ", " << bufDims.d[i]; 349 | os << "]" << std::endl; 350 | switch (mEngine->getBindingDataType(index)) 351 | { 352 | case nvinfer1::DataType::kINT32: print(os, buf, bufSize, rowCount); break; 353 | case nvinfer1::DataType::kFLOAT: print(os, buf, bufSize, rowCount); break; 354 | case nvinfer1::DataType::kHALF: print(os, buf, bufSize, rowCount); break; 355 | case nvinfer1::DataType::kINT8: assert(0 && "Int8 network-level input and output is not supported"); break; 356 | } 357 | } 358 | 359 | //! 360 | //! \brief Templated print function that dumps buffers of arbitrary type to std::ostream. 361 | //! rowCount parameter controls how many elements are on each line. 362 | //! A rowCount of 1 means that there is only 1 element on each line. 363 | //! 364 | template 365 | void print(std::ostream& os, void* buf, size_t bufSize, size_t rowCount) 366 | { 367 | assert(rowCount != 0); 368 | assert(bufSize % sizeof(T) == 0); 369 | T* typedBuf = static_cast(buf); 370 | size_t numItems = bufSize / sizeof(T); 371 | for (int i = 0; i < static_cast(numItems); i++) 372 | { 373 | // Handle rowCount == 1 case 374 | if (rowCount == 1 && i != static_cast(numItems) - 1) 375 | os << typedBuf[i] << std::endl; 376 | else if (rowCount == 1) 377 | os << typedBuf[i]; 378 | // Handle rowCount > 1 case 379 | else if (i % rowCount == 0) 380 | os << typedBuf[i]; 381 | else if (i % rowCount == rowCount - 1) 382 | os << " " << typedBuf[i] << std::endl; 383 | else 384 | os << " " << typedBuf[i]; 385 | } 386 | } 387 | 388 | //! 389 | //! \brief Copy the contents of input host buffers to input device buffers synchronously. 390 | //! 391 | void copyInputToDevice() { memcpyBuffers(true, false, false); } 392 | 393 | //! 394 | //! \brief Copy the contents of output device buffers to output host buffers synchronously. 395 | //! 396 | void copyOutputToHost() { memcpyBuffers(false, true, false); } 397 | 398 | //! 399 | //! \brief Copy the contents of input host buffers to input device buffers asynchronously. 400 | //! 401 | void copyInputToDeviceAsync(const cudaStream_t& stream = 0) { memcpyBuffers(true, false, true, stream); } 402 | 403 | //! 404 | //! \brief Copy the contents of output device buffers to output host buffers asynchronously. 405 | //! 406 | void copyOutputToHostAsync(const cudaStream_t& stream = 0) { memcpyBuffers(false, true, true, stream); } 407 | 408 | ~BufferManager() = default; 409 | 410 | private: 411 | 412 | void* getBuffer(const bool isHost, const std::string& tensorName) const 413 | { 414 | int index = mEngine->getBindingIndex(tensorName.c_str()); 415 | if (index == -1) 416 | return nullptr; 417 | return (isHost ? mManagedBuffers[index]->hostBuffer.data() : mManagedBuffers[index]->deviceBuffer.data()); 418 | } 419 | 420 | void memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream = 0) 421 | { 422 | for (int i = 0; i < mEngine->getNbBindings(); i++) 423 | { 424 | void* dstPtr 425 | = deviceToHost ? mManagedBuffers[i]->hostBuffer.data() : mManagedBuffers[i]->deviceBuffer.data(); 426 | const void* srcPtr 427 | = deviceToHost ? mManagedBuffers[i]->deviceBuffer.data() : mManagedBuffers[i]->hostBuffer.data(); 428 | const size_t byteSize = mManagedBuffers[i]->hostBuffer.nbBytes(); 429 | const cudaMemcpyKind memcpyType = deviceToHost ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; 430 | if ((copyInput && mEngine->bindingIsInput(i)) || (!copyInput && !mEngine->bindingIsInput(i))) 431 | { 432 | if (async) 433 | CHECK(cudaMemcpyAsync(dstPtr, srcPtr, byteSize, memcpyType, stream)); 434 | else 435 | CHECK(cudaMemcpy(dstPtr, srcPtr, byteSize, memcpyType)); 436 | } 437 | } 438 | } 439 | 440 | std::shared_ptr mEngine; //!< The pointer to the engine 441 | int mBatchSize; //!< The batch size 442 | std::vector> mManagedBuffers; //!< The vector of pointers to managed buffers 443 | std::vector mDeviceBindings; //!< The vector of device buffers needed for engine execution 444 | }; 445 | 446 | } // namespace common 447 | 448 | #endif // TENSORRT_BUFFERS_H 449 | -------------------------------------------------------------------------------- /common/common.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef TENSORRT_COMMON_H 51 | #define TENSORRT_COMMON_H 52 | 53 | // For loadLibrary 54 | #ifdef _MSC_VER 55 | // Needed so that the max/min definitions in windows.h do not conflict with std::max/min. 56 | #define NOMINMAX 57 | #include 58 | #undef NOMINMAX 59 | #else 60 | #include 61 | #endif 62 | 63 | #include "NvInfer.h" 64 | #include "NvInferPlugin.h" 65 | #include "logger.h" 66 | #include 67 | #include 68 | #include 69 | #include 70 | #include 71 | #include 72 | #include 73 | #include 74 | #include 75 | #include 76 | #include 77 | #include 78 | #include 79 | #include 80 | #include 81 | #include 82 | #include 83 | #include 84 | #include 85 | 86 | using namespace nvinfer1; 87 | using namespace plugin; 88 | 89 | #ifdef _MSC_VER 90 | #define FN_NAME __FUNCTION__ 91 | #else 92 | #define FN_NAME __func__ 93 | #endif 94 | 95 | #if (!defined(__ANDROID__) && defined(__aarch64__)) || defined(__QNX__) 96 | #define ENABLE_DLA_API 1 97 | #endif 98 | 99 | #define CHECK(status) \ 100 | do \ 101 | { \ 102 | auto ret = (status); \ 103 | if (ret != 0) \ 104 | { \ 105 | std::cerr << "Cuda failure: " << ret << std::endl; \ 106 | abort(); \ 107 | } \ 108 | } while (0) 109 | 110 | #define CHECK_RETURN_W_MSG(status, val, errMsg) \ 111 | do \ 112 | { \ 113 | if (!(status)) \ 114 | { \ 115 | std::cerr << errMsg << " Error in " << __FILE__ \ 116 | << ", function " << FN_NAME << "(), line " \ 117 | << __LINE__ << std::endl; \ 118 | return val; \ 119 | } \ 120 | } while (0) 121 | 122 | #define CHECK_RETURN(status, val) \ 123 | CHECK_RETURN_W_MSG(status, val, "") 124 | 125 | #define OBJ_GUARD(A) std::unique_ptr 126 | 127 | template 128 | OBJ_GUARD(T) makeObjGuard(T_* t) 129 | { 130 | CHECK(!(std::is_base_of::value || std::is_same::value)); 131 | auto deleter = [](T* t) { t->destroy(); }; 132 | return std::unique_ptr{static_cast(t), deleter}; 133 | } 134 | 135 | constexpr long double operator"" _GiB(long double val) 136 | { 137 | return val * (1 << 30); 138 | } 139 | constexpr long double operator"" _MiB(long double val) { return val * (1 << 20); } 140 | constexpr long double operator"" _KiB(long double val) { return val * (1 << 10); } 141 | 142 | // These is necessary if we want to be able to write 1_GiB instead of 1.0_GiB. 143 | // Since the return type is signed, -1_GiB will work as expected. 144 | constexpr long long int operator"" _GiB(long long unsigned int val) { return val * (1 << 30); } 145 | constexpr long long int operator"" _MiB(long long unsigned int val) { return val * (1 << 20); } 146 | constexpr long long int operator"" _KiB(long long unsigned int val) { return val * (1 << 10); } 147 | 148 | struct SimpleProfiler : public nvinfer1::IProfiler 149 | { 150 | struct Record 151 | { 152 | float time{0}; 153 | int count{0}; 154 | }; 155 | 156 | virtual void reportLayerTime(const char* layerName, float ms) 157 | { 158 | mProfile[layerName].count++; 159 | mProfile[layerName].time += ms; 160 | if (std::find(mLayerNames.begin(), mLayerNames.end(), layerName) 161 | == mLayerNames.end()) 162 | { 163 | mLayerNames.push_back(layerName); 164 | } 165 | } 166 | 167 | SimpleProfiler( 168 | const char* name, 169 | const std::vector& srcProfilers = std::vector()) 170 | : mName(name) 171 | { 172 | for (const auto& srcProfiler : srcProfilers) 173 | { 174 | for (const auto& rec : srcProfiler.mProfile) 175 | { 176 | auto it = mProfile.find(rec.first); 177 | if (it == mProfile.end()) 178 | { 179 | mProfile.insert(rec); 180 | } 181 | else 182 | { 183 | it->second.time += rec.second.time; 184 | it->second.count += rec.second.count; 185 | } 186 | } 187 | } 188 | } 189 | 190 | friend std::ostream& operator<<(std::ostream& out, const SimpleProfiler& value) 191 | { 192 | out << "========== " << value.mName << " profile ==========" << std::endl; 193 | float totalTime = 0; 194 | std::string layerNameStr = "TensorRT layer name"; 195 | int maxLayerNameLength = std::max(static_cast(layerNameStr.size()), 70); 196 | for (const auto& elem : value.mProfile) 197 | { 198 | totalTime += elem.second.time; 199 | maxLayerNameLength = std::max(maxLayerNameLength, static_cast(elem.first.size())); 200 | } 201 | 202 | auto old_settings = out.flags(); 203 | auto old_precision = out.precision(); 204 | // Output header 205 | { 206 | out << std::setw(maxLayerNameLength) << layerNameStr << " "; 207 | out << std::setw(12) << "Runtime, " 208 | << "%" 209 | << " "; 210 | out << std::setw(12) << "Invocations" 211 | << " "; 212 | out << std::setw(12) << "Runtime, ms" << std::endl; 213 | } 214 | for (size_t i = 0; i < value.mLayerNames.size(); i++) 215 | { 216 | const std::string layerName = value.mLayerNames[i]; 217 | auto elem = value.mProfile.at(layerName); 218 | out << std::setw(maxLayerNameLength) << layerName << " "; 219 | out << std::setw(12) << std::fixed << std::setprecision(1) << (elem.time * 100.0F / totalTime) << "%" 220 | << " "; 221 | out << std::setw(12) << elem.count << " "; 222 | out << std::setw(12) << std::fixed << std::setprecision(2) << elem.time << std::endl; 223 | } 224 | out.flags(old_settings); 225 | out.precision(old_precision); 226 | out << "========== " << value.mName << " total runtime = " << totalTime << " ms ==========" << std::endl; 227 | 228 | return out; 229 | } 230 | 231 | private: 232 | std::string mName; 233 | std::vector mLayerNames; 234 | std::map mProfile; 235 | }; 236 | 237 | // Locate path to file, given its filename or filepath suffix and possible dirs it might lie in 238 | // Function will also walk back MAX_DEPTH dirs from CWD to check for such a file path 239 | inline std::string locateFile(const std::string& filepathSuffix, const std::vector& directories) 240 | { 241 | const int MAX_DEPTH{10}; 242 | bool found{false}; 243 | std::string filepath; 244 | 245 | for (auto& dir : directories) 246 | { 247 | if (!dir.empty() && dir.back() != '/') 248 | { 249 | #ifdef _MSC_VER 250 | filepath = dir + "\\" + filepathSuffix; 251 | #else 252 | filepath = dir + "/" + filepathSuffix; 253 | #endif 254 | } 255 | else 256 | filepath = dir + filepathSuffix; 257 | 258 | for (int i = 0; i < MAX_DEPTH && !found; i++) 259 | { 260 | std::ifstream checkFile(filepath); 261 | found = checkFile.is_open(); 262 | if (found) 263 | break; 264 | filepath = "../" + filepath; // Try again in parent dir 265 | } 266 | 267 | if (found) 268 | { 269 | break; 270 | } 271 | 272 | filepath.clear(); 273 | } 274 | 275 | if (filepath.empty()) 276 | { 277 | std::string directoryList = std::accumulate(directories.begin() + 1, directories.end(), directories.front(), 278 | [](const std::string& a, const std::string& b) { return a + "\n\t" + b; }); 279 | std::cout << "Could not find " << filepathSuffix << " in data directories:\n\t" << directoryList << std::endl; 280 | std::cout << "&&&& FAILED" << std::endl; 281 | exit(EXIT_FAILURE); 282 | } 283 | return filepath; 284 | } 285 | 286 | inline void readPGMFile(const std::string& fileName, uint8_t* buffer, int inH, int inW) 287 | { 288 | std::ifstream infile(fileName, std::ifstream::binary); 289 | assert(infile.is_open() && "Attempting to read from a file that is not open."); 290 | std::string magic, h, w, max; 291 | infile >> magic >> h >> w >> max; 292 | infile.seekg(1, infile.cur); 293 | infile.read(reinterpret_cast(buffer), inH * inW); 294 | } 295 | 296 | namespace common 297 | { 298 | 299 | class HostMemory : public IHostMemory 300 | { 301 | public: 302 | HostMemory() = delete; 303 | void* data() const noexcept override { return mData; } 304 | std::size_t size() const noexcept override { return mSize; } 305 | DataType type() const noexcept override { return mType; } 306 | protected: 307 | HostMemory(std::size_t size, DataType type) 308 | : mSize(size) 309 | , mType(type) 310 | { 311 | } 312 | void* mData; 313 | std::size_t mSize; 314 | DataType mType; 315 | 316 | }; 317 | 318 | template 319 | class TypedHostMemory : public HostMemory 320 | { 321 | public: 322 | TypedHostMemory(std::size_t size) 323 | : HostMemory(size, dataType) 324 | { 325 | mData = new ElemType[size]; 326 | }; 327 | void destroy() noexcept override 328 | { 329 | delete[](ElemType*) mData; 330 | delete this; 331 | } 332 | ElemType* raw() noexcept { return static_cast(data()); } 333 | }; 334 | 335 | using FloatMemory = TypedHostMemory; 336 | using HalfMemory = TypedHostMemory; 337 | using ByteMemory = TypedHostMemory; 338 | 339 | // Swaps endianness of an integral type. 340 | template ::value, int>::type = 0> 341 | inline T swapEndianness(const T& value) 342 | { 343 | uint8_t bytes[sizeof(T)]; 344 | for (int i = 0; i < static_cast(sizeof(T)); ++i) 345 | { 346 | bytes[sizeof(T) - 1 - i] = *(reinterpret_cast(&value) + i); 347 | } 348 | return *reinterpret_cast(bytes); 349 | } 350 | 351 | inline void* safeCudaMalloc(size_t memSize) 352 | { 353 | void* deviceMem; 354 | CHECK(cudaMalloc(&deviceMem, memSize)); 355 | if (deviceMem == nullptr) 356 | { 357 | std::cerr << "Out of memory" << std::endl; 358 | exit(1); 359 | } 360 | return deviceMem; 361 | } 362 | 363 | inline bool isDebug() 364 | { 365 | return (std::getenv("TENSORRT_DEBUG") ? true : false); 366 | } 367 | 368 | struct InferDeleter 369 | { 370 | template 371 | void operator()(T* obj) const 372 | { 373 | if (obj) 374 | { 375 | obj->destroy(); 376 | } 377 | } 378 | }; 379 | 380 | template 381 | inline std::shared_ptr infer_object(T* obj) 382 | { 383 | if (!obj) 384 | { 385 | throw std::runtime_error("Failed to create object"); 386 | } 387 | return std::shared_ptr(obj, InferDeleter()); 388 | } 389 | 390 | template 391 | inline std::vector argsort(Iter begin, Iter end, bool reverse = false) 392 | { 393 | std::vector inds(end - begin); 394 | std::iota(inds.begin(), inds.end(), 0); 395 | if (reverse) 396 | { 397 | std::sort(inds.begin(), inds.end(), [&begin](size_t i1, size_t i2) { 398 | return begin[i2] < begin[i1]; 399 | }); 400 | } 401 | else 402 | { 403 | std::sort(inds.begin(), inds.end(), [&begin](size_t i1, size_t i2) { 404 | return begin[i1] < begin[i2]; 405 | }); 406 | } 407 | return inds; 408 | } 409 | 410 | inline bool readReferenceFile(const std::string& fileName, std::vector& refVector) 411 | { 412 | std::ifstream infile(fileName); 413 | if (!infile.is_open()) 414 | { 415 | std::cout << "ERROR: readReferenceFile: Attempting to read from a file that is not open." << std::endl; 416 | return false; 417 | } 418 | std::string line; 419 | while (std::getline(infile, line)) 420 | { 421 | if (line.empty()) 422 | continue; 423 | refVector.push_back(line); 424 | } 425 | infile.close(); 426 | return true; 427 | } 428 | 429 | template 430 | inline std::vector classify(const std::vector& refVector, const result_vector_t& output, const size_t topK) 431 | { 432 | auto inds = common::argsort(output.cbegin(), output.cend(), true); 433 | std::vector result; 434 | for (size_t k = 0; k < topK; ++k) 435 | { 436 | result.push_back(refVector[inds[k]]); 437 | } 438 | return result; 439 | } 440 | 441 | //...LG returns top K indices, not values. 442 | template 443 | inline std::vector topK(const std::vector inp, const size_t k) 444 | { 445 | std::vector result; 446 | std::vector inds = common::argsort(inp.cbegin(), inp.cend(), true); 447 | result.assign(inds.begin(), inds.begin() + k); 448 | return result; 449 | } 450 | 451 | template 452 | inline bool readASCIIFile(const std::string& fileName, const size_t size, std::vector& out) 453 | { 454 | std::ifstream infile(fileName); 455 | if (!infile.is_open()) 456 | { 457 | std::cout << "ERROR readASCIIFile: Attempting to read from a file that is not open." << std::endl; 458 | return false; 459 | } 460 | out.clear(); 461 | out.reserve(size); 462 | out.assign(std::istream_iterator(infile), std::istream_iterator()); 463 | infile.close(); 464 | return true; 465 | } 466 | 467 | template 468 | inline bool writeASCIIFile(const std::string& fileName, const std::vector& in) 469 | { 470 | std::ofstream outfile(fileName); 471 | if (!outfile.is_open()) 472 | { 473 | std::cout << "ERROR: writeASCIIFile: Attempting to write to a file that is not open." << std::endl; 474 | return false; 475 | } 476 | for (auto fn : in) 477 | { 478 | outfile << fn << "\n"; 479 | } 480 | outfile.close(); 481 | return true; 482 | } 483 | 484 | inline void print_version() 485 | { 486 | std::cout << " TensorRT version: " 487 | << NV_TENSORRT_MAJOR << "." 488 | << NV_TENSORRT_MINOR << "." 489 | << NV_TENSORRT_PATCH << "." 490 | << NV_TENSORRT_BUILD << std::endl; 491 | } 492 | 493 | inline std::string getFileType(const std::string& filepath) 494 | { 495 | return filepath.substr(filepath.find_last_of(".") + 1); 496 | } 497 | 498 | inline std::string toLower(const std::string& inp) 499 | { 500 | std::string out = inp; 501 | std::transform(out.begin(), out.end(), out.begin(), ::tolower); 502 | return out; 503 | } 504 | 505 | inline float getMaxValue(const float* buffer, int64_t size) 506 | { 507 | assert(buffer != nullptr); 508 | assert(size > 0); 509 | return *std::max_element(buffer, buffer + size); 510 | } 511 | 512 | // Ensures that every tensor used by a network has a scale. 513 | // 514 | // All tensors in a network must have a range specified if a calibrator is not used. 515 | // This function is just a utility to globally fill in missing scales for the entire network. 516 | // 517 | // If a tensor does not have a scale, it is assigned inScales or outScales as follows: 518 | // 519 | // * If the tensor is the input to a layer or output of a pooling node, its scale is assigned inScales. 520 | // * Otherwise its scale is assigned outScales. 521 | // 522 | // The default parameter values are intended to demonstrate, for final layers in the network, 523 | // cases where scaling factors are asymmetric. 524 | inline void setAllTensorScales(INetworkDefinition* network, float inScales = 2.0f, float outScales = 4.0f) 525 | { 526 | // Ensure that all layer inputs have a scale. 527 | for (int i = 0; i < network->getNbLayers(); i++) 528 | { 529 | auto layer = network->getLayer(i); 530 | for (int j = 0; j < layer->getNbInputs(); j++) 531 | { 532 | ITensor* input{layer->getInput(j)}; 533 | // Optional inputs are nullptr here and are from RNN layers. 534 | if (input != nullptr && !input->dynamicRangeIsSet()) 535 | { 536 | input->setDynamicRange(-inScales, inScales); 537 | } 538 | } 539 | } 540 | 541 | // Ensure that all layer outputs have a scale. 542 | // Tensors that are also inputs to layers are ingored here 543 | // since the previous loop nest assigned scales to them. 544 | for (int i = 0; i < network->getNbLayers(); i++) 545 | { 546 | auto layer = network->getLayer(i); 547 | for (int j = 0; j < layer->getNbOutputs(); j++) 548 | { 549 | ITensor* output{layer->getOutput(j)}; 550 | // Optional outputs are nullptr here and are from RNN layers. 551 | if (output != nullptr && !output->dynamicRangeIsSet()) 552 | { 553 | // Pooling must have the same input and output scales. 554 | if (layer->getType() == LayerType::kPOOLING) 555 | { 556 | output->setDynamicRange(-inScales, inScales); 557 | } 558 | else 559 | { 560 | output->setDynamicRange(-outScales, outScales); 561 | } 562 | } 563 | } 564 | } 565 | } 566 | 567 | inline void setDummyInt8Scales(const IBuilderConfig* c, INetworkDefinition* n) 568 | { 569 | // Set dummy tensor scales if Int8 mode is requested. 570 | if (c->getFlag(BuilderFlag::kINT8)) 571 | { 572 | gLogWarning << "Int8 calibrator not provided. Generating dummy per tensor scales. Int8 accuracy is not guaranteed." << std::endl; 573 | setAllTensorScales(n); 574 | } 575 | } 576 | 577 | inline void enableDLA(IBuilder* builder, IBuilderConfig* config, int useDLACore, bool allowGPUFallback = true) 578 | { 579 | if (useDLACore >= 0) 580 | { 581 | if (builder->getNbDLACores() == 0) 582 | { 583 | std::cerr << "Trying to use DLA core " << useDLACore << " on a platform that doesn't have any DLA cores" << std::endl; 584 | assert("Error: use DLA core on a platfrom that doesn't have any DLA cores" && false); 585 | } 586 | if (allowGPUFallback) 587 | { 588 | config->setFlag(BuilderFlag::kGPU_FALLBACK); 589 | } 590 | if (!builder->getInt8Mode() && !config->getFlag(BuilderFlag::kINT8)) 591 | { 592 | // User has not requested INT8 Mode. 593 | // By default run in FP16 mode. FP32 mode is not permitted. 594 | builder->setFp16Mode(true); 595 | config->setFlag(BuilderFlag::kFP16); 596 | } 597 | config->setDefaultDeviceType(DeviceType::kDLA); 598 | config->setDLACore(useDLACore); 599 | config->setFlag(BuilderFlag::kSTRICT_TYPES); 600 | } 601 | } 602 | 603 | inline int parseDLA(int argc, char** argv) 604 | { 605 | for (int i = 1; i < argc; i++) 606 | { 607 | std::string arg(argv[i]); 608 | if (strncmp(argv[i], "--useDLACore=", 13) == 0) 609 | return std::stoi(argv[i] + 13); 610 | } 611 | return -1; 612 | } 613 | 614 | inline unsigned int getElementSize(nvinfer1::DataType t) 615 | { 616 | switch (t) 617 | { 618 | case nvinfer1::DataType::kINT32: return 4; 619 | case nvinfer1::DataType::kFLOAT: return 4; 620 | case nvinfer1::DataType::kHALF: return 2; 621 | case nvinfer1::DataType::kINT8: return 1; 622 | } 623 | throw std::runtime_error("Invalid DataType."); 624 | return 0; 625 | } 626 | 627 | inline int64_t volume(const nvinfer1::Dims& d) 628 | { 629 | return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies()); 630 | } 631 | 632 | inline unsigned int elementSize(DataType t) 633 | { 634 | switch (t) 635 | { 636 | case DataType::kINT32: 637 | case DataType::kFLOAT: return 4; 638 | case DataType::kHALF: return 2; 639 | case DataType::kINT8: return 1; 640 | } 641 | return 0; 642 | } 643 | 644 | template 645 | inline A divUp(A x, B n) 646 | { 647 | return (x + n - 1) / n; 648 | } 649 | 650 | template 651 | struct PPM 652 | { 653 | std::string magic, fileName; 654 | int h, w, max; 655 | uint8_t buffer[C * H * W]; 656 | }; 657 | 658 | struct BBox 659 | { 660 | float x1, y1, x2, y2; 661 | }; 662 | 663 | template 664 | inline void readPPMFile(const std::string& filename, common::PPM& ppm) 665 | { 666 | ppm.fileName = filename; 667 | std::ifstream infile(filename, std::ifstream::binary); 668 | assert(infile.is_open() && "Attempting to read from a file that is not open."); 669 | infile >> ppm.magic >> ppm.w >> ppm.h >> ppm.max; 670 | infile.seekg(1, infile.cur); 671 | infile.read(reinterpret_cast(ppm.buffer), ppm.w * ppm.h * 3); 672 | } 673 | 674 | template 675 | inline void writePPMFileWithBBox(const std::string& filename, PPM& ppm, const BBox& bbox) 676 | { 677 | std::ofstream outfile("./" + filename, std::ofstream::binary); 678 | assert(!outfile.fail()); 679 | outfile << "P6" 680 | << "\n" 681 | << ppm.w << " " << ppm.h << "\n" 682 | << ppm.max << "\n"; 683 | auto round = [](float x) -> int { return int(std::floor(x + 0.5f)); }; 684 | const int x1 = std::min(std::max(0, round(int(bbox.x1))), W - 1); 685 | const int x2 = std::min(std::max(0, round(int(bbox.x2))), W - 1); 686 | const int y1 = std::min(std::max(0, round(int(bbox.y1))), H - 1); 687 | const int y2 = std::min(std::max(0, round(int(bbox.y2))), H - 1); 688 | for (int x = x1; x <= x2; ++x) 689 | { 690 | // bbox top border 691 | ppm.buffer[(y1 * ppm.w + x) * 3] = 255; 692 | ppm.buffer[(y1 * ppm.w + x) * 3 + 1] = 0; 693 | ppm.buffer[(y1 * ppm.w + x) * 3 + 2] = 0; 694 | // bbox bottom border 695 | ppm.buffer[(y2 * ppm.w + x) * 3] = 255; 696 | ppm.buffer[(y2 * ppm.w + x) * 3 + 1] = 0; 697 | ppm.buffer[(y2 * ppm.w + x) * 3 + 2] = 0; 698 | } 699 | for (int y = y1; y <= y2; ++y) 700 | { 701 | // bbox left border 702 | ppm.buffer[(y * ppm.w + x1) * 3] = 255; 703 | ppm.buffer[(y * ppm.w + x1) * 3 + 1] = 0; 704 | ppm.buffer[(y * ppm.w + x1) * 3 + 2] = 0; 705 | // bbox right border 706 | ppm.buffer[(y * ppm.w + x2) * 3] = 255; 707 | ppm.buffer[(y * ppm.w + x2) * 3 + 1] = 0; 708 | ppm.buffer[(y * ppm.w + x2) * 3 + 2] = 0; 709 | } 710 | outfile.write(reinterpret_cast(ppm.buffer), ppm.w * ppm.h * 3); 711 | } 712 | 713 | class TimerBase 714 | { 715 | public: 716 | virtual void start() {} 717 | virtual void stop() {} 718 | float microseconds() const noexcept { return mMs * 1000.f; } 719 | float milliseconds() const noexcept { return mMs; } 720 | float seconds() const noexcept { return mMs / 1000.f; } 721 | void reset() noexcept { mMs = 0.f; } 722 | 723 | protected: 724 | float mMs{0.0f}; 725 | }; 726 | 727 | class GpuTimer : public TimerBase 728 | { 729 | public: 730 | GpuTimer(cudaStream_t stream) 731 | : mStream(stream) 732 | { 733 | CHECK(cudaEventCreate(&mStart)); 734 | CHECK(cudaEventCreate(&mStop)); 735 | } 736 | ~GpuTimer() 737 | { 738 | CHECK(cudaEventDestroy(mStart)); 739 | CHECK(cudaEventDestroy(mStop)); 740 | } 741 | void start() { CHECK(cudaEventRecord(mStart, mStream)); } 742 | void stop() 743 | { 744 | CHECK(cudaEventRecord(mStop, mStream)); 745 | float ms{0.0f}; 746 | CHECK(cudaEventSynchronize(mStop)); 747 | CHECK(cudaEventElapsedTime(&ms, mStart, mStop)); 748 | mMs += ms; 749 | } 750 | 751 | private: 752 | cudaEvent_t mStart, mStop; 753 | cudaStream_t mStream; 754 | }; // class GpuTimer 755 | 756 | template 757 | class CpuTimer : public TimerBase 758 | { 759 | public: 760 | using clock_type = Clock; 761 | 762 | void start() { mStart = Clock::now(); } 763 | void stop() 764 | { 765 | mStop = Clock::now(); 766 | mMs += std::chrono::duration{mStop - mStart}.count(); 767 | } 768 | 769 | private: 770 | std::chrono::time_point mStart, mStop; 771 | }; // class CpuTimer 772 | 773 | using PreciseCpuTimer = CpuTimer; 774 | 775 | inline std::vector splitString(std::string str, char delimiter = ',') 776 | { 777 | std::vector splitVect; 778 | std::stringstream ss(str); 779 | std::string substr; 780 | 781 | while (ss.good()) 782 | { 783 | getline(ss, substr, delimiter); 784 | splitVect.emplace_back(std::move(substr)); 785 | } 786 | return splitVect; 787 | } 788 | 789 | // Return m rounded up to nearest multiple of n 790 | inline int roundUp(int m, int n) 791 | { 792 | return ((m + n - 1) / n) * n; 793 | } 794 | 795 | inline int getC(const Dims& d) 796 | { 797 | return d.nbDims >= 3 ? d.d[d.nbDims - 3] : 1; 798 | } 799 | 800 | inline int getH(const Dims& d) 801 | { 802 | return d.nbDims >= 2 ? d.d[d.nbDims - 2] : 1; 803 | } 804 | 805 | inline int getW(const Dims& d) 806 | { 807 | return d.nbDims >= 1 ? d.d[d.nbDims - 1] : 1; 808 | } 809 | 810 | inline void loadLibrary(const std::string& path) 811 | { 812 | #ifdef _MSC_VER 813 | void* handle = LoadLibrary(path.c_str()); 814 | #else 815 | void* handle = dlopen(path.c_str(), RTLD_LAZY); 816 | #endif 817 | if (handle == nullptr) 818 | { 819 | #ifdef _MSC_VER 820 | gLogError << "Could not load plugin library: " << path << std::endl; 821 | #else 822 | gLogError << "Could not load plugin library: " << path << ", due to: " << dlerror() << std::endl; 823 | #endif 824 | } 825 | } 826 | 827 | } // namespace common 828 | 829 | inline std::ostream& operator<<(std::ostream& os, const nvinfer1::Dims& dims) 830 | { 831 | os << "("; 832 | for (int i = 0; i < dims.nbDims; ++i) 833 | { 834 | os << (i ? ", " : "") << dims.d[i]; 835 | } 836 | return os << ")"; 837 | } 838 | 839 | #endif // TENSORRT_COMMON_H 840 | -------------------------------------------------------------------------------- /common/dumpTFWts.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 4 | # 5 | # NOTICE TO LICENSEE: 6 | # 7 | # This source code and/or documentation ("Licensed Deliverables") are 8 | # subject to NVIDIA intellectual property rights under U.S. and 9 | # international Copyright laws. 10 | # 11 | # These Licensed Deliverables contained herein is PROPRIETARY and 12 | # CONFIDENTIAL to NVIDIA and is being provided under the terms and 13 | # conditions of a form of NVIDIA software license agreement by and 14 | # between NVIDIA and Licensee ("License Agreement") or electronically 15 | # accepted by Licensee. Notwithstanding any terms or conditions to 16 | # the contrary in the License Agreement, reproduction or disclosure 17 | # of the Licensed Deliverables to any third party without the express 18 | # written consent of NVIDIA is prohibited. 19 | # 20 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 21 | # LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 22 | # SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 23 | # PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 24 | # NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 25 | # DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 26 | # NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 27 | # NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 28 | # LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 29 | # SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 30 | # DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 31 | # WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 32 | # ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 33 | # OF THESE LICENSED DELIVERABLES. 34 | # 35 | # U.S. Government End Users. These Licensed Deliverables are a 36 | # "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 37 | # 1995), consisting of "commercial computer software" and "commercial 38 | # computer software documentation" as such terms are used in 48 39 | # C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 40 | # only as a commercial end item. Consistent with 48 C.F.R.12.212 and 41 | # 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 42 | # U.S. Government End Users acquire the Licensed Deliverables with 43 | # only those rights set forth herein. 44 | # 45 | # Any use of the Licensed Deliverables in individual and commercial 46 | # software must include, in the user documentation and internal 47 | # comments to the code, the above Disclaimer and U.S. Government End 48 | # Users Notice. 49 | # 50 | 51 | 52 | # Script to dump TensorFlow weights in TRT v1 and v2 dump format. 53 | # The V1 format is for TensorRT 4.0. The V2 format is for TensorRT 4.0 and later. 54 | 55 | import sys 56 | import struct 57 | import argparse 58 | try: 59 | import tensorflow as tf 60 | from tensorflow.python import pywrap_tensorflow 61 | except ImportError as err: 62 | sys.stderr.write("""Error: Failed to import module ({})""".format(err)) 63 | sys.exit() 64 | 65 | parser = argparse.ArgumentParser(description='TensorFlow Weight Dumper') 66 | 67 | parser.add_argument('-m', '--model', required=True, help='The checkpoint file basename, example basename(model.ckpt-766908.data-00000-of-00001) -> model.ckpt-766908') 68 | parser.add_argument('-o', '--output', required=True, help='The weight file to dump all the weights to.') 69 | parser.add_argument('-1', '--wtsv1', required=False, default=False, type=bool, help='Dump the weights in the wts v1.') 70 | 71 | opt = parser.parse_args() 72 | 73 | if opt.wtsv1: 74 | print "Outputting the trained weights in TensorRT's wts v1 format. This format is documented as:" 75 | print "Line 0: " 76 | print "Line 1-Num: [buffer name] [buffer type] [buffer size] " 77 | else: 78 | print "Outputting the trained weights in TensorRT's wts v2 format. This format is documented as:" 79 | print "Line 0: " 80 | print "Line 1-Num: [buffer name] [buffer type] [(buffer shape{e.g. (1, 2, 3)}] " 81 | 82 | inputbase = opt.model 83 | outputbase = opt.output 84 | 85 | def float_to_hex(f): 86 | return hex(struct.unpack(' 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | namespace nvinfer1 11 | { 12 | namespace utility 13 | { 14 | 15 | //! Matching for TRTOptions is defined as follows: 16 | //! 17 | //! If A and B both have longName set, A matches B if and only if A.longName == 18 | //! B.longName and (A.shortName == B.shortName if both have short name set). 19 | //! 20 | //! If A only has shortName set and B only has longName set, then A does not 21 | //! match B. It is assumed that when 2 TRTOptions are compared, one of them is 22 | //! the definition of a TRTOption in the input to getOptions. As such, if the 23 | //! definition only has shortName set, it will never be equal to a TRTOption 24 | //! that does not have shortName set (and same for longName). 25 | //! 26 | //! If A and B both have shortName set but B does not have longName set, A 27 | //! matches B if and only if A.shortName == B.shortName. 28 | //! 29 | //! If A has neither long or short name set, A matches B if and only if B has 30 | //! neither long or short name set. 31 | bool matches(const TRTOption& a, const TRTOption& b) 32 | { 33 | if (!a.longName.empty() && !b.longName.empty()) 34 | { 35 | if (a.shortName && b.shortName) 36 | { 37 | return (a.longName == b.longName) && (a.shortName == b.shortName); 38 | } 39 | return a.longName == b.longName; 40 | } 41 | 42 | // If only one of them is not set, this will return false anyway. 43 | return a.shortName == b.shortName; 44 | } 45 | 46 | //! getTRTOptionIndex returns the index of a TRTOption in a vector of 47 | //! TRTOptions, -1 if not found. 48 | int getTRTOptionIndex(const std::vector& options, const TRTOption& opt) 49 | { 50 | for (size_t i = 0; i < options.size(); ++i) 51 | { 52 | if (matches(opt, options[i])) 53 | { 54 | return i; 55 | } 56 | } 57 | return -1; 58 | } 59 | 60 | //! validateTRTOption will return a string containing an error message if options 61 | //! contain non-numeric characters, or if there are duplicate option names found. 62 | //! Otherwise, returns the empty string. 63 | std::string validateTRTOption( 64 | const std::set& seenShortNames, const std::set& seenLongNames, const TRTOption& opt) 65 | { 66 | if (opt.shortName != 0) 67 | { 68 | if (!std::isalnum(opt.shortName)) 69 | { 70 | return "Short name '" + std::to_string(opt.shortName) + "' is non-alphanumeric"; 71 | } 72 | 73 | if (seenShortNames.find(opt.shortName) != seenShortNames.end()) 74 | { 75 | return "Short name '" + std::to_string(opt.shortName) + "' is a duplicate"; 76 | } 77 | } 78 | 79 | if (!opt.longName.empty()) 80 | { 81 | for (const char& c : opt.longName) 82 | { 83 | if (!std::isalnum(c) && c != '-' && c != '_') 84 | { 85 | return "Long name '" + opt.longName + "' contains characters that are not '-', '_', or alphanumeric"; 86 | } 87 | } 88 | 89 | if (seenLongNames.find(opt.longName) != seenLongNames.end()) 90 | { 91 | return "Long name '" + opt.longName + "' is a duplicate"; 92 | } 93 | } 94 | return ""; 95 | } 96 | 97 | //! validateTRTOptions will return a string containing an error message if any 98 | //! options contain non-numeric characters, or if there are duplicate option 99 | //! names found. Otherwise, returns the empty string. 100 | std::string validateTRTOptions(const std::vector& options) 101 | { 102 | std::set seenShortNames; 103 | std::set seenLongNames; 104 | for (size_t i = 0; i < options.size(); ++i) 105 | { 106 | const std::string errMsg = validateTRTOption(seenShortNames, seenLongNames, options[i]); 107 | if (!errMsg.empty()) 108 | { 109 | return "Error '" + errMsg + "' at TRTOption " + std::to_string(i); 110 | } 111 | 112 | seenShortNames.insert(options[i].shortName); 113 | seenLongNames.insert(options[i].longName); 114 | } 115 | return ""; 116 | } 117 | 118 | //! parseArgs parses an argument list and returns a TRTParsedArgs with the 119 | //! fields set accordingly. Assumes that options is validated. 120 | //! ErrMsg will be set if: 121 | //! - an argument is null 122 | //! - an argument is empty 123 | //! - an argument does not have option (i.e. "-" and "--") 124 | //! - a short argument has more than 1 character 125 | //! - the last argument in the list requires a value 126 | TRTParsedArgs parseArgs(int argc, const char* const* argv, const std::vector& options) 127 | { 128 | TRTParsedArgs parsedArgs; 129 | parsedArgs.values.resize(options.size()); 130 | 131 | for (int i = 1; i < argc; ++i) // index of current command-line argument 132 | { 133 | if (argv[i] == nullptr) 134 | { 135 | return TRTParsedArgs{"Null argument at index " + std::to_string(i)}; 136 | } 137 | 138 | const std::string argStr(argv[i]); 139 | if (argStr.empty()) 140 | { 141 | return TRTParsedArgs{"Empty argument at index " + std::to_string(i)}; 142 | } 143 | 144 | // No starting hyphen means it is a positional argument 145 | if (argStr[0] != '-') 146 | { 147 | parsedArgs.positionalArgs.push_back(argStr); 148 | continue; 149 | } 150 | 151 | if (argStr == "-" || argStr == "--") 152 | { 153 | return TRTParsedArgs{"Argument does not specify an option at index " + std::to_string(i)}; 154 | } 155 | 156 | // If only 1 hyphen, char after is the flag. 157 | TRTOption opt; 158 | std::string value; 159 | if (argStr[1] != '-') 160 | { 161 | // Must only have 1 char after the hyphen 162 | if (argStr.size() > 2) 163 | { 164 | return TRTParsedArgs{"Short arg contains more than 1 character at index " + std::to_string(i)}; 165 | } 166 | opt.shortName = argStr[1]; 167 | } 168 | else 169 | { 170 | opt.longName = argStr.substr(2); 171 | 172 | // We need to support --foo=bar syntax, so look for '=' 173 | const size_t eqIndex = opt.longName.find('='); 174 | if (eqIndex < opt.longName.size()) 175 | { 176 | value = opt.longName.substr(eqIndex + 1); 177 | opt.longName = opt.longName.substr(0, eqIndex); 178 | } 179 | } 180 | 181 | const int idx = getTRTOptionIndex(options, opt); 182 | if (idx < 0) 183 | { 184 | continue; 185 | } 186 | 187 | if (options[idx].valueRequired) 188 | { 189 | if (!value.empty()) 190 | { 191 | parsedArgs.values[idx].second.push_back(value); 192 | parsedArgs.values[idx].first = parsedArgs.values[idx].second.size(); 193 | continue; 194 | } 195 | 196 | if (i + 1 >= argc) 197 | { 198 | return TRTParsedArgs{"Last argument requires value, but none given"}; 199 | } 200 | 201 | const std::string nextArg(argv[i + 1]); 202 | if (nextArg.size() >= 1 && nextArg[0] == '-') 203 | { 204 | gLogWarning << "Warning: Using '" << nextArg << "' as a value for '" << argStr 205 | << "', Should this be its own flag?" << std::endl; 206 | } 207 | 208 | parsedArgs.values[idx].second.push_back(nextArg); 209 | i += 1; // Next argument already consumed 210 | 211 | parsedArgs.values[idx].first = parsedArgs.values[idx].second.size(); 212 | } 213 | else 214 | { 215 | parsedArgs.values[idx].first += 1; 216 | } 217 | } 218 | return parsedArgs; 219 | } 220 | 221 | TRTParsedArgs getOptions(int argc, const char* const* argv, const std::vector& options) 222 | { 223 | const std::string errMsg = validateTRTOptions(options); 224 | if (!errMsg.empty()) 225 | { 226 | return TRTParsedArgs{errMsg}; 227 | } 228 | return parseArgs(argc, argv, options); 229 | } 230 | } // namespace utility 231 | } // namespace nvinfer1 232 | -------------------------------------------------------------------------------- /common/getOptions.h: -------------------------------------------------------------------------------- 1 | #ifndef TRT_GET_OPTIONS_H 2 | #define TRT_GET_OPTIONS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace nvinfer1 9 | { 10 | namespace utility 11 | { 12 | 13 | //! TRTOption defines a command line option. At least 1 of shortName and longName 14 | //! must be defined. 15 | //! If bool initialization is undefined behavior on your system, valueRequired 16 | //! must also be explicitly defined. 17 | //! helpText is optional. 18 | struct TRTOption 19 | { 20 | char shortName; //!< Option name in short (single hyphen) form (i.e. -a, -b) 21 | std::string longName; //!< Option name in long (double hyphen) form (i.e. --foo, --bar) 22 | bool valueRequired; //!< True if a value is needed for an option (i.e. -N 4, --foo bar) 23 | std::string helpText; //!< Text to show when printing out the command usage 24 | }; 25 | 26 | //! TRTParsedArgs is returned by getOptions after it has parsed a command line 27 | //! argument list (argv). 28 | //! 29 | //! errMsg is a string containing an error message if any errors occurred. If it 30 | //! is empty, no errors occurred. 31 | //! 32 | //! values stores a vector of pairs for each option (ordered by order in the 33 | //! input). Each pair contains an int (the number of occurrences) and a vector 34 | //! of strings (a list of values). The user should know which of these to use, 35 | //! and which options required values. For non-value options, only occurrences is 36 | //! populated. For value-required options, occurrences == # of values. Values do 37 | //! not need to be unique. 38 | //! 39 | //! positionalArgs stores additional arguments that are passed in without an 40 | //! option (these must not start with a hyphen). 41 | struct TRTParsedArgs 42 | { 43 | std::string errMsg; 44 | std::vector>> values; 45 | std::vector positionalArgs; 46 | }; 47 | 48 | //! Parse the input arguments passed to main() and extract options as well as 49 | //! positional arguments. 50 | //! 51 | //! Options are supposed to be passed to main() with a preceding hyphen '-'. 52 | //! 53 | //! If there is a single preceding hyphen, there should be exactly 1 character 54 | //! after the hyphen, which is interpreted as the option. 55 | //! 56 | //! If there are 2 preceding hyphens, the entire argument (without the hyphens) 57 | //! is interpreted as the option. 58 | //! 59 | //! If the option requires a value, the next argument is used as the value. 60 | //! 61 | //! Positional arguments must not start with a hyphen. 62 | //! 63 | //! If an argument requires a value, the next argument is interpreted as the 64 | //! value, even if it is the form of a valid option (i.e. --foo --bar will store 65 | //! "--bar" as a value for option "foo" if "foo" requires a value). 66 | //! We also support --name=value syntax. In this case, 'value' would be used as 67 | //! the value, NOT the next argument. 68 | //! 69 | //! For options: 70 | //! { { 'a', "", false }, 71 | //! { 'b', "", false }, 72 | //! { 0, "cee", false }, 73 | //! { 'd', "", true }, 74 | //! { 'e', "", true }, 75 | //! { 'f', "foo", true } } 76 | //! 77 | //! ./main hello world -a -a --cee -d 12 -f 34 78 | //! and 79 | //! ./main hello world -a -a --cee -d 12 --foo 34 80 | //! 81 | //! will result in: 82 | //! 83 | //! TRTParsedArgs { 84 | //! errMsg: "", 85 | //! values: { { 2, {} }, 86 | //! { 0, {} }, 87 | //! { 1, {} }, 88 | //! { 1, {"12"} }, 89 | //! { 0, {} }, 90 | //! { 1, {"34"} } } 91 | //! positionalArgs: {"hello", "world"}, 92 | //! } 93 | //! 94 | //! Non-POSIX behavior: 95 | //! - Does not support "-abcde" as a shorthand for "-a -b -c -d -e". Each 96 | //! option must have its own hyphen prefix. 97 | //! - Does not support -e12 as a shorthand for "-e 12". Values MUST be 98 | //! whitespace-separated from the option it is for. 99 | //! 100 | //! @param[in] argc The number of arguments passed to main (including the 101 | //! file name, which is disregarded) 102 | //! @param[in] argv The arguments passed to main (including the file name, 103 | //! which is disregarded) 104 | //! @param[in] options List of TRTOptions to parse 105 | //! @return TRTParsedArgs. See TRTParsedArgs documentation for descriptions of 106 | //! the fields. 107 | TRTParsedArgs getOptions(int argc, const char* const* argv, const std::vector& options); 108 | } // namespace utility 109 | } // namespace nvinfer1 110 | 111 | #endif // TRT_GET_OPTIONS_H 112 | -------------------------------------------------------------------------------- /common/logger.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #include "logger.h" 51 | #include "logging.h" 52 | 53 | Logger gLogger{Logger::Severity::kINFO}; 54 | LogStreamConsumer gLogVerbose{LOG_VERBOSE(gLogger)}; 55 | LogStreamConsumer gLogInfo{LOG_INFO(gLogger)}; 56 | LogStreamConsumer gLogWarning{LOG_WARN(gLogger)}; 57 | LogStreamConsumer gLogError{LOG_ERROR(gLogger)}; 58 | LogStreamConsumer gLogFatal{LOG_FATAL(gLogger)}; 59 | 60 | void setReportableSeverity(Logger::Severity severity) 61 | { 62 | gLogger.setReportableSeverity(severity); 63 | gLogVerbose.setReportableSeverity(severity); 64 | gLogInfo.setReportableSeverity(severity); 65 | gLogWarning.setReportableSeverity(severity); 66 | gLogError.setReportableSeverity(severity); 67 | gLogFatal.setReportableSeverity(severity); 68 | } 69 | -------------------------------------------------------------------------------- /common/logger.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef LOGGER_H 51 | #define LOGGER_H 52 | 53 | #include "logging.h" 54 | 55 | extern Logger gLogger; 56 | extern LogStreamConsumer gLogVerbose; 57 | extern LogStreamConsumer gLogInfo; 58 | extern LogStreamConsumer gLogWarning; 59 | extern LogStreamConsumer gLogError; 60 | extern LogStreamConsumer gLogFatal; 61 | 62 | void setReportableSeverity(Logger::Severity severity); 63 | 64 | #endif // LOGGER_H 65 | 66 | -------------------------------------------------------------------------------- /common/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef TENSORRT_LOGGING_H 51 | #define TENSORRT_LOGGING_H 52 | 53 | #include "NvInferRuntimeCommon.h" 54 | #include 55 | #include 56 | #include 57 | #include 58 | #include 59 | #include 60 | #include 61 | 62 | using Severity = nvinfer1::ILogger::Severity; 63 | 64 | class LogStreamConsumerBuffer : public std::stringbuf 65 | { 66 | public: 67 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 68 | : mOutput(stream) 69 | , mPrefix(prefix) 70 | , mShouldLog(shouldLog) 71 | { 72 | } 73 | 74 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 75 | : mOutput(other.mOutput) 76 | { 77 | } 78 | 79 | ~LogStreamConsumerBuffer() 80 | { 81 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 82 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 83 | // if the pointer to the beginning is not equal to the pointer to the current position, 84 | // call putOutput() to log the output to the stream 85 | if (pbase() != pptr()) 86 | { 87 | putOutput(); 88 | } 89 | } 90 | 91 | // synchronizes the stream buffer and returns 0 on success 92 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 93 | // resetting the buffer and flushing the stream 94 | virtual int sync() 95 | { 96 | putOutput(); 97 | return 0; 98 | } 99 | 100 | void putOutput() 101 | { 102 | if (mShouldLog) 103 | { 104 | // prepend timestamp 105 | std::time_t timestamp = std::time(nullptr); 106 | tm *tm_local = std::localtime(×tamp); 107 | std::cout << "["; 108 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mon << "/"; 109 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 110 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 111 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 112 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 113 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 114 | // std::stringbuf::str() gets the string contents of the buffer 115 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 116 | mOutput << mPrefix << str(); 117 | // set the buffer to empty 118 | str(""); 119 | // flush the stream 120 | mOutput.flush(); 121 | } 122 | } 123 | 124 | void setShouldLog(bool shouldLog) 125 | { 126 | mShouldLog = shouldLog; 127 | } 128 | 129 | private: 130 | std::ostream& mOutput; 131 | std::string mPrefix; 132 | bool mShouldLog; 133 | }; 134 | 135 | //! 136 | //! \class LogStreamConsumerBase 137 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 138 | //! 139 | class LogStreamConsumerBase 140 | { 141 | public: 142 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 143 | : mBuffer(stream, prefix, shouldLog) 144 | { 145 | } 146 | 147 | protected: 148 | LogStreamConsumerBuffer mBuffer; 149 | }; 150 | 151 | //! 152 | //! \class LogStreamConsumer 153 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 154 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 155 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 156 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 157 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 158 | //! Please do not change the order of the parent classes. 159 | //! 160 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 161 | { 162 | public: 163 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 164 | //! Reportable severity determines if the messages are severe enough to be logged. 165 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 166 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 167 | , std::ostream(&mBuffer) // links the stream buffer with the stream 168 | , mShouldLog(severity <= reportableSeverity) 169 | , mSeverity(severity) 170 | { 171 | } 172 | 173 | LogStreamConsumer(LogStreamConsumer&& other) 174 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 175 | , std::ostream(&mBuffer) // links the stream buffer with the stream 176 | , mShouldLog(other.mShouldLog) 177 | , mSeverity(other.mSeverity) 178 | { 179 | } 180 | 181 | void setReportableSeverity(Severity reportableSeverity) 182 | { 183 | mShouldLog = mSeverity <= reportableSeverity; 184 | mBuffer.setShouldLog(mShouldLog); 185 | } 186 | 187 | private: 188 | static std::ostream& severityOstream(Severity severity) 189 | { 190 | return severity >= Severity::kINFO ? std::cout : std::cerr; 191 | } 192 | 193 | static std::string severityPrefix(Severity severity) 194 | { 195 | switch (severity) 196 | { 197 | case Severity::kINTERNAL_ERROR: return "[F] "; 198 | case Severity::kERROR: return "[E] "; 199 | case Severity::kWARNING: return "[W] "; 200 | case Severity::kINFO: return "[I] "; 201 | case Severity::kVERBOSE: return "[V] "; 202 | default: assert(0); return ""; 203 | } 204 | } 205 | 206 | bool mShouldLog; 207 | Severity mSeverity; 208 | }; 209 | 210 | //! \class Logger 211 | //! 212 | //! \brief Class which manages logging of TensorRT tools and samples 213 | //! 214 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, and 215 | //! supports logging two types of messages: 216 | //! 217 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 218 | //! - Test pass/fail messages 219 | //! 220 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is that 221 | //! the logic for controlling the verbosity and formatting of sample output is centralized in one location. 222 | //! 223 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 224 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 225 | //! 226 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger interface, 227 | //! which is problematic since there isn't a clean separation between messages coming from the TensorRT library and messages coming 228 | //! from the sample. 229 | //! 230 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the class 231 | //! to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger object. 232 | 233 | class Logger : public nvinfer1::ILogger 234 | { 235 | public: 236 | Logger(Severity severity = Severity::kWARNING) 237 | : mReportableSeverity(severity) 238 | { 239 | } 240 | 241 | //! 242 | //! \enum TestResult 243 | //! \brief Represents the state of a given test 244 | //! 245 | enum class TestResult 246 | { 247 | kRUNNING, //!< The test is running 248 | kPASSED, //!< The test passed 249 | kFAILED, //!< The test failed 250 | kWAIVED //!< The test was waived 251 | }; 252 | 253 | //! 254 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 255 | //! \return The nvinfer1::ILogger associated with this Logger 256 | //! 257 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 258 | //! we can eliminate the inheritance of Logger from ILogger 259 | //! 260 | nvinfer1::ILogger& getTRTLogger() 261 | { 262 | return *this; 263 | } 264 | 265 | //! 266 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 267 | //! 268 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the inheritance from 269 | //! nvinfer1::ILogger 270 | //! 271 | void log(Severity severity, const char* msg) override 272 | { 273 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 274 | } 275 | 276 | //! 277 | //! \brief Method for controlling the verbosity of logging output 278 | //! 279 | //! \param severity The logger will only emit messages that have severity of this level or higher. 280 | //! 281 | void setReportableSeverity(Severity severity) 282 | { 283 | mReportableSeverity = severity; 284 | } 285 | 286 | //! 287 | //! \brief Opaque handle that holds logging information for a particular test 288 | //! 289 | //! This object is an opaque handle to information used by the Logger to print test results. 290 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 291 | //! with Logger::reportTest{Start,End}(). 292 | //! 293 | class TestAtom 294 | { 295 | public: 296 | TestAtom(TestAtom&&) = default; 297 | 298 | private: 299 | friend class Logger; 300 | 301 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 302 | : mStarted(started) 303 | , mName(name) 304 | , mCmdline(cmdline) 305 | { 306 | } 307 | 308 | bool mStarted; 309 | std::string mName; 310 | std::string mCmdline; 311 | }; 312 | 313 | //! 314 | //! \brief Define a test for logging 315 | //! 316 | //! \param[in] name The name of the test. This should be a string starting with 317 | //! "TensorRT" and containing dot-separated strings containing 318 | //! the characters [A-Za-z0-9_]. 319 | //! For example, "TensorRT.sample_googlenet" 320 | //! \param[in] cmdline The command line used to reproduce the test 321 | // 322 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 323 | //! 324 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 325 | { 326 | return TestAtom(false, name, cmdline); 327 | } 328 | 329 | //! 330 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 331 | //! as input 332 | //! 333 | //! \param[in] name The name of the test 334 | //! \param[in] argc The number of command-line arguments 335 | //! \param[in] argv The array of command-line arguments (given as C strings) 336 | //! 337 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 338 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 339 | { 340 | auto cmdline = genCmdlineString(argc, argv); 341 | return defineTest(name, cmdline); 342 | } 343 | 344 | //! 345 | //! \brief Report that a test has started. 346 | //! 347 | //! \pre reportTestStart() has not been called yet for the given testAtom 348 | //! 349 | //! \param[in] testAtom The handle to the test that has started 350 | //! 351 | static void reportTestStart(TestAtom& testAtom) 352 | { 353 | reportTestResult(testAtom, TestResult::kRUNNING); 354 | assert(!testAtom.mStarted); 355 | testAtom.mStarted = true; 356 | } 357 | 358 | //! 359 | //! \brief Report that a test has ended. 360 | //! 361 | //! \pre reportTestStart() has been called for the given testAtom 362 | //! 363 | //! \param[in] testAtom The handle to the test that has ended 364 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 365 | //! TestResult::kFAILED, TestResult::kWAIVED 366 | //! 367 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 368 | { 369 | assert(result != TestResult::kRUNNING); 370 | assert(testAtom.mStarted); 371 | reportTestResult(testAtom, result); 372 | } 373 | 374 | static int reportPass(const TestAtom& testAtom) 375 | { 376 | reportTestEnd(testAtom, TestResult::kPASSED); 377 | return EXIT_SUCCESS; 378 | } 379 | 380 | static int reportFail(const TestAtom& testAtom) 381 | { 382 | reportTestEnd(testAtom, TestResult::kFAILED); 383 | return EXIT_FAILURE; 384 | } 385 | 386 | static int reportWaive(const TestAtom& testAtom) 387 | { 388 | reportTestEnd(testAtom, TestResult::kWAIVED); 389 | return EXIT_SUCCESS; 390 | } 391 | 392 | static int reportTest(const TestAtom& testAtom, bool pass) 393 | { 394 | return pass ? reportPass(testAtom) : reportFail(testAtom); 395 | } 396 | 397 | Severity getReportableSeverity() const 398 | { 399 | return mReportableSeverity; 400 | } 401 | 402 | private: 403 | //! 404 | //! \brief returns an appropriate string for prefixing a log message with the given severity 405 | //! 406 | static const char* severityPrefix(Severity severity) 407 | { 408 | switch (severity) 409 | { 410 | case Severity::kINTERNAL_ERROR: return "[F] "; 411 | case Severity::kERROR: return "[E] "; 412 | case Severity::kWARNING: return "[W] "; 413 | case Severity::kINFO: return "[I] "; 414 | case Severity::kVERBOSE: return "[V] "; 415 | default: assert(0); return ""; 416 | } 417 | } 418 | 419 | //! 420 | //! \brief returns an appropriate string for prefixing a test result message with the given result 421 | //! 422 | static const char* testResultString(TestResult result) 423 | { 424 | switch (result) 425 | { 426 | case TestResult::kRUNNING: return "RUNNING"; 427 | case TestResult::kPASSED: return "PASSED"; 428 | case TestResult::kFAILED: return "FAILED"; 429 | case TestResult::kWAIVED: return "WAIVED"; 430 | default: assert(0); return ""; 431 | } 432 | } 433 | 434 | //! 435 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 436 | //! 437 | static std::ostream& severityOstream(Severity severity) 438 | { 439 | return severity >= Severity::kINFO ? std::cout : std::cerr; 440 | } 441 | 442 | //! 443 | //! \brief method that implements logging test results 444 | //! 445 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 446 | { 447 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) 448 | << " " << testAtom.mName << " # " << testAtom.mCmdline 449 | << std::endl; 450 | } 451 | 452 | //! 453 | //! \brief generate a command line string from the given (argc, argv) values 454 | //! 455 | static std::string genCmdlineString(int argc, char const* const* argv) 456 | { 457 | std::stringstream ss; 458 | for (int i = 0; i < argc; i++) 459 | { 460 | if (i > 0) 461 | ss << " "; 462 | ss << argv[i]; 463 | } 464 | return ss.str(); 465 | } 466 | 467 | Severity mReportableSeverity; 468 | }; 469 | 470 | namespace 471 | { 472 | 473 | //! 474 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 475 | //! 476 | //! Example usage: 477 | //! 478 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 479 | //! 480 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 481 | { 482 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 483 | } 484 | 485 | //! 486 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 487 | //! 488 | //! Example usage: 489 | //! 490 | //! LOG_INFO(logger) << "hello world" << std::endl; 491 | //! 492 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 493 | { 494 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 495 | } 496 | 497 | //! 498 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 499 | //! 500 | //! Example usage: 501 | //! 502 | //! LOG_WARN(logger) << "hello world" << std::endl; 503 | //! 504 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 505 | { 506 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 507 | } 508 | 509 | //! 510 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 511 | //! 512 | //! Example usage: 513 | //! 514 | //! LOG_ERROR(logger) << "hello world" << std::endl; 515 | //! 516 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 517 | { 518 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 519 | } 520 | 521 | //! 522 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 523 | // ("fatal" severity) 524 | //! 525 | //! Example usage: 526 | //! 527 | //! LOG_FATAL(logger) << "hello world" << std::endl; 528 | //! 529 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 530 | { 531 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 532 | } 533 | 534 | } // anonymous namespace 535 | 536 | #endif // TENSORRT_LOGGING_H 537 | -------------------------------------------------------------------------------- /common/parserOnnxConfig.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef PARSER_ONNX_CONFIG_H 51 | #define PARSER_ONNX_CONFIG_H 52 | 53 | #include 54 | #include 55 | #include 56 | 57 | #include "NvInfer.h" 58 | #include "NvOnnxConfig.h" 59 | #include "NvOnnxParser.h" 60 | 61 | #define ONNX_DEBUG 1 62 | 63 | /** 64 | * \class ParserOnnxConfig 65 | * \brief Configuration Manager Class Concrete Implementation 66 | * 67 | * \note: 68 | * 69 | */ 70 | 71 | using namespace std; 72 | 73 | class ParserOnnxConfig : public nvonnxparser::IOnnxConfig 74 | { 75 | 76 | protected: 77 | string mModelFilename{}; 78 | string mTextFilename{}; 79 | string mFullTextFilename{}; 80 | nvinfer1::DataType mModelDtype; 81 | nvonnxparser::IOnnxConfig::Verbosity mVerbosity; 82 | bool mPrintLayercInfo; 83 | 84 | public: 85 | ParserOnnxConfig() 86 | : mModelDtype(nvinfer1::DataType::kFLOAT) 87 | , mVerbosity(static_cast(nvinfer1::ILogger::Severity::kWARNING)) 88 | , mPrintLayercInfo(false) 89 | { 90 | #ifdef ONNX_DEBUG 91 | if (isDebug()) 92 | { 93 | std::cout << " ParserOnnxConfig::ctor(): " 94 | << this << "\t" 95 | << std::endl; 96 | } 97 | #endif 98 | } 99 | 100 | protected: 101 | ~ParserOnnxConfig() 102 | { 103 | #ifdef ONNX_DEBUG 104 | if (isDebug()) 105 | { 106 | std::cout << "ParserOnnxConfig::dtor(): " << this << std::endl; 107 | } 108 | #endif 109 | } 110 | 111 | public: 112 | virtual void setModelDtype(const nvinfer1::DataType modelDtype) { mModelDtype = modelDtype; } 113 | 114 | virtual nvinfer1::DataType getModelDtype() const 115 | { 116 | return mModelDtype; 117 | } 118 | 119 | virtual const char* getModelFileName() const { return mModelFilename.c_str(); } 120 | virtual void setModelFileName(const char* onnxFilename) 121 | { 122 | mModelFilename = string(onnxFilename); 123 | } 124 | virtual nvonnxparser::IOnnxConfig::Verbosity getVerbosityLevel() const { return mVerbosity; } 125 | virtual void addVerbosity() { ++mVerbosity; } 126 | virtual void reduceVerbosity() { --mVerbosity; } 127 | virtual void setVerbosityLevel(nvonnxparser::IOnnxConfig::Verbosity verbosity) { mVerbosity = verbosity; } 128 | 129 | virtual const char* getTextFileName() const { return mTextFilename.c_str(); } 130 | virtual void setTextFileName(const char* textFilename) 131 | { 132 | mTextFilename = string(textFilename); 133 | } 134 | virtual const char* getFullTextFileName() const { return mFullTextFilename.c_str(); } 135 | virtual void setFullTextFileName(const char* fullTextFilename) 136 | { 137 | mFullTextFilename = string(fullTextFilename); 138 | } 139 | virtual bool getPrintLayerInfo() const { return mPrintLayercInfo; } 140 | virtual void setPrintLayerInfo(bool src) { mPrintLayercInfo = src; } //!< get the boolean variable corresponding to the Layer Info, see getPrintLayerInfo() 141 | 142 | virtual bool isDebug() const 143 | { 144 | #if ONNX_DEBUG 145 | return (std::getenv("ONNX_DEBUG") ? true : false); 146 | #else 147 | return false; 148 | #endif 149 | } 150 | 151 | virtual void destroy() { delete this; } 152 | 153 | }; // class ParserOnnxConfig 154 | 155 | #endif 156 | -------------------------------------------------------------------------------- /common/sampleConfig.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef SampleConfig_H 51 | #define SampleConfig_H 52 | 53 | #include 54 | #include 55 | #include 56 | 57 | #include "NvInfer.h" 58 | #include "NvOnnxConfig.h" 59 | class SampleConfig : public nvonnxparser::IOnnxConfig 60 | { 61 | public: 62 | enum class InputDataFormat : int 63 | { 64 | kASCII = 0, 65 | kPPM = 1 66 | }; 67 | 68 | private: 69 | std::string mModelFilename; 70 | std::string mEngineFilename; 71 | std::string mTextFilename; 72 | std::string mFullTextFilename; 73 | std::string mImageFilename; 74 | std::string mReferenceFilename; 75 | std::string mOutputFilename; 76 | std::string mCalibrationFilename; 77 | int64_t mMaxBatchSize{32}; 78 | int64_t mMaxWorkspaceSize{1 * 1024 * 1024 * 1024}; 79 | int64_t mCalibBatchSize{0}; 80 | int64_t mMaxNCalibBatch{0}; 81 | int64_t mFirstCalibBatch{0}; 82 | int64_t mUseDLACore{-1}; 83 | nvinfer1::DataType mModelDtype{nvinfer1::DataType::kFLOAT}; 84 | Verbosity mVerbosity{static_cast(nvinfer1::ILogger::Severity::kWARNING)}; 85 | bool mPrintLayercInfo{false}; 86 | bool mDebugBuilder{false}; 87 | InputDataFormat mInputDataFormat{InputDataFormat::kASCII}; 88 | uint64_t mTopK{0}; 89 | float mFailurePercentage{-1.0f}; 90 | 91 | public: 92 | SampleConfig() 93 | { 94 | #ifdef ONNX_DEBUG 95 | if (isDebug()) 96 | { 97 | std::cout << " SampleConfig::ctor(): " 98 | << this << "\t" 99 | << std::endl; 100 | } 101 | #endif 102 | } 103 | 104 | protected: 105 | ~SampleConfig() 106 | { 107 | #ifdef ONNX_DEBUG 108 | if (isDebug()) 109 | { 110 | std::cout << "SampleConfig::dtor(): " << this << std::endl; 111 | } 112 | #endif 113 | } 114 | 115 | public: 116 | void setModelDtype(const nvinfer1::DataType mdt) { mModelDtype = mdt; } 117 | 118 | nvinfer1::DataType getModelDtype() const 119 | { 120 | return mModelDtype; 121 | } 122 | 123 | const char* getModelFileName() const { return mModelFilename.c_str(); } 124 | 125 | void setModelFileName(const char* onnxFilename) 126 | { 127 | mModelFilename = string(onnxFilename); 128 | } 129 | Verbosity getVerbosityLevel() const { return mVerbosity; } 130 | void addVerbosity() { ++mVerbosity; } 131 | void reduceVerbosity() { --mVerbosity; } 132 | virtual void setVerbosityLevel(Verbosity v) { mVerbosity = v; } 133 | const char* getEngineFileName() const { return mEngineFilename.c_str(); } 134 | void setEngineFileName(const char* engineFilename) 135 | { 136 | mEngineFilename = string(engineFilename); 137 | } 138 | const char* getTextFileName() const { return mTextFilename.c_str(); } 139 | void setTextFileName(const char* textFilename) 140 | { 141 | mTextFilename = string(textFilename); 142 | } 143 | const char* getFullTextFileName() const { return mFullTextFilename.c_str(); } 144 | void setFullTextFileName(const char* fullTextFilename) 145 | { 146 | mFullTextFilename = string(fullTextFilename); 147 | } 148 | bool getPrintLayerInfo() const { return mPrintLayercInfo; } 149 | void setPrintLayerInfo(bool b) { mPrintLayercInfo = b; } //!< get the boolean variable corresponding to the Layer Info, see getPrintLayerInfo() 150 | 151 | void setMaxBatchSize(int64_t maxBatchSize) { mMaxBatchSize = maxBatchSize; } //!< set the Max Batch Size 152 | int64_t getMaxBatchSize() const { return mMaxBatchSize; } //!< get the Max Batch Size 153 | 154 | void setMaxWorkSpaceSize(int64_t maxWorkSpaceSize) { mMaxWorkspaceSize = maxWorkSpaceSize; } //!< set the Max Work Space size 155 | int64_t getMaxWorkSpaceSize() const { return mMaxWorkspaceSize; } //!< get the Max Work Space size 156 | 157 | void setCalibBatchSize(int64_t CalibBatchSize) { mCalibBatchSize = CalibBatchSize; } //!< set the calibration batch size 158 | int64_t getCalibBatchSize() const { return mCalibBatchSize; } //!< get calibration batch size 159 | 160 | void setMaxNCalibBatch(int64_t MaxNCalibBatch) { mMaxNCalibBatch = MaxNCalibBatch; } //!< set Max Number of Calibration Batches 161 | int64_t getMaxNCalibBatch() const { return mMaxNCalibBatch; } //!< get the Max Number of Calibration Batches 162 | 163 | void setFirstCalibBatch(int64_t FirstCalibBatch) { mFirstCalibBatch = FirstCalibBatch; } //!< set the first calibration batch 164 | int64_t getFirstCalibBatch() const { return mFirstCalibBatch; } //!< get the first calibration batch 165 | 166 | void setUseDLACore(int64_t UseDLACore) { mUseDLACore = UseDLACore; } //!< set the DLA core to use 167 | int64_t getUseDLACore() const { return mUseDLACore; } //!< get the DLA core to use 168 | 169 | void setDebugBuilder() { mDebugBuilder = true; } //!< enable the Debug info, while building the engine. 170 | bool getDebugBuilder() const { return mDebugBuilder; } //!< get the boolean variable, corresponding to the debug builder 171 | 172 | const char* getImageFileName() const { return mImageFilename.c_str(); } //!< set Image file name (PPM or ASCII) 173 | void setImageFileName(const char* imageFilename) //!< get the Image file name 174 | { 175 | mImageFilename = string(imageFilename); 176 | } 177 | const char* getReferenceFileName() const { return mReferenceFilename.c_str(); } 178 | void setReferenceFileName(const char* referenceFilename) //!< set reference file name 179 | { 180 | mReferenceFilename = string(referenceFilename); 181 | } 182 | 183 | void setInputDataFormat(InputDataFormat idt) { mInputDataFormat = idt; } //!< specifies expected data format of the image file (PPM or ASCII) 184 | InputDataFormat getInputDataFormat() const { return mInputDataFormat; } //!< returns the expected data format of the image file. 185 | 186 | const char* getOutputFileName() const { return mOutputFilename.c_str(); } //!< specifies the file to save the results 187 | void setOutputFileName(const char* outputFilename) //!< get the output file name 188 | { 189 | mOutputFilename = string(outputFilename); 190 | } 191 | 192 | const char* getCalibrationFileName() const { return mCalibrationFilename.c_str(); } //!< specifies the file containing the list of image files for int8 calibration 193 | void setCalibrationFileName(const char* calibrationFilename) //!< get the int 8 calibration list file name 194 | { 195 | mCalibrationFilename = string(calibrationFilename); 196 | } 197 | 198 | uint64_t getTopK() const { return mTopK; } 199 | void setTopK(uint64_t topK) { mTopK = topK; } //!< If this options is specified, return the K top probabilities. 200 | 201 | float getFailurePercentage() const { return mFailurePercentage; } 202 | void setFailurePercentage(float f) { mFailurePercentage = f; } 203 | 204 | bool isDebug() const 205 | { 206 | #if ONNX_DEBUG 207 | return (std::getenv("ONNX_DEBUG") ? true : false); 208 | #else 209 | return false; 210 | #endif 211 | } 212 | 213 | void destroy() { delete this; } 214 | 215 | }; // class SampleConfig 216 | 217 | #endif 218 | -------------------------------------------------------------------------------- /common/sampleEngines.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #include 51 | #include 52 | #include 53 | #include 54 | #include 55 | #include 56 | #include 57 | #include 58 | 59 | #include "NvInfer.h" 60 | #include "NvCaffeParser.h" 61 | #include "NvOnnxParser.h" 62 | #include "NvUffParser.h" 63 | 64 | #include "logger.h" 65 | #include "sampleUtils.h" 66 | #include "sampleOptions.h" 67 | #include "sampleEngines.h" 68 | 69 | using namespace nvinfer1; 70 | 71 | namespace sample 72 | { 73 | 74 | namespace 75 | { 76 | 77 | struct CaffeBufferShutter 78 | { 79 | ~CaffeBufferShutter() { nvcaffeparser1::shutdownProtobufLibrary(); } 80 | }; 81 | 82 | struct UffBufferShutter 83 | { 84 | ~UffBufferShutter() { nvuffparser::shutdownProtobufLibrary(); } 85 | }; 86 | 87 | } 88 | 89 | Parser modelToNetwork(const ModelOptions& model, nvinfer1::INetworkDefinition& network, std::ostream& err) 90 | { 91 | Parser parser; 92 | const std::string& modelName = model.baseModel.model; 93 | switch (model.baseModel.format) 94 | { 95 | case ModelFormat::kCAFFE: 96 | { 97 | using namespace nvcaffeparser1; 98 | parser.caffeParser.reset(createCaffeParser()); 99 | CaffeBufferShutter bufferShutter; 100 | const auto blobNameToTensor = parser.caffeParser->parse(model.prototxt.c_str(), modelName.empty() ? nullptr : modelName.c_str(), network, DataType::kFLOAT); 101 | if (!blobNameToTensor) 102 | { 103 | err << "Failed to parse caffe model or prototxt, tensors blob not found" << std::endl; 104 | parser.caffeParser.reset(); 105 | break; 106 | } 107 | 108 | for (const auto& s : model.outputs) 109 | { 110 | if (blobNameToTensor->find(s.c_str()) == nullptr) 111 | { 112 | err << "Could not find output blob " << s << std::endl; 113 | parser.caffeParser.reset(); 114 | break; 115 | } 116 | network.markOutput(*blobNameToTensor->find(s.c_str())); 117 | } 118 | break; 119 | } 120 | case ModelFormat::kUFF: 121 | { 122 | using namespace nvuffparser; 123 | parser.uffParser.reset(createUffParser()); 124 | UffBufferShutter bufferShutter; 125 | for (const auto& s : model.uffInputs.inputs) 126 | { 127 | if (!parser.uffParser->registerInput(s.first.c_str(), s.second, model.uffInputs.NHWC ? UffInputOrder::kNHWC : UffInputOrder::kNCHW)) 128 | { 129 | err << "Failed to register input " << s.first << std::endl; 130 | parser.uffParser.reset(); 131 | break; 132 | } 133 | } 134 | 135 | for (const auto& s : model.outputs) 136 | { 137 | if (!parser.uffParser->registerOutput(s.c_str())) 138 | { 139 | err << "Failed to register output " << s << std::endl; 140 | parser.uffParser.reset(); 141 | break; 142 | } 143 | } 144 | 145 | if (!parser.uffParser->parse(model.baseModel.model.c_str(), network)) 146 | { 147 | err << "Failed to parse uff file" << std::endl; 148 | parser.uffParser.reset(); 149 | break; 150 | } 151 | break; 152 | } 153 | case ModelFormat::kONNX: 154 | { 155 | using namespace nvonnxparser; 156 | parser.onnxParser.reset(createParser(network, gLogger.getTRTLogger())); 157 | if (!parser.onnxParser->parseFromFile(model.baseModel.model.c_str(), static_cast(gLogger.getReportableSeverity()))) 158 | { 159 | err << "Failed to parse onnx file" << std::endl; 160 | parser.onnxParser.reset(); 161 | } 162 | break; 163 | } 164 | case ModelFormat::kANY: 165 | break; 166 | } 167 | 168 | return parser; 169 | 170 | } 171 | 172 | namespace 173 | { 174 | 175 | class RndInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 176 | { 177 | public: 178 | RndInt8Calibrator(int batches, const std::string& cacheFile, const nvinfer1::INetworkDefinition& network, std::ostream& err); 179 | 180 | ~RndInt8Calibrator() 181 | { 182 | for (auto& elem : mInputDeviceBuffers) 183 | { 184 | cudaCheck(cudaFree(elem.second), mErr); 185 | } 186 | } 187 | 188 | bool getBatch(void* bindings[], const char* names[], int nbBindings) override; 189 | 190 | int getBatchSize() const override { return 1; } 191 | 192 | const void* readCalibrationCache(size_t& length) override; 193 | 194 | virtual void writeCalibrationCache(const void*, size_t) override {} 195 | 196 | private: 197 | int mBatches{}; 198 | int mCurrentBatch{}; 199 | std::string mCacheFile; 200 | std::map mInputDeviceBuffers; 201 | std::vector mCalibrationCache; 202 | std::ostream& mErr; 203 | }; 204 | 205 | RndInt8Calibrator::RndInt8Calibrator(int batches, const std::string& cacheFile, const INetworkDefinition& network, std::ostream& err) 206 | : mBatches(batches), mCurrentBatch(0), mCacheFile(cacheFile), mErr(err) 207 | { 208 | std::default_random_engine generator; 209 | std::uniform_real_distribution distribution(-1.0F, 1.0F); 210 | auto gen = [&generator, &distribution]() { return distribution(generator); }; 211 | 212 | for (int i = 0; i < network.getNbInputs(); i++) 213 | { 214 | auto input = network.getInput(i); 215 | int elemCount = volume(input->getDimensions()); 216 | std::vector rnd_data(elemCount); 217 | std::generate_n(rnd_data.begin(), elemCount, gen); 218 | 219 | void* data; 220 | cudaCheck(cudaMalloc(&data, elemCount * sizeof(float)), mErr); 221 | cudaCheck(cudaMemcpy(data, rnd_data.data(), elemCount * sizeof(float), cudaMemcpyHostToDevice), mErr); 222 | 223 | mInputDeviceBuffers.insert(std::make_pair(input->getName(), data)); 224 | } 225 | } 226 | 227 | bool RndInt8Calibrator::getBatch(void* bindings[], const char* names[], int nbBindings) 228 | { 229 | if (mCurrentBatch >= mBatches) 230 | { 231 | return false; 232 | } 233 | 234 | for (int i = 0; i < nbBindings; ++i) 235 | { 236 | bindings[i] = mInputDeviceBuffers[names[i]]; 237 | } 238 | 239 | ++mCurrentBatch; 240 | 241 | return true; 242 | } 243 | 244 | const void* RndInt8Calibrator::readCalibrationCache(size_t& length) 245 | { 246 | mCalibrationCache.clear(); 247 | std::ifstream input(mCacheFile, std::ios::binary); 248 | input >> std::noskipws; 249 | if (input.good()) 250 | { 251 | std::copy(std::istream_iterator(input), std::istream_iterator(), 252 | std::back_inserter(mCalibrationCache)); 253 | } 254 | 255 | return mCalibrationCache.size() ? mCalibrationCache.data() : nullptr; 256 | } 257 | 258 | void setTensorScales(const INetworkDefinition& network, float inScales = 2.0f, float outScales = 4.0f) 259 | { 260 | // Ensure that all layer inputs have a scale. 261 | for (int l = 0; l < network.getNbLayers(); l++) 262 | { 263 | auto layer = network.getLayer(l); 264 | for (int i = 0; i < layer->getNbInputs(); i++) 265 | { 266 | ITensor* input{layer->getInput(i)}; 267 | // Optional inputs are nullptr here and are from RNN layers. 268 | if (input && !input->dynamicRangeIsSet()) 269 | { 270 | input->setDynamicRange(-inScales, inScales); 271 | } 272 | } 273 | for (int o = 0; o < layer->getNbOutputs(); o++) 274 | { 275 | ITensor* output{layer->getOutput(o)}; 276 | // Optional outputs are nullptr here and are from RNN layers. 277 | if (output && !output->dynamicRangeIsSet()) 278 | { 279 | // Pooling must have the same input and output scales. 280 | if (layer->getType() == LayerType::kPOOLING) 281 | { 282 | output->setDynamicRange(-inScales, inScales); 283 | } 284 | else 285 | { 286 | output->setDynamicRange(-outScales, outScales); 287 | } 288 | } 289 | } 290 | } 291 | } 292 | 293 | } 294 | 295 | ICudaEngine* networkToEngine(const BuildOptions& build, const SystemOptions& sys, IBuilder& builder, INetworkDefinition& network, std::ostream& err) 296 | { 297 | unique_ptr config{builder.createBuilderConfig()}; 298 | 299 | IOptimizationProfile* profile{nullptr}; 300 | if (build.maxBatch) 301 | { 302 | builder.setMaxBatchSize(build.maxBatch); 303 | } 304 | else 305 | { 306 | if (!build.shapes.empty()) 307 | { 308 | profile = builder.createOptimizationProfile(); 309 | } 310 | } 311 | 312 | for (unsigned int i = 0, n = network.getNbInputs(); i < n; i++) 313 | { 314 | // Set formats and data types of inputs 315 | auto input = network.getInput(i); 316 | if (!build.inputFormats.empty()) 317 | { 318 | input->setType(build.inputFormats[i].first); 319 | input->setAllowedFormats(build.inputFormats[i].second); 320 | } 321 | else 322 | { 323 | input->setType(DataType::kFLOAT); 324 | input->setAllowedFormats(1U << static_cast(TensorFormat::kLINEAR)); 325 | } 326 | 327 | if (profile) 328 | { 329 | Dims dims = input->getDimensions(); 330 | if (std::any_of(dims.d + 1, dims.d + dims.nbDims, [](int dim){ return dim == -1; })) 331 | { 332 | err << "Only dynamic batch dimension is currently supported, other dimensions must be static" << std::endl; 333 | return nullptr; 334 | } 335 | dims.d[0] = -1; 336 | Dims profileDims = dims; 337 | auto shape = build.shapes.find(input->getName()); 338 | if (shape == build.shapes.end()) 339 | { 340 | err << "Dynamic dimensions required for input " << input->getName() << std::endl; 341 | return nullptr; 342 | } 343 | profileDims.d[0] = shape->second[static_cast(OptProfileSelector::kMIN)].d[0]; 344 | profile->setDimensions(input->getName(), OptProfileSelector::kMIN, profileDims); 345 | profileDims.d[0] = shape->second[static_cast(OptProfileSelector::kOPT)].d[0]; 346 | profile->setDimensions(input->getName(), OptProfileSelector::kOPT, profileDims); 347 | profileDims.d[0] = shape->second[static_cast(OptProfileSelector::kMAX)].d[0]; 348 | profile->setDimensions(input->getName(), OptProfileSelector::kMAX, profileDims); 349 | 350 | input->setDimensions(dims); 351 | } 352 | } 353 | 354 | if (profile) 355 | { 356 | if (!profile->isValid()) 357 | { 358 | err << "Required optimization profile is invalid" << std::endl; 359 | return nullptr; 360 | } 361 | config->addOptimizationProfile(profile); 362 | } 363 | 364 | for (unsigned int i = 0, n = network.getNbOutputs(); i < n; i++) 365 | { 366 | // Set formats and data types of outputs 367 | auto output = network.getOutput(i); 368 | if (!build.outputFormats.empty()) 369 | { 370 | output->setType(build.outputFormats[i].first); 371 | output->setAllowedFormats(build.outputFormats[i].second); 372 | } 373 | else 374 | { 375 | output->setType(DataType::kFLOAT); 376 | output->setAllowedFormats(1U << static_cast(TensorFormat::kLINEAR)); 377 | } 378 | } 379 | 380 | config->setMaxWorkspaceSize(static_cast(build.workspace) << 20); 381 | 382 | if (build.fp16) 383 | { 384 | config->setFlag(BuilderFlag::kFP16); 385 | } 386 | 387 | if (build.int8) 388 | { 389 | config->setFlag(BuilderFlag::kINT8); 390 | } 391 | 392 | auto isInt8 = [](const IOFormat& format){ return format.first == DataType::kINT8; }; 393 | auto int8IO = std::count_if(build.inputFormats.begin(), build.inputFormats.end(), isInt8) + 394 | std::count_if(build.outputFormats.begin(), build.outputFormats.end(), isInt8); 395 | 396 | if ((build.int8 && build.calibration.empty()) || int8IO) 397 | { 398 | // Explicitly set int8 scales if no calibrator is provided and if I/O tensors use int8, 399 | // because auto calibration does not support this case. 400 | setTensorScales(network); 401 | } 402 | else if (build.int8) 403 | { 404 | config->setInt8Calibrator(new RndInt8Calibrator(1, build.calibration, network, err)); 405 | } 406 | 407 | if (build.safe) 408 | { 409 | config->setEngineCapability(sys.DLACore != -1 ? EngineCapability::kSAFE_DLA : EngineCapability::kSAFE_GPU); 410 | } 411 | 412 | if (sys.DLACore != -1) 413 | { 414 | if (sys.DLACore < builder.getNbDLACores()) 415 | { 416 | config->setDefaultDeviceType(DeviceType::kDLA); 417 | config->setDLACore(sys.DLACore); 418 | config->setFlag(BuilderFlag::kSTRICT_TYPES); 419 | 420 | if (sys.fallback) 421 | { 422 | config->setFlag(BuilderFlag::kGPU_FALLBACK); 423 | } 424 | if (!build.int8) 425 | { 426 | config->setFlag(BuilderFlag::kFP16); 427 | } 428 | } 429 | else 430 | { 431 | err << "Cannot create DLA engine, " << sys.DLACore << " not available" << std::endl; 432 | return nullptr; 433 | } 434 | } 435 | 436 | return builder.buildEngineWithConfig(network, *config); 437 | } 438 | 439 | ICudaEngine* modelToEngine(const ModelOptions& model, const BuildOptions& build, const SystemOptions& sys, std::ostream& err) 440 | { 441 | unique_ptr builder{createInferBuilder(gLogger.getTRTLogger())}; 442 | if (builder == nullptr) 443 | { 444 | err << "Builder creation failed" << std::endl; 445 | return nullptr; 446 | } 447 | auto batchFlag = (build.maxBatch ? 0U : 1U) << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 448 | unique_ptr network{builder->createNetworkV2(batchFlag)}; 449 | if (!network) 450 | { 451 | err << "Network creation failed" << std::endl; 452 | return nullptr; 453 | } 454 | Parser parser = modelToNetwork(model, *network, err); 455 | if (!parser) 456 | { 457 | err << "Parsing model failed" << std::endl; 458 | return nullptr; 459 | } 460 | 461 | return networkToEngine(build, sys, *builder, *network, err); 462 | } 463 | 464 | ICudaEngine* loadEngine(const std::string& engine, int DLACore, std::ostream& err) 465 | { 466 | std::ifstream engineFile(engine, std::ios::binary); 467 | if (!engineFile) 468 | { 469 | err << "Error opening engine file: " << engine << std::endl; 470 | return nullptr; 471 | } 472 | 473 | engineFile.seekg(0, engineFile.end); 474 | long int fsize = engineFile.tellg(); 475 | engineFile.seekg(0, engineFile.beg); 476 | 477 | std::vector engineData(fsize); 478 | engineFile.read(engineData.data(), fsize); 479 | if (!engineFile) 480 | { 481 | err << "Error loading engine file: " << engine << std::endl; 482 | return nullptr; 483 | } 484 | 485 | unique_ptr runtime{createInferRuntime(gLogger.getTRTLogger())}; 486 | if (DLACore != -1) 487 | { 488 | runtime->setDLACore(DLACore); 489 | } 490 | 491 | return runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr); 492 | } 493 | 494 | bool saveEngine(const ICudaEngine& engine, const std::string& fileName, std::ostream& err) 495 | { 496 | std::ofstream engineFile(fileName, std::ios::binary); 497 | if (!engineFile) 498 | { 499 | err << "Cannot open engine file: " << fileName << std::endl; 500 | return false; 501 | } 502 | 503 | unique_ptr serializedEngine{engine.serialize()}; 504 | if (serializedEngine == nullptr) 505 | { 506 | err << "Engine serialization failed" << std::endl; 507 | return false; 508 | } 509 | 510 | engineFile.write(static_cast(serializedEngine->data()), serializedEngine->size()); 511 | return !engineFile.fail(); 512 | } 513 | 514 | } // namespace sample 515 | -------------------------------------------------------------------------------- /common/sampleEngines.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef TRT_SAMPLE_ENGINES_H 51 | #define TRT_SAMPLE_ENGINES_H 52 | 53 | #include 54 | 55 | #include "NvInfer.h" 56 | #include "NvCaffeParser.h" 57 | #include "NvOnnxParser.h" 58 | #include "NvUffParser.h" 59 | 60 | #include "sampleUtils.h" 61 | 62 | namespace sample 63 | { 64 | 65 | struct Parser 66 | { 67 | unique_ptr caffeParser; 68 | unique_ptr uffParser; 69 | unique_ptr onnxParser; 70 | 71 | operator bool() const { return caffeParser || uffParser || onnxParser; } 72 | }; 73 | 74 | //! 75 | //! \brief Generate a network definition for a given model 76 | //! 77 | //! \return Parser The parser used to initialize the network and that holds the weights for the network, or an invalid parser (the returned parser converts to false if tested) 78 | //! 79 | //! \see Parser::operator bool() 80 | //! 81 | Parser modelToNetwork(const ModelOptions& model, nvinfer1::INetworkDefinition& network, std::ostream& err); 82 | 83 | //! 84 | //! \brief Create an engine for a network defintion 85 | //! 86 | //! \return Pointer to the engine created or nullptr if the creation failed 87 | //! 88 | nvinfer1::ICudaEngine* networkToEngine(const BuildOptions& build, const SystemOptions& sys, nvinfer1::IBuilder& builder, nvinfer1::INetworkDefinition& network, std::ostream& err); 89 | 90 | //! 91 | //! \brief Create an engine for a given model 92 | //! 93 | //! \return Pointer to the engine created or nullptr if the creation failed 94 | //! 95 | nvinfer1::ICudaEngine* modelToEngine(const ModelOptions& model, const BuildOptions& build, const SystemOptions& sys, std::ostream& err); 96 | 97 | //! 98 | //! \brief Load a serialized engine 99 | //! 100 | //! \return Pointer to the engine loaded or nullptr if the operation failed 101 | //! 102 | nvinfer1::ICudaEngine* loadEngine(const std::string& engine, int DLACore, std::ostream& err); 103 | 104 | //! 105 | //! \brief Save an engine into a file 106 | //! 107 | //! \return boolean Return true if the engine was successfully saved 108 | //! 109 | bool saveEngine(const nvinfer1::ICudaEngine& engine, const std::string& fileName, std::ostream& err); 110 | 111 | } // namespace sample 112 | 113 | #endif // TRT_SAMPLE_ENGINES_H 114 | -------------------------------------------------------------------------------- /common/sampleOptions.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef TRT_SAMPLE_OPTIONS_H 51 | #define TRT_SAMPLE_OPTIONS_H 52 | 53 | #include 54 | #include 55 | #include 56 | #include 57 | #include 58 | #include 59 | #include 60 | #include 61 | 62 | #include "NvInfer.h" 63 | 64 | namespace sample 65 | { 66 | 67 | // Build default params 68 | constexpr int defaultMaxBatch{1}; 69 | constexpr int defaultWorkspace{16}; 70 | constexpr int defaultMinTiming{1}; 71 | constexpr int defaultAvgTiming{8}; 72 | 73 | // System default params 74 | constexpr int defaultDevice{0}; 75 | 76 | // Inference default params 77 | constexpr int defaultBatch{1}; 78 | constexpr int defaultStreams{1}; 79 | constexpr int defaultIterations{10}; 80 | constexpr int defaultWarmUp{200}; 81 | constexpr int defaultDuration{10}; 82 | constexpr int defaultSleep{0}; 83 | 84 | // Reporting default params 85 | constexpr int defaultAvgRuns{10}; 86 | constexpr float defaultPercentile{99}; 87 | 88 | enum class ModelFormat {kANY, kCAFFE, kONNX, kUFF}; 89 | 90 | using Arguments = std::unordered_multimap; 91 | 92 | using IOFormat = std::pair; 93 | 94 | using ShapeRange = std::array()>; 95 | 96 | struct Options 97 | { 98 | virtual void parse(Arguments& arguments) = 0; 99 | }; 100 | 101 | struct BaseModelOptions: public Options 102 | { 103 | ModelFormat format{ModelFormat::kANY}; 104 | std::string model; 105 | 106 | void parse(Arguments& arguments) override; 107 | 108 | static void help(std::ostream& out); 109 | }; 110 | 111 | struct UffInput: public Options 112 | { 113 | std::vector> inputs; 114 | bool NHWC{false}; 115 | 116 | void parse(Arguments& arguments) override; 117 | 118 | static void help(std::ostream& out); 119 | }; 120 | 121 | struct ModelOptions: public Options 122 | { 123 | BaseModelOptions baseModel; 124 | std::string prototxt; 125 | std::vector outputs; 126 | UffInput uffInputs; 127 | 128 | void parse(Arguments& arguments) override; 129 | 130 | static void help(std::ostream& out); 131 | }; 132 | 133 | struct BuildOptions: public Options 134 | { 135 | //bool explicitBatch{false}; 136 | int maxBatch{defaultMaxBatch}; // Parsing sets maxBatch to 0 if explicitBatch is true 137 | int workspace{defaultWorkspace}; 138 | int minTiming{defaultMinTiming}; 139 | int avgTiming{defaultAvgTiming}; 140 | bool fp16{false}; 141 | bool int8{false}; 142 | bool safe{false}; 143 | bool save{false}; 144 | bool load{false}; 145 | std::string engine; 146 | std::string calibration; 147 | std::unordered_map shapes; 148 | std::vector inputFormats; 149 | std::vector outputFormats; 150 | 151 | void parse(Arguments& arguments) override; 152 | 153 | static void help(std::ostream& out); 154 | }; 155 | 156 | struct SystemOptions: public Options 157 | { 158 | int device{defaultDevice}; 159 | int DLACore{-1}; 160 | bool fallback{false}; 161 | std::vector plugins; 162 | 163 | void parse(Arguments& arguments) override; 164 | 165 | static void help(std::ostream& out); 166 | }; 167 | 168 | struct InferenceOptions: public Options 169 | { 170 | int batch{defaultBatch}; // Parsing sets batch to 0 is shapes is not empty 171 | int iterations{defaultIterations}; 172 | int warmup{defaultWarmUp}; 173 | int duration{defaultDuration}; 174 | int sleep{defaultSleep}; 175 | int streams{defaultStreams}; 176 | bool spin{false}; 177 | bool threads{true}; 178 | bool graph{false}; 179 | bool skip{false}; 180 | std::unordered_map shapes; 181 | 182 | void parse(Arguments& arguments) override; 183 | 184 | static void help(std::ostream& out); 185 | }; 186 | 187 | struct ReportingOptions: public Options 188 | { 189 | bool verbose{false}; 190 | int avgs{defaultAvgRuns}; 191 | float percentile{defaultPercentile}; 192 | bool output{false}; 193 | bool profile{false}; 194 | std::string exportTimes{}; 195 | std::string exportProfile{}; 196 | 197 | void parse(Arguments& arguments) override; 198 | 199 | static void help(std::ostream& out); 200 | }; 201 | 202 | struct AllOptions: public Options 203 | { 204 | ModelOptions model; 205 | BuildOptions build; 206 | SystemOptions system; 207 | InferenceOptions inference; 208 | ReportingOptions reporting; 209 | bool helps{false}; 210 | 211 | void parse(Arguments& arguments) override; 212 | 213 | static void help(std::ostream& out); 214 | }; 215 | 216 | Arguments argsToArgumentsMap(int argc, char* argv[]); 217 | 218 | bool parseHelp(Arguments& arguments); 219 | 220 | void helpHelp(std::ostream& out); 221 | 222 | // Functions to print options 223 | 224 | std::ostream& operator<<(std::ostream& os, const BaseModelOptions& options); 225 | 226 | std::ostream& operator<<(std::ostream& os, const UffInput& input); 227 | 228 | std::ostream& operator<<(std::ostream& os, const IOFormat& format); 229 | 230 | std::ostream& operator<<(std::ostream& os, const nvinfer1::Dims& dims); 231 | 232 | std::ostream& operator<<(std::ostream& os, const ShapeRange& dims); 233 | 234 | std::ostream& operator<<(std::ostream& os, const ModelOptions& options); 235 | 236 | std::ostream& operator<<(std::ostream& os, const BuildOptions& options); 237 | 238 | std::ostream& operator<<(std::ostream& os, const SystemOptions& options); 239 | 240 | std::ostream& operator<<(std::ostream& os, const InferenceOptions& options); 241 | 242 | std::ostream& operator<<(std::ostream& os, const ReportingOptions& options); 243 | 244 | std::ostream& operator<<(std::ostream& os, const AllOptions& options); 245 | 246 | // Utils to extract options 247 | 248 | inline std::vector splitToStringVec(const std::string& option, char separator) 249 | { 250 | std::vector options; 251 | 252 | for(size_t start = 0; start < option.length(); ) 253 | { 254 | size_t separatorIndex = option.find(separator, start); 255 | if (separatorIndex == std::string::npos) 256 | { 257 | separatorIndex = option.length(); 258 | } 259 | options.emplace_back(option.substr(start, separatorIndex - start)); 260 | start = separatorIndex + 1; 261 | } 262 | 263 | return options; 264 | } 265 | 266 | template 267 | inline T stringToValue(const std::string& option) 268 | { 269 | return T{option}; 270 | } 271 | 272 | template <> 273 | inline int stringToValue(const std::string& option) 274 | { 275 | return std::stoi(option); 276 | } 277 | 278 | template <> 279 | inline float stringToValue(const std::string& option) 280 | { 281 | return std::stof(option); 282 | } 283 | 284 | template <> 285 | inline bool stringToValue(const std::string& option) 286 | { 287 | return true; 288 | } 289 | 290 | template <> 291 | inline nvinfer1::Dims stringToValue(const std::string& option) 292 | { 293 | nvinfer1::Dims dims; 294 | dims.nbDims = 0; 295 | std::vector dimsStrings = splitToStringVec(option, 'x'); 296 | for (const auto& d : dimsStrings) 297 | { 298 | if (d == "*") 299 | { 300 | break; 301 | } 302 | dims.d[dims.nbDims] = stringToValue(d); 303 | ++dims.nbDims; 304 | } 305 | return dims; 306 | } 307 | 308 | template <> 309 | inline nvinfer1::DataType stringToValue(const std::string& option) 310 | { 311 | const std::unordered_map strToDT{{"fp32", nvinfer1::DataType::kFLOAT}, {"fp16", nvinfer1::DataType::kHALF}, 312 | {"int8", nvinfer1::DataType::kINT8}, {"int32", nvinfer1::DataType::kINT32}}; 313 | auto dt = strToDT.find(option); 314 | if (dt == strToDT.end()) 315 | { 316 | throw std::invalid_argument("Invalid DataType " + option); 317 | } 318 | return dt->second; 319 | } 320 | 321 | template <> 322 | inline nvinfer1::TensorFormats stringToValue(const std::string& option) 323 | { 324 | std::vector optionStrings = splitToStringVec(option, '+'); 325 | const std::unordered_map strToFmt{{"chw", nvinfer1::TensorFormat::kLINEAR}, {"chw2", nvinfer1::TensorFormat::kCHW2}, 326 | {"chw4", nvinfer1::TensorFormat::kCHW4}, {"hwc8", nvinfer1::TensorFormat::kHWC8}, 327 | {"chw16", nvinfer1::TensorFormat::kCHW16}, {"chw32", nvinfer1::TensorFormat::kCHW32}}; 328 | nvinfer1::TensorFormats formats{}; 329 | for (auto f : optionStrings) 330 | { 331 | auto tf = strToFmt.find(f); 332 | if (tf == strToFmt.end()) 333 | { 334 | throw std::invalid_argument(std::string("Invalid TensorFormat ") + f); 335 | } 336 | formats |= 1U << int(tf->second); 337 | } 338 | 339 | return formats; 340 | } 341 | 342 | template <> 343 | inline IOFormat stringToValue(const std::string& option) 344 | { 345 | IOFormat ioFormat{}; 346 | size_t colon = option.find(':'); 347 | 348 | if (colon == std::string::npos) 349 | { 350 | throw std::invalid_argument(std::string("Invalid IOFormat ") + option); 351 | } 352 | ioFormat.first = stringToValue(option.substr(0, colon)); 353 | ioFormat.second = stringToValue(option.substr(colon+1)); 354 | 355 | return ioFormat; 356 | } 357 | 358 | inline const char* boolToEnabled(bool enable) 359 | { 360 | return enable ? "Enabled" : "Disabled"; 361 | } 362 | 363 | template 364 | inline bool checkEraseOption(Arguments& arguments, const std::string& option, T& value) 365 | { 366 | auto match = arguments.find(option); 367 | if (match != arguments.end()) 368 | { 369 | value = stringToValue(match->second); 370 | arguments.erase(match); 371 | return true; 372 | } 373 | 374 | return false; 375 | } 376 | 377 | template 378 | inline bool checkEraseRepeatedOption(Arguments& arguments, const std::string& option, std::vector& values) 379 | { 380 | auto match = arguments.equal_range(option); 381 | if (match.first == match.second) 382 | { 383 | return false; 384 | } 385 | auto addValue = [&values](Arguments::value_type& value) {values.emplace_back(stringToValue(value.second));}; 386 | std::for_each(match.first, match.second, addValue); 387 | arguments.erase(match.first, match.second); 388 | return true; 389 | } 390 | 391 | } // namespace sample 392 | 393 | #endif // TRT_SAMPLES_OPTIONS_H 394 | -------------------------------------------------------------------------------- /common/sampleUtils.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright 1993-2019 NVIDIA Corporation. All rights reserved. 3 | * 4 | * NOTICE TO LICENSEE: 5 | * 6 | * This source code and/or documentation ("Licensed Deliverables") are 7 | * subject to NVIDIA intellectual property rights under U.S. and 8 | * international Copyright laws. 9 | * 10 | * These Licensed Deliverables contained herein is PROPRIETARY and 11 | * CONFIDENTIAL to NVIDIA and is being provided under the terms and 12 | * conditions of a form of NVIDIA software license agreement by and 13 | * between NVIDIA and Licensee ("License Agreement") or electronically 14 | * accepted by Licensee. Notwithstanding any terms or conditions to 15 | * the contrary in the License Agreement, reproduction or disclosure 16 | * of the Licensed Deliverables to any third party without the express 17 | * written consent of NVIDIA is prohibited. 18 | * 19 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 20 | * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE 21 | * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS 22 | * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. 23 | * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED 24 | * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, 25 | * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. 26 | * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE 27 | * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY 28 | * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY 29 | * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 30 | * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS 31 | * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE 32 | * OF THESE LICENSED DELIVERABLES. 33 | * 34 | * U.S. Government End Users. These Licensed Deliverables are a 35 | * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT 36 | * 1995), consisting of "commercial computer software" and "commercial 37 | * computer software documentation" as such terms are used in 48 38 | * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government 39 | * only as a commercial end item. Consistent with 48 C.F.R.12.212 and 40 | * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all 41 | * U.S. Government End Users acquire the Licensed Deliverables with 42 | * only those rights set forth herein. 43 | * 44 | * Any use of the Licensed Deliverables in individual and commercial 45 | * software must include, in the user documentation and internal 46 | * comments to the code, the above Disclaimer and U.S. Government End 47 | * Users Notice. 48 | */ 49 | 50 | #ifndef TRT_SAMPLE_UTILS_H 51 | #define TRT_SAMPLE_UTILS_H 52 | 53 | #include 54 | #include 55 | #include 56 | #include 57 | 58 | #include "NvInfer.h" 59 | 60 | namespace sample 61 | { 62 | 63 | inline void cudaCheck(cudaError_t ret, std::ostream& err = std::cerr) 64 | { 65 | if (ret != cudaSuccess) 66 | { 67 | err << "Cuda failure: " << ret << std::endl; 68 | abort(); 69 | } 70 | } 71 | 72 | template 73 | struct destroyer 74 | { 75 | void operator()(T* t) { t->destroy(); } 76 | }; 77 | 78 | template using unique_ptr = std::unique_ptr >; 79 | 80 | inline int64_t volume(const nvinfer1::Dims& d) 81 | { 82 | return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies()); 83 | } 84 | 85 | } // namespace sample 86 | 87 | #endif // TRT_SAMPLE_UTILS_H 88 | -------------------------------------------------------------------------------- /data/1492638000682869180/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/1.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/10.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/11.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/12.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/13.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/14.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/14.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/15.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/16.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/17.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/17.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/18.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/19.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/2.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/20.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/3.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/4.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/5.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/6.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/7.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/8.jpg -------------------------------------------------------------------------------- /data/1492638000682869180/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thoughtworks-hpc/PINetTensorrt/fae4d1d8bac53a0197d79b2cf569f6b7b4c0eba4/data/1492638000682869180/9.jpg --------------------------------------------------------------------------------