├── .github └── ISSUE_TEMPLATE │ └── bug_report.md ├── .gitignore ├── .gitmodules ├── CMakeLists.txt ├── ConditionalHelpers.cpp ├── ConditionalHelpers.hpp ├── ImporterContext.cpp ├── ImporterContext.hpp ├── LICENSE ├── LoopHelpers.cpp ├── LoopHelpers.hpp ├── ModelImporter.cpp ├── ModelImporter.hpp ├── ModelRefitter.cpp ├── ModelRefitter.hpp ├── NvOnnxParser.cpp ├── NvOnnxParser.h ├── OnnxAttrs.cpp ├── OnnxAttrs.hpp ├── README.md ├── RNNHelpers.cpp ├── RNNHelpers.hpp ├── ShapeTensor.cpp ├── ShapeTensor.hpp ├── ShapedWeights.cpp ├── ShapedWeights.hpp ├── Status.hpp ├── TensorOrWeights.cpp ├── TensorOrWeights.hpp ├── WeightsContext.cpp ├── WeightsContext.hpp ├── bfloat16.cpp ├── bfloat16.hpp ├── common.hpp ├── docs ├── Changelog.md ├── contributing.md ├── faq.md └── operators.md ├── errorHelpers.cpp ├── errorHelpers.hpp ├── getSupportedAPITest.cpp ├── half.h ├── ieee_half.h ├── importerUtils.cpp ├── importerUtils.hpp ├── libnvonnxparser.version ├── onnx2trt_common.hpp ├── onnx2trt_runtime.hpp ├── onnxErrorRecorder.cpp ├── onnxErrorRecorder.hpp ├── onnxOpCheckers.cpp ├── onnxOpCheckers.hpp ├── onnxOpImporters.cpp ├── onnxOpImporters.hpp ├── onnxProtoUtils.cpp ├── onnxProtoUtils.hpp ├── onnx_backend_test.py ├── onnx_tensorrt ├── __init__.py ├── backend.py └── tensorrt_engine.py ├── onnx_trt_backend.cpp ├── setup.py ├── toposort.hpp ├── weightUtils.cpp └── weightUtils.hpp /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve TRT 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Description 11 | 12 | 13 | 14 | 15 | ## Environment 16 | 17 | **TensorRT Version**: 18 | **ONNX-TensorRT Version / Branch**: 19 | **GPU Type**: 20 | **Nvidia Driver Version**: 21 | **CUDA Version**: 22 | **CUDNN Version**: 23 | **Operating System + Version**: 24 | **Python Version (if applicable)**: 25 | **TensorFlow + TF2ONNX Version (if applicable)**: 26 | **PyTorch Version (if applicable)**: 27 | **Baremetal or Container (if container which image + tag)**: 28 | 29 | 30 | ## Relevant Files 31 | 32 | 33 | 34 | 35 | ## Steps To Reproduce 36 | 37 | 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | onnx2trt 2 | 3 | # Compiled files 4 | *.so 5 | *.o* 6 | *.lo 7 | *.la 8 | *.a 9 | .deps/ 10 | 11 | # Backup files 12 | *~ 13 | *.bak* 14 | 15 | # Log files 16 | *.log 17 | *.prof 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | 23 | # Build 24 | build -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/onnx"] 2 | path = third_party/onnx 3 | url = https://github.com/onnx/onnx.git 4 | branch = v1.17.0 5 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | cmake_minimum_required(VERSION 3.13) 4 | project(onnx2trt LANGUAGES CXX C) 5 | 6 | set(ONNX2TRT_ROOT ${PROJECT_SOURCE_DIR}) 7 | # Set C++17 as standard for the whole project, as required by ONNX 1.16 8 | set(CMAKE_CXX_STANDARD 17) 9 | 10 | # Enable compiler warnings 11 | if (CMAKE_COMPILER_IS_GNUCC) 12 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-deprecated-declarations -Wno-unused-function") 13 | endif() 14 | if (MSVC) 15 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4") 16 | endif() 17 | 18 | # Build the libraries with -fPIC 19 | set(CMAKE_POSITION_INDEPENDENT_CODE ON) 20 | 21 | set(PARSER_LINKER_SCRIPT ${ONNX2TRT_ROOT}/libnvonnxparser.version) 22 | 23 | # Find length of source directory used to pad filename in Status.hpp 24 | string(LENGTH "${CMAKE_SOURCE_DIR}/" SOURCE_LENGTH) 25 | add_definitions("-DSOURCE_LENGTH=${SOURCE_LENGTH}") 26 | 27 | #-------------------------------------------------- 28 | # Version information 29 | #-------------------------------------------------- 30 | set(ONNX2TRT_MAJOR 10) 31 | set(ONNX2TRT_MINOR 11) 32 | set(ONNX2TRT_PATCH 0) 33 | set(ONNX2TRT_VERSION "${ONNX2TRT_MAJOR}.${ONNX2TRT_MINOR}.${ONNX2TRT_PATCH}" CACHE STRING "ONNX2TRT version") 34 | 35 | #-------------------------------------------------- 36 | # Build configurations, global to all projects 37 | #-------------------------------------------------- 38 | 39 | set(IMPORTER_SOURCES 40 | NvOnnxParser.cpp 41 | ModelImporter.cpp 42 | ModelRefitter.cpp 43 | onnxOpImporters.cpp 44 | ImporterContext.cpp 45 | importerUtils.cpp 46 | ShapedWeights.cpp 47 | ShapeTensor.cpp 48 | LoopHelpers.cpp 49 | RNNHelpers.cpp 50 | OnnxAttrs.cpp 51 | onnxErrorRecorder.cpp 52 | ConditionalHelpers.cpp 53 | bfloat16.cpp 54 | onnxOpCheckers.cpp 55 | onnxProtoUtils.cpp 56 | weightUtils.cpp 57 | WeightsContext.cpp 58 | TensorOrWeights.cpp 59 | errorHelpers.cpp 60 | ) 61 | 62 | if (BUILD_ONNXIFI) 63 | set(ONNXIFI_SOURCES onnx_trt_backend.cpp) 64 | endif() 65 | 66 | set(API_TESTS_SOURCES 67 | getSupportedAPITest.cpp 68 | ModelImporter.cpp 69 | ) 70 | 71 | # Find protobuf if it's not a target. 72 | if (NOT TARGET protobuf::libprotobuf) 73 | FIND_PACKAGE(Protobuf REQUIRED) 74 | endif() 75 | 76 | # Set protobuf libraries between full / lite. 77 | if (ONNX_USE_LITE_PROTO) 78 | add_definitions("-DUSE_LITE_PROTOBUF=1") 79 | set(PROTOBUF_LIBRARY "protobuf::libprotobuf-lite") 80 | else() 81 | set(PROTOBUF_LIBRARY "protobuf::libprotobuf") 82 | endif() 83 | 84 | if(NOT TARGET onnx_proto) 85 | # Note: This avoids libprotobuf.so complaining about name collisions at runtime 86 | if(NOT ONNX_NAMESPACE) 87 | set(ONNX_NAMESPACE "onnx2trt_onnx") 88 | endif() 89 | add_definitions("-DONNX_NAMESPACE=${ONNX_NAMESPACE}") 90 | add_subdirectory(third_party/onnx EXCLUDE_FROM_ALL) 91 | endif() 92 | 93 | # CUDA 94 | if (NOT CUDA_TOOLKIT_ROOT_DIR) 95 | set(CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda) 96 | endif() 97 | find_path(CUDA_INCLUDE_DIR cuda_runtime.h 98 | HINTS ${CUDA_TOOLKIT_ROOT_DIR} 99 | PATH_SUFFIXES include 100 | ) 101 | MESSAGE(STATUS "Found CUDA headers at ${CUDA_INCLUDE_DIR}") 102 | 103 | # TensorRT 104 | find_path(TENSORRT_INCLUDE_DIR NvInfer.h 105 | HINTS ${TENSORRT_ROOT} ${CUDA_TOOLKIT_ROOT_DIR} 106 | PATH_SUFFIXES include) 107 | MESSAGE(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") 108 | 109 | # TensorRT Python Headers 110 | find_path(TENSORRT_PYTHON_INCLUDE_DIR NvInferPythonPlugin.h 111 | HINTS ${TENSORRT_ROOT} 112 | PATH_SUFFIXES python/include/impl) 113 | 114 | # If header is not found, download it from open source release. 115 | if(NOT TENSORRT_PYTHON_INCLUDE_DIR) 116 | set(PLUGIN_URL "https://raw.githubusercontent.com/NVIDIA/TensorRT/refs/heads/release/${ONNX2TRT_MAJOR}.${ONNX2TRT_MINOR}/python/include/impl/NvInferPythonPlugin.h") 117 | set(FILE_DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/NvInferPythonPlugin.h") 118 | 119 | message(NOTICE "Required header NvInferPythonPlugin.h not found. Downloading from ${PLUGIN_URL} to ${FILE_DESTINATION}") 120 | 121 | file(DOWNLOAD ${PLUGIN_URL} ${FILE_DESTINATION} 122 | SHOW_PROGRESS 123 | STATUS DOWNLOAD_STATUS 124 | LOG DOWNLOAD_LOG 125 | ) 126 | 127 | list(GET DOWNLOAD_STATUS 0 STATUS_CODE) 128 | list(GET DOWNLOAD_STATUS 1 ERROR_MESSAGE) 129 | 130 | if(NOT STATUS_CODE EQUAL 0) 131 | message(FATAL_ERROR "Error downloading file: ${ERROR_MESSAGE}") 132 | endif() 133 | 134 | set(TENSORRT_PYTHON_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}) 135 | endif() 136 | 137 | message(NOTICE "Found TensorRT Python headers at ${TENSORRT_PYTHON_INCLUDE_DIR}") 138 | 139 | # Output dynamic library names depends on platform: 140 | if (MSVC) 141 | set(nvonnxparser_lib_name "nvonnxparser_${ONNX2TRT_MAJOR}") 142 | else() 143 | set(nvonnxparser_lib_name "nvonnxparser") 144 | endif() 145 | # Output static library name is the same cross-platform. 146 | set(nvonnxparser_lib_name_static "nvonnxparser_static") 147 | 148 | # -------------------------------- 149 | # Importer library 150 | # -------------------------------- 151 | add_library(${nvonnxparser_lib_name} SHARED ${IMPORTER_SOURCES}) 152 | target_include_directories(${nvonnxparser_lib_name} PUBLIC ${ONNX_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR} ${TENSORRT_PYTHON_INCLUDE_DIR} ${CUDA_INCLUDE_DIR}) 153 | target_link_libraries(${nvonnxparser_lib_name} PUBLIC onnx_proto ${PROTOBUF_LIBRARY}) 154 | set_target_properties(${nvonnxparser_lib_name} PROPERTIES 155 | VERSION ${ONNX2TRT_VERSION} 156 | SOVERSION ${ONNX2TRT_MAJOR} 157 | LINK_DEPENDS ${PARSER_LINKER_SCRIPT} 158 | LINK_FLAGS "-Wl,--version-script=${PARSER_LINKER_SCRIPT}" 159 | ARCHIVE_OUTPUT_DIRECTORY "${TRT_OUT_DIR}" 160 | LIBRARY_OUTPUT_DIRECTORY "${TRT_OUT_DIR}" 161 | RUNTIME_OUTPUT_DIRECTORY "${TRT_OUT_DIR}" 162 | ) 163 | add_library(${nvonnxparser_lib_name_static} STATIC ${IMPORTER_SOURCES}) 164 | target_include_directories(${nvonnxparser_lib_name_static} PUBLIC ${ONNX_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR} ${TENSORRT_PYTHON_INCLUDE_DIR} ${CUDA_INCLUDE_DIR}) 165 | target_link_libraries(${nvonnxparser_lib_name_static} PUBLIC onnx_proto ${PROTOBUF_LIBRARY}) 166 | set_target_properties(${nvonnxparser_lib_name_static} PROPERTIES 167 | ARCHIVE_OUTPUT_DIRECTORY "${TRT_OUT_DIR}" 168 | LIBRARY_OUTPUT_DIRECTORY "${TRT_OUT_DIR}" 169 | RUNTIME_OUTPUT_DIRECTORY "${TRT_OUT_DIR}" 170 | ) 171 | # -------------------------------- 172 | # Onnxifi library 173 | # -------------------------------- 174 | if(BUILD_ONNXIFI) 175 | add_library(trt_onnxify SHARED ${ONNXIFI_SOURCES}) 176 | target_include_directories(trt_onnxify PUBLIC ${CUDA_INCLUDE_DIR} ${ONNX_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR} ${TENSORRT_PYTHON_INCLUDE_DIR}) 177 | target_link_libraries(trt_onnxify PUBLIC ${nvonnxparser_lib_name_static} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS}) 178 | endif() 179 | 180 | # -------------------------------- 181 | # API Tests 182 | # -------------------------------- 183 | if (BUILD_API_TEST) 184 | add_executable(getSupportedAPITest ${API_TESTS_SOURCES}) 185 | target_include_directories(getSupportedAPITest PUBLIC ${ONNX_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIR}) 186 | target_link_libraries(getSupportedAPITest PUBLIC ${PROTOBUF_LIB} ${nvonnxparser_lib_name_static} ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS}) 187 | endif() 188 | 189 | # -------------------------------- 190 | # Installation 191 | # -------------------------------- 192 | install(TARGETS 193 | ${nvonnxparser_lib_name} 194 | ${nvonnxparser_lib_name_static} 195 | LIBRARY DESTINATION lib 196 | ARCHIVE DESTINATION lib 197 | ) 198 | 199 | install(FILES ${HEADERS} 200 | DESTINATION include 201 | ) 202 | 203 | SET(CPACK_GENERATOR "DEB") 204 | SET(CPACK_DEBIAN_PACKAGE_MAINTAINER "NVIDIA") #required 205 | SET(CPACK_PACKAGE_NAME "onnx-trt-dev") 206 | SET(CPACK_PACKAGE_VERSION "0.5.9") 207 | SET(CPACK_PACKAGE_VERSION_MAJOR "0") 208 | SET(CPACK_PACKAGE_VERSION_MINOR "5") 209 | SET(CPACK_PACKAGE_VERSION_PATCH "9") 210 | 211 | INCLUDE(CPack) 212 | -------------------------------------------------------------------------------- /ConditionalHelpers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "ConditionalHelpers.hpp" 6 | #include "ModelImporter.hpp" 7 | #include "importerUtils.hpp" 8 | #include "toposort.hpp" 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | // Search for a network Layer name in a SubgraphPortsMap. 14 | SubgraphPortsMap::const_iterator findLayer(const SubgraphPortsMap& inputs, const std::string layerName) 15 | { 16 | return std::find_if( 17 | inputs.begin(), inputs.end(), [&](const auto& item) { return layerName == item.first->getName(); }); 18 | } 19 | 20 | // Add an ConditionalInputLayer between `layer` and its inputs. 21 | // I.e. input[inIdx] -> layer ==> input[inIdx] -> ConditionalInputLayer -> layer. 22 | void addConditionalInputLayer(ImporterContext* ctx, nvinfer1::IIfConditional* conditional, InputsMap& inputsMap, 23 | nvinfer1::ILayer& layer, int32_t inIdx, ::ONNX_NAMESPACE::NodeProto const* node) 24 | { 25 | auto input = layer.getInput(inIdx); 26 | if (input == nullptr) 27 | { 28 | // Phantom input (an input that is really constant weights). 29 | return; 30 | } 31 | 32 | if (layer.getType() == nvinfer1::LayerType::kCONDITIONAL_OUTPUT) 33 | { 34 | return; 35 | } 36 | 37 | auto const name = input->getName(); 38 | auto it = inputsMap.find(name); 39 | nvinfer1::IIfConditionalInputLayer* inputLayer = nullptr; 40 | if (it == inputsMap.end()) 41 | { 42 | inputLayer = N_CHECK(conditional->addInput(*input)); 43 | inputsMap[name] = inputLayer; 44 | const std::string inputLayerName(name); 45 | ctx->registerLayer(inputLayer, inputLayerName + "_InputLayer", node); 46 | // Note: Since multiple conditionals may use the same external tensor, check unique names for output tensors of 47 | // IfConditionalInputLayers to avoid tensor name duplication. 48 | ctx->registerTensor( 49 | TensorOrWeights{N_CHECK(inputLayer->getOutput(0))}, inputLayerName + "_InputLayer_output", /*checkUniqueName*/ true); 50 | } 51 | else 52 | { 53 | // An InputLayer may in the inputsMap if it has several consumers. 54 | inputLayer = it->second; 55 | } 56 | auto ifOutput = N_CHECK(inputLayer->getOutput(0)); 57 | layer.setInput(inIdx, *ifOutput); 58 | }; 59 | 60 | // Take a snapshot of the network before and after parsing the subgraph and return a list 61 | // of newly added network layers. 62 | void importSubgraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& subgraph, 63 | std::vector& newLayers, std::vector& subgraphTensors) 64 | { 65 | auto net = ctx->network(); 66 | int32_t beforeSubgraph = net->getNbLayers(); 67 | 68 | // Establish scope for names local to the subgraph. 69 | NameScope nameScope(*ctx); 70 | 71 | std::vector errors{}; 72 | onnx2trt::parseGraph(ctx, subgraph, errors); 73 | 74 | for (int32_t i = 0; i < subgraph.output_size(); ++i) 75 | { 76 | std::string name = subgraph.output(i).name(); 77 | subgraphTensors.push_back(ctx->tensors().at(name)); 78 | } 79 | 80 | for (int32_t i = beforeSubgraph; i < net->getNbLayers(); i++) 81 | { 82 | newLayers.push_back(net->getLayer(i)); 83 | } 84 | } 85 | 86 | // Add an IConditionalInputLayer to `layer`'s inputs, if they don't already exist. 87 | void addConditionalInputIfNeeded(ImporterContext* ctx, nvinfer1::IIfConditional* conditional, InputsMap& inputsMap, 88 | nvinfer1::ILayer& layer, SubgraphPortsMap subgraphInputsMap, ::ONNX_NAMESPACE::NodeProto const* node) 89 | { 90 | // Return all of the layer's inputs that are external to the subgraph that 91 | // that the layer belongs to. 92 | auto getLayerExternalInputs = [&](std::string const& layerName) { 93 | std::set inIndices; 94 | auto iter = findLayer(subgraphInputsMap, layerName); 95 | if (iter != subgraphInputsMap.end()) 96 | { 97 | const auto& indicesSet = iter->second; 98 | inIndices.insert(indicesSet.begin(), indicesSet.end()); 99 | } 100 | 101 | return inIndices; 102 | }; 103 | 104 | const auto inIndices = getLayerExternalInputs(layer.getName()); 105 | for (auto inIdx : inIndices) 106 | { 107 | LOG_VERBOSE("Adding Input layer for " << layer.getName()); 108 | addConditionalInputLayer(ctx, conditional, inputsMap, layer, inIdx, node); 109 | } 110 | } 111 | 112 | // Add IConditionalInputLayers to `layer`'s inputs. 113 | void addIfInputLayers(ImporterContext* ctx, nvinfer1::IIfConditional* conditional, InputsMap& inputsMap, 114 | const std::vector& newLayers, ::ONNX_NAMESPACE::NodeProto const* node) 115 | { 116 | // Find all of the tensors entering the subgraph. 117 | SubgraphPortsMap externalInputs; 118 | getSubgraphInputs(newLayers, externalInputs); 119 | 120 | // Add a ConditionalInputLayer in front of each input that is external to the subgraph. 121 | for (const auto& layer : newLayers) 122 | { 123 | addConditionalInputIfNeeded(ctx, conditional, inputsMap, *layer, externalInputs, node); 124 | } 125 | } 126 | 127 | // Given a subgraph, find all of its external inputs (tensors entering the subgraph). 128 | void getSubgraphInputs(const std::vector& newLayers, SubgraphPortsMap& externalInputs) 129 | { 130 | using PortIndex = int32_t; 131 | using TensorsSet = std::unordered_set; 132 | TensorsSet outputTensors; 133 | TensorsSet inputTensors; 134 | 135 | // To determine which tensors are entering or exiting the given graph, we first collect the sets of all input and 136 | // output tensors. Then we categorize the tensors according to this logic: 137 | // Entering tensors := {inputs} - {outputs} 138 | // Exiting tensors := {outputs} - {inputs} 139 | 140 | // Collect all input and output tensors belonging to nodes in the graph. 141 | 142 | auto getTensors = [](nvinfer1::ILayer const* l, bool const input, auto inserter) { 143 | auto const count = input ? l->getNbInputs() : l->getNbOutputs(); 144 | for (int32_t i = 0; i < count; i++) 145 | { 146 | inserter(input ? l->getInput(i) : l->getOutput(i)); 147 | } 148 | }; 149 | 150 | for (const auto& l : newLayers) 151 | { 152 | getTensors(l, false, [&](nvinfer1::ITensor* t) { outputTensors.insert(t); }); 153 | getTensors(l, true, [&](nvinfer1::ITensor* t) { inputTensors.insert(t); }); 154 | } 155 | 156 | using TensorsVec = std::vector; 157 | auto getInputs = [&](nvinfer1::ILayer const* l, TensorsVec& res) { 158 | getTensors(l, true, [&](nvinfer1::ITensor* t) { res.emplace_back(t); }); 159 | }; 160 | 161 | // Retrieve the list of tensors either exiting or entering the subgraph. 162 | auto filterTensors = [&](TensorsSet const& tensors, auto getNodeAccessor) { 163 | for (nvinfer1::ILayer const* l : newLayers) 164 | { 165 | PortIndex i = 0; 166 | 167 | TensorsVec nodeAccessor; 168 | getNodeAccessor(l, nodeAccessor); 169 | for (const auto& tensor : nodeAccessor) 170 | { 171 | if (tensor == nullptr) 172 | { 173 | continue; 174 | } 175 | if (tensors.count(tensor) == 0) 176 | { 177 | externalInputs[l].insert(i); 178 | } 179 | i++; 180 | } 181 | } 182 | }; 183 | 184 | filterTensors(outputTensors, getInputs); 185 | } 186 | 187 | } // namespace onnx2trt 188 | -------------------------------------------------------------------------------- /ConditionalHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | * 4 | * Helper functions used for importing the ONNX If-operator follow below. 5 | * 6 | */ 7 | 8 | #pragma once 9 | 10 | #include "ImporterContext.hpp" 11 | #include "Status.hpp" 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | namespace onnx2trt 19 | { 20 | 21 | using NodeName = std::string; 22 | using LayerName = std::string; 23 | using InputIndex = int32_t; 24 | 25 | // A SubgraphPortsMap maps inputs' ports of each layer in an ONNX graph. 26 | using SubgraphPortsMap = std::unordered_map>; 27 | 28 | // Given a subgraph, find all of its external inputs (tensors entering the subgraph). 29 | void getSubgraphInputs(const std::vector& newLayers, SubgraphPortsMap& externalInputs); 30 | 31 | // Take a snapshot of the network before and after parsing the subgraph and return a list 32 | // of newly added network layers. 33 | void importSubgraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& subgraph, 34 | std::vector& newLayers, std::vector& subgraphTensors); 35 | 36 | // An InputsMap tracks which IIfConditionalInputLayer we've added to a layer's inputs, 37 | // so that we can reuse them if needed. 38 | using InputsMap = std::unordered_map; 39 | 40 | // Add IIfConditionalInputLayers to the inputs of the subgraph indicated by `subgraph`. 41 | void addIfInputLayers(ImporterContext* ctx, nvinfer1::IIfConditional* conditional, InputsMap& inputsMap, 42 | const std::vector& newLayers, ::ONNX_NAMESPACE::NodeProto const* node); 43 | 44 | } // namespace onnx2trt 45 | -------------------------------------------------------------------------------- /ImporterContext.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "ImporterContext.hpp" 6 | #include "NvInferVersion.h" 7 | #include "importerUtils.hpp" 8 | #include "weightUtils.hpp" 9 | #include 10 | 11 | #if !defined(_WIN32) 12 | #include 13 | #if defined(__linux__) 14 | #include 15 | #endif 16 | #else // defined(_WIN32) 17 | #include 18 | #endif // !defined(_WIN32) 19 | 20 | #define RT_ASSERT(cond) \ 21 | do \ 22 | { \ 23 | if (!(cond)) \ 24 | { \ 25 | throw std::runtime_error("Assertion " #cond " failed!"); \ 26 | } \ 27 | } while (0) 28 | 29 | namespace onnx2trt 30 | { 31 | 32 | void ImporterContext::pushBaseNameScope() 33 | { 34 | mBaseNameScopeStack.push_back({}); 35 | } 36 | 37 | void ImporterContext::popBaseNameScope() 38 | { 39 | auto& tensorMap = tensors(); 40 | for (auto& binding : mBaseNameScopeStack.back()) 41 | { 42 | if (binding.second.first) 43 | { 44 | tensorMap.erase(binding.first); 45 | } 46 | else 47 | { 48 | tensorMap.at(binding.first) = std::move(binding.second.second); 49 | } 50 | } 51 | mBaseNameScopeStack.pop_back(); 52 | } 53 | 54 | void ImporterContext::registerTensor(TensorOrWeights tensor, std::string const& basename, bool const checkUniqueName) 55 | { 56 | // TRT requires unique tensor names. 57 | std::string const& uniqueName = generateUniqueName(mTensorNames, mSuffixCounter, basename); 58 | 59 | if (tensor) 60 | { 61 | if (tensor.is_tensor()) 62 | { 63 | tensor.tensor().setName(uniqueName.c_str()); 64 | // Logging macro refers to ctx. 65 | auto* ctx = this; 66 | LOG_VERBOSE("Registering tensor: " << uniqueName << " for ONNX tensor: " << basename); 67 | } 68 | else if (tensor.is_weights()) 69 | { 70 | // It may be possible for nested subgraphs to have different values for the same initializer. 71 | // For multiple name scopes - use unique name to keep track of weights. 72 | if (!mBaseNameScopeStack.empty()) 73 | { 74 | tensor.weights().setName(uniqueName.c_str()); 75 | } 76 | else 77 | { 78 | tensor.weights().setName(basename.c_str()); 79 | } 80 | } 81 | } 82 | 83 | std::string const& nameToCheck = checkUniqueName ? uniqueName : basename; 84 | 85 | auto const p = this->tensors().emplace(nameToCheck, TensorOrWeights{}); 86 | bool nameIsDuplicate = false; 87 | if (!mBaseNameScopeStack.empty()) 88 | { 89 | // Remember original binding so it can be restored when scope is popped. 90 | auto const q 91 | = mBaseNameScopeStack.back().emplace(nameToCheck, std::make_pair(p.second, std::move(p.first->second))); 92 | // Check that scope did not already have a binding for basename. 93 | nameIsDuplicate = !q.second; 94 | } 95 | else 96 | { 97 | // The condition here accounts for ModelImporter::importModel reserving 98 | // output names by registering null tensors. 99 | nameIsDuplicate = !p.second && !p.first->second.isNullTensor(); 100 | } 101 | if (nameIsDuplicate) 102 | { 103 | throw std::runtime_error("ONNX graph has duplicate tensor name: " + nameToCheck); 104 | } 105 | p.first->second = std::move(tensor); 106 | } 107 | 108 | void ImporterContext::registerLayer(nvinfer1::ILayer* layer, std::string const& basename, ::ONNX_NAMESPACE::NodeProto const* node) 109 | { 110 | // No layer will be added for Constant nodes in ONNX. 111 | if (layer) 112 | { 113 | std::string const name = basename.empty() ? layer->getName() : basename; 114 | std::string const& uniqueName = generateUniqueName(mLayerNames, mSuffixCounter, basename); 115 | 116 | auto* ctx = this; // To enable logging. 117 | if (node != nullptr) 118 | { 119 | LOG_VERBOSE("Registering layer: " << uniqueName << " for ONNX node: " << basename); 120 | } 121 | else 122 | { 123 | LOG_VERBOSE("Registering layer: " << uniqueName << " required by ONNX-TRT"); 124 | } 125 | 126 | layer->setName(uniqueName.c_str()); 127 | if (layer->getType() == nvinfer1::LayerType::kCONSTANT) 128 | { 129 | if (basename != uniqueName && mConstantLayers.find(uniqueName) != mConstantLayers.end()) 130 | { 131 | LOG_ERROR("Constant layer: " << uniqueName << " can be a duplicate of: " << basename); 132 | assert(!"Internal error: duplicate constant layers for the same weights"); 133 | } 134 | mConstantLayers.insert({uniqueName, static_cast(layer)}); 135 | } 136 | } 137 | // Set metadata only if the layer is associated with an ONNX node. 138 | // Skip constant layers because constants are represented as initializers in ONNX and should not be associated 139 | // with any ONNX node. 140 | if (node != nullptr && layer != nullptr && layer->getType() != nvinfer1::LayerType::kCONSTANT) 141 | { 142 | processMetadata(this, *node, layer); 143 | } 144 | } 145 | 146 | void ImporterContext::registerLayer(nvinfer1::ILayer* layer, ::ONNX_NAMESPACE::NodeProto const& node) 147 | { 148 | std::string const& basename = getNodeName(node); 149 | registerLayer(layer, basename, &node); 150 | } 151 | 152 | namespace 153 | { 154 | 155 | //! Translates a "logical" library name into an OS-dependent DSO or DLL name 156 | std::string getOSLibraryName(char const* logicalName) 157 | { 158 | std::stringstream libName; 159 | #if defined(_WIN32) 160 | libName << logicalName << ".dll"; 161 | #else 162 | libName << "lib" << logicalName << ".so." << NV_TENSORRT_MAJOR; 163 | #endif 164 | return libName.str(); 165 | } 166 | 167 | //! Platform-agnostic wrapper around dynamic libraries. 168 | class DynamicLibrary 169 | { 170 | public: 171 | explicit DynamicLibrary(std::string const& name) 172 | : mLibName{name} 173 | { 174 | #if defined(_WIN32) 175 | mHandle = LoadLibraryA(name.c_str()); 176 | #else // defined(_WIN32) 177 | int32_t flags{RTLD_LAZY}; 178 | mHandle = dlopen(name.c_str(), flags); 179 | #endif // defined(_WIN32) 180 | 181 | if (mHandle == nullptr) 182 | { 183 | std::string errorStr{}; 184 | #if !defined(_WIN32) 185 | errorStr = std::string{" due to "} + std::string{dlerror()}; 186 | #endif 187 | throw std::runtime_error("Unable to open library: " + name + errorStr); 188 | } 189 | } 190 | 191 | DynamicLibrary(DynamicLibrary const&) = delete; 192 | DynamicLibrary(DynamicLibrary const&&) = delete; 193 | 194 | ~DynamicLibrary() 195 | { 196 | try 197 | { 198 | #if defined(_WIN32) 199 | RT_ASSERT(static_cast(FreeLibrary(static_cast(mHandle)))); 200 | #else 201 | RT_ASSERT(dlclose(mHandle) == 0); 202 | #endif 203 | } 204 | catch (...) 205 | { 206 | std::cerr << "Unable to close library: " << mLibName << std::endl; 207 | } 208 | } 209 | 210 | std::string getFullPath() const 211 | { 212 | RT_ASSERT(mHandle != nullptr); 213 | #if defined(__linux__) 214 | link_map* linkMap = nullptr; 215 | auto const err = dlinfo(mHandle, RTLD_DI_LINKMAP, &linkMap); 216 | RT_ASSERT(err == 0 && linkMap != nullptr && linkMap->l_name != nullptr); 217 | return std::string{linkMap->l_name}; 218 | #elif defined(_WIN32) 219 | constexpr int32_t kMAX_PATH_LEN{4096}; 220 | std::string path(kMAX_PATH_LEN, '\0'); // since C++11, std::string storage is guaranteed to be contiguous 221 | auto const pathLen = GetModuleFileNameA(static_cast(mHandle), &path[0], kMAX_PATH_LEN); 222 | RT_ASSERT(GetLastError() == ERROR_SUCCESS); 223 | path.resize(pathLen); 224 | path.shrink_to_fit(); 225 | return path; 226 | #else 227 | RT_ASSERT(!"Unsupported operation: getFullPath()"); 228 | #endif 229 | } 230 | 231 | private: 232 | std::string mLibName{}; //!< Name of the DynamicLibrary 233 | void* mHandle{}; //!< Handle to the DynamicLibrary 234 | }; 235 | 236 | //! Translates an OS-dependent DSO/DLL name into a path on the filesystem 237 | std::string getOSLibraryPath(std::string const& osLibName) 238 | { 239 | DynamicLibrary lib{osLibName}; 240 | return lib.getFullPath(); 241 | } 242 | 243 | } // namespace 244 | 245 | void ImporterContext::addUsedVCPluginLibrary( 246 | ::ONNX_NAMESPACE::NodeProto const& node, char const* pluginName, char const* pluginLib) 247 | { 248 | auto* ctx = this; // For logging 249 | auto osPluginLibName = getOSLibraryName(pluginLib); 250 | LOG_VERBOSE("Node " << getNodeName(node) << " requires plugin " << pluginName << " which is provided by " 251 | << osPluginLibName); 252 | mLogicalVCPluginLibraries.insert(osPluginLibName); 253 | } 254 | 255 | std::vector ImporterContext::getUsedVCPluginLibraries() 256 | { 257 | auto* ctx = this; // For logging 258 | #if defined(_WIN32) || defined(__linux__) 259 | std::vector ret; 260 | ret.reserve(mLogicalVCPluginLibraries.size()); 261 | for (auto const& l : mLogicalVCPluginLibraries) 262 | { 263 | auto osLibPath = getOSLibraryPath(l); 264 | LOG_VERBOSE("Library " << l << " located on filesystem as " << osLibPath); 265 | ret.emplace_back(std::move(osLibPath)); 266 | } 267 | return ret; 268 | #else 269 | LOG_WARNING("getUsedVCPluginLibraries not implemented on platform!"); 270 | return {}; 271 | #endif 272 | } 273 | 274 | } // namespace onnx2trt 275 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2023 NVIDIA Corporation 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | 204 | 205 | PORTIONS LICENSED AS FOLLOWS 206 | 207 | > ieee_half.h 208 | > half.h 209 | 210 | The MIT License 211 | 212 | Copyright (c) 2012-2017 Christian Rau 213 | 214 | Permission is hereby granted, free of charge, to any person obtaining a 215 | copy of this software and associated documentation files (the "Software"), 216 | to deal in the Software without restriction, including without limitation 217 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 218 | and/or sell copies of the Software, and to permit persons to whom the 219 | Software is furnished to do so, subject to the following conditions: 220 | 221 | The above copyright notice and this permission notice shall be included 222 | in all copies or substantial portions of the Software. 223 | 224 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 225 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 226 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 227 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 228 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 229 | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 230 | DEALINGS IN THE SOFTWARE. 231 | 232 | 233 | -------------------------------------------------------------------------------- /LoopHelpers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "LoopHelpers.hpp" 6 | #include "importerUtils.hpp" 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | nvinfer1::ITensor* addLoopCounter(ImporterContext* ctx, nvinfer1::ILoop* loop, int64_t initial) 12 | { 13 | nvinfer1::ITensor* initialTensor 14 | = addConstantScalar(ctx, initial, ::ONNX_NAMESPACE::TensorProto::INT64, nvinfer1::Dims{1, {1}})->getOutput(0); 15 | nvinfer1::ITensor* one = addConstantScalar(ctx, static_cast(1), ::ONNX_NAMESPACE::TensorProto::INT64, 16 | nvinfer1::Dims{1, {1}})->getOutput(0); 17 | 18 | auto counter = N_CHECK(loop->addRecurrence(*initialTensor)); 19 | nvinfer1::ITensor* addOne = getElementWiseResult(ctx, *N_CHECK(counter->getOutput(0)), *one, nvinfer1::ElementWiseOperation::kSUM); 20 | counter->setInput(1, *addOne); 21 | return N_CHECK(counter->getOutput(0)); 22 | } 23 | 24 | } // namespace onnx2trt 25 | -------------------------------------------------------------------------------- /LoopHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include "ImporterContext.hpp" 10 | 11 | namespace onnx2trt 12 | { 13 | 14 | nvinfer1::ITensor* addLoopCounter(ImporterContext* ctx, nvinfer1::ILoop* loop, int64_t initial = 0); 15 | 16 | } // namespace onnx2trt 17 | -------------------------------------------------------------------------------- /ModelImporter.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ImporterContext.hpp" 8 | #include "NvInferPlugin.h" 9 | #include "NvOnnxParser.h" 10 | #include "errorHelpers.hpp" 11 | #include "onnxOpCheckers.hpp" 12 | #include "onnxOpImporters.hpp" 13 | #include 14 | 15 | namespace onnx2trt 16 | { 17 | 18 | void parseNode(ImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, size_t const nodeIdx, 19 | bool deserializingINetwork = false); 20 | 21 | void parseNodeStaticCheck( 22 | ImporterContext* ctx, ::ONNX_NAMESPACE::NodeProto const& node, std::vector& errors, size_t const nodeIndex); 23 | 24 | void parseGraph(ImporterContext* ctx, ::ONNX_NAMESPACE::GraphProto const& graph, std::vector& errors, 25 | bool deserializingINetwork = false, int32_t* currentNode = nullptr); 26 | 27 | class ModelImporter : public nvonnxparser::IParser 28 | { 29 | using SubGraphSupport_t = std::pair, bool>; 30 | using SubGraphSupportVector_t = std::vector; 31 | 32 | protected: 33 | StringMap _op_importers; 34 | virtual void importModel(::ONNX_NAMESPACE::ModelProto const& model); 35 | 36 | private: 37 | ImporterContext mImporterCtx; 38 | std::vector mPluginLibraryList; // Array of strings containing plugin libs 39 | std::vector 40 | mPluginLibraryListCStr; // Array of C-strings corresponding to the strings in mPluginLibraryList 41 | std::list<::ONNX_NAMESPACE::ModelProto> mONNXModels; // Needed for ownership of weights 42 | SubGraphSupportVector_t mSubGraphSupportVector; 43 | int mCurrentNode; 44 | mutable std::vector mErrors; // Marked as mutable so that errors could be reported from const functions 45 | nvonnxparser::OnnxParserFlags mOnnxParserFlags{ 46 | 1U << static_cast( 47 | nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM)}; // kNATIVE_INSTANCENORM is ON by default. 48 | std::pair doSupportsModel( 49 | void const* serialized_onnx_model, size_t serialized_onnx_model_size, char const* model_path = nullptr); 50 | 51 | public: 52 | ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger) noexcept 53 | : _op_importers(getBuiltinOpImporterMap()) 54 | , mImporterCtx(network, logger) 55 | { 56 | } 57 | bool parseWithWeightDescriptors( 58 | void const* serialized_onnx_model, size_t serialized_onnx_model_size) noexcept override; 59 | bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, 60 | const char* model_path = nullptr) noexcept override; 61 | 62 | bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size, 63 | SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) noexcept override; 64 | bool supportsModelV2(void const* serialized_onnx_model, size_t serialized_onnx_model_size, 65 | char const* model_path = nullptr) noexcept override; 66 | 67 | int64_t getNbSubgraphs() noexcept override; 68 | bool isSubgraphSupported(int64_t const index) noexcept override; 69 | int64_t* getSubgraphNodes(int64_t const index, int64_t& subgraphLength) noexcept override; 70 | 71 | bool supportsOperator(const char* op_name) const noexcept override; 72 | 73 | void setFlags(nvonnxparser::OnnxParserFlags onnxParserFlags) noexcept override 74 | { 75 | mOnnxParserFlags = onnxParserFlags; 76 | } 77 | nvonnxparser::OnnxParserFlags getFlags() const noexcept override 78 | { 79 | return mOnnxParserFlags; 80 | } 81 | 82 | void clearFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) noexcept override 83 | { 84 | ONNXTRT_TRY 85 | { 86 | mOnnxParserFlags &= ~(1U << static_cast(onnxParserFlag)); 87 | } 88 | ONNXTRT_CATCH_RECORD 89 | } 90 | 91 | void setFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) noexcept override 92 | { 93 | ONNXTRT_TRY 94 | { 95 | mOnnxParserFlags |= 1U << static_cast(onnxParserFlag); 96 | } 97 | ONNXTRT_CATCH_RECORD 98 | } 99 | 100 | bool getFlag(nvonnxparser::OnnxParserFlag onnxParserFlag) const noexcept override 101 | { 102 | ONNXTRT_TRY 103 | { 104 | auto flag = 1U << static_cast(onnxParserFlag); 105 | return static_cast(mOnnxParserFlags & flag); 106 | } 107 | ONNXTRT_CATCH_RECORD 108 | return false; 109 | } 110 | 111 | int32_t getNbErrors() const noexcept override 112 | { 113 | return mErrors.size(); 114 | } 115 | nvonnxparser::IParserError const* getError(int32_t index) const noexcept override 116 | { 117 | ONNXTRT_TRY 118 | { 119 | return &mErrors.at(index); 120 | } 121 | ONNXTRT_CATCH_RECORD 122 | return nullptr; 123 | } 124 | void clearErrors() noexcept override 125 | { 126 | mErrors.clear(); 127 | } 128 | 129 | nvinfer1::ITensor const* getLayerOutputTensor(char const* name, int64_t i) noexcept override 130 | { 131 | ONNXTRT_TRY 132 | { 133 | if (!name) 134 | { 135 | throw std::invalid_argument("name is a nullptr"); 136 | } 137 | return mImporterCtx.findLayerOutputTensor(name, i); 138 | } 139 | ONNXTRT_CATCH_RECORD 140 | return nullptr; 141 | } 142 | 143 | bool parseFromFile(char const* onnxModelFile, int32_t verbosity) noexcept override; 144 | 145 | virtual char const* const* getUsedVCPluginLibraries(int64_t& nbPluginLibs) const noexcept override; 146 | }; 147 | 148 | } // namespace onnx2trt 149 | -------------------------------------------------------------------------------- /ModelRefitter.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "NvInferRuntime.h" 8 | #include "Status.hpp" 9 | #include "WeightsContext.hpp" 10 | #include "errorHelpers.hpp" 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // Logging macros 17 | #define LOG_REFITTER(msg, severity) \ 18 | do \ 19 | { \ 20 | std::ostringstream ss{}; \ 21 | if (severity <= nvinfer1::ILogger::Severity::kWARNING) \ 22 | ss << __FILENAME__ << ":" << __LINE__ << ": "; \ 23 | ss << msg; \ 24 | mLogger->log(severity, ss.str().c_str()); \ 25 | } while (0) 26 | 27 | #define LOG_REFITTER_WARNING(msg) LOG_REFITTER(msg, nvinfer1::ILogger::Severity::kWARNING) 28 | 29 | namespace onnx2trt 30 | { 31 | class ModelRefitter : public nvonnxparser::IParserRefitter 32 | { 33 | private: 34 | nvinfer1::IRefitter* mRefitter; 35 | nvinfer1::ILogger* mLogger; 36 | 37 | //! WeightsContext object to hold ownership of ONNX weights and any temporary weights created by the refitter. 38 | WeightsContext mWeightsContext; 39 | 40 | //! ONNX ModelProto object to hold ownership of ONNX weights whenever a data type conversion is not needed. 41 | ::ONNX_NAMESPACE::ModelProto onnx_model; 42 | 43 | //! Counter to limit the recursion depth to a set amount for nodes containing subgraphs. 44 | size_t nestedDepth{0}; 45 | 46 | //! Set to keep track of how many times a batch norm weight name shows up, to avoid duplicate naming in TRT. 47 | std::set mBatchNormWeightNames; 48 | //! An increasing suffix counter used to uniquify batch norm weight names. 49 | int64_t mBatchNormWeightSuffixCounter{0}; 50 | 51 | size_t successfullyRefittedWeights{}; 52 | std::unordered_set refittableWeights; 53 | std::unordered_set refittedWeights; 54 | 55 | mutable std::vector mErrors; 56 | 57 | std::unordered_set getRefittableWeights(); 58 | 59 | //! T is the working type. 60 | //! TConvertFunc is a functor for converting ShapedWeights to an array of type T. 61 | //! It should return a T*. 62 | template 63 | size_t batchnormWeightRefitter( 64 | ::ONNX_NAMESPACE::NodeProto const& node, std::vector& inputs, TConvertFunc&& f); 65 | 66 | void refitOnnxWeights(::ONNX_NAMESPACE::ModelProto const& onnx_model); 67 | void refitOnnxGraph(::ONNX_NAMESPACE::GraphProto const& graph); 68 | void refitOnnxNode(::ONNX_NAMESPACE::NodeProto const& node, ::ONNX_NAMESPACE::GraphProto const& graph); 69 | void refitOnnxConstantNode(::ONNX_NAMESPACE::NodeProto const& node, std::string const& graphName); 70 | void refitOnnxBatchNormNode(::ONNX_NAMESPACE::NodeProto const& node, ::ONNX_NAMESPACE::GraphProto const& graph); 71 | void refitOnnxIfNode(::ONNX_NAMESPACE::NodeProto const& node); 72 | void refitOnnxLoopNode(::ONNX_NAMESPACE::NodeProto const& node); 73 | void refitOnnxScanNode(::ONNX_NAMESPACE::NodeProto const& node); 74 | 75 | public: 76 | ModelRefitter(nvinfer1::IRefitter* refitter, nvinfer1::ILogger* logger) 77 | : mRefitter{refitter} 78 | , mLogger{logger} 79 | , mWeightsContext{WeightsContext{logger}} 80 | { 81 | } 82 | 83 | bool refitFromBytes(void const* serializedOnnxModel, size_t serializedOnnxModelSize, 84 | char const* modelPath = nullptr) noexcept override; 85 | bool refitFromFile(char const* onnxModelFile) noexcept override; 86 | 87 | int32_t getNbErrors() const noexcept override 88 | { 89 | return mErrors.size(); 90 | } 91 | 92 | nvonnxparser::IParserError const* getError(int32_t index) const noexcept override 93 | { 94 | ONNXTRT_TRY 95 | { 96 | return (index >= 0 && index < mErrors.size()) ? &mErrors.at(index) : nullptr; 97 | } 98 | ONNXTRT_CATCH_LOG(mLogger) 99 | return nullptr; 100 | } 101 | 102 | void clearErrors() noexcept override 103 | { 104 | mErrors.clear(); 105 | } 106 | }; 107 | 108 | } // namespace onnx2trt 109 | -------------------------------------------------------------------------------- /NvOnnxParser.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "NvOnnxParser.h" 6 | #include "ModelImporter.hpp" 7 | #include "ModelRefitter.hpp" 8 | #include "NvInferRuntime.h" 9 | 10 | extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version) noexcept 11 | { 12 | auto network = static_cast(network_); 13 | auto logger = static_cast(logger_); 14 | return new onnx2trt::ModelImporter(network, logger); 15 | } 16 | 17 | extern "C" void* createNvOnnxParserRefitter_INTERNAL(void* refitter_, void* logger_, int32_t version) noexcept 18 | { 19 | auto refitter = static_cast(refitter_); 20 | auto logger = static_cast(logger_); 21 | return new onnx2trt::ModelRefitter(refitter, logger); 22 | } 23 | 24 | extern "C" int getNvOnnxParserVersion() noexcept 25 | { 26 | return NV_ONNX_PARSER_VERSION; 27 | } 28 | -------------------------------------------------------------------------------- /OnnxAttrs.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "OnnxAttrs.hpp" 6 | #include "ShapedWeights.hpp" 7 | #include "importerUtils.hpp" 8 | #include 9 | 10 | bool isExternalAttribute(std::string const& key, onnx2trt::ImporterContext* ctx) 11 | { 12 | return !key.empty() && !ctx->localFunctionStack().empty() && ctx->localFunctionStack().back().attrs.count(key); 13 | } 14 | 15 | template <> 16 | float OnnxAttrs::get(std::string const& key) const 17 | { 18 | std::string extName = this->at(key)->ref_attr_name(); 19 | bool isExtAttr = isExternalAttribute(extName, mCtx); 20 | return isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->f() : this->at(key)->f(); 21 | } 22 | 23 | template <> 24 | int32_t OnnxAttrs::get(std::string const& key) const 25 | { 26 | std::string extName = this->at(key)->ref_attr_name(); 27 | bool isExtAttr = isExternalAttribute(extName, mCtx); 28 | return isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->i() : this->at(key)->i(); 29 | } 30 | 31 | template <> 32 | int64_t OnnxAttrs::get(std::string const& key) const 33 | { 34 | std::string extName = this->at(key)->ref_attr_name(); 35 | bool isExtAttr = isExternalAttribute(extName, mCtx); 36 | return isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->i() : this->at(key)->i(); 37 | } 38 | 39 | template <> 40 | bool OnnxAttrs::get(std::string const& key) const 41 | { 42 | std::string extName = this->at(key)->ref_attr_name(); 43 | bool isExtAttr = isExternalAttribute(extName, mCtx); 44 | int64_t value = isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->i() : this->at(key)->i(); 45 | assert(value == bool(value)); 46 | return static_cast(value); 47 | } 48 | 49 | template <> 50 | std::string OnnxAttrs::get(std::string const& key) const 51 | { 52 | std::string extName = this->at(key)->ref_attr_name(); 53 | bool isExtAttr = isExternalAttribute(extName, mCtx); 54 | return isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->s() : this->at(key)->s(); 55 | } 56 | 57 | template <> 58 | std::vector OnnxAttrs::get>(std::string const& key) const 59 | { 60 | std::string extName = this->at(key)->ref_attr_name(); 61 | bool isExtAttr = isExternalAttribute(extName, mCtx); 62 | auto attr = isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->ints() : this->at(key)->ints(); 63 | return std::vector(attr.begin(), attr.end()); 64 | } 65 | 66 | template <> 67 | std::vector OnnxAttrs::get>(std::string const& key) const 68 | { 69 | std::string extName = this->at(key)->ref_attr_name(); 70 | bool isExtAttr = isExternalAttribute(extName, mCtx); 71 | auto attr = isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->ints() : this->at(key)->ints(); 72 | return std::vector(attr.begin(), attr.end()); 73 | } 74 | 75 | template <> 76 | std::vector OnnxAttrs::get>(std::string const& key) const 77 | { 78 | std::string extName = this->at(key)->ref_attr_name(); 79 | bool isExtAttr = isExternalAttribute(extName, mCtx); 80 | auto attr = isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->floats() : this->at(key)->floats(); 81 | return std::vector(attr.begin(), attr.end()); 82 | } 83 | 84 | template <> 85 | nvinfer1::Dims OnnxAttrs::get(std::string const& key) const 86 | { 87 | auto values = this->get>(key); 88 | nvinfer1::Dims dims; 89 | dims.nbDims = values.size(); 90 | if (dims.nbDims > nvinfer1::Dims::MAX_DIMS) 91 | { 92 | throw std::runtime_error{"Number of dimensions values exceed the maximum amount supported by TensorRT!"}; 93 | } 94 | std::copy(values.begin(), values.end(), dims.d); 95 | // Note: No dimension type information is included 96 | return dims; 97 | } 98 | 99 | template <> 100 | nvinfer1::DimsHW OnnxAttrs::get(std::string const& key) const 101 | { 102 | nvinfer1::Dims dims = this->get(key); 103 | assert(dims.nbDims == 2); 104 | return nvinfer1::DimsHW(dims.d[0], dims.d[1]); 105 | } 106 | 107 | template <> 108 | nvinfer1::Permutation OnnxAttrs::get(std::string const& key) const 109 | { 110 | auto values = this->get>(key); 111 | nvinfer1::Permutation perm; 112 | if (values.size() > nvinfer1::Dims::MAX_DIMS) 113 | { 114 | throw std::runtime_error{"Number of permutations values exceed the maximum amount supported by TensorRT!"}; 115 | } 116 | std::copy(values.begin(), values.end(), perm.order); 117 | // Fill unused values with identity permutation 118 | for (int32_t i = values.size(); i < nvinfer1::Dims::MAX_DIMS; ++i) 119 | { 120 | perm.order[i] = i; 121 | } 122 | return perm; 123 | } 124 | 125 | template <> 126 | onnx2trt::ShapedWeights OnnxAttrs::get(std::string const& key) const 127 | { 128 | // Check for reference attribute in parent function 129 | std::string extName = this->at(key)->ref_attr_name(); 130 | bool isExtAttr = isExternalAttribute(extName, mCtx); 131 | 132 | ::ONNX_NAMESPACE::TensorProto const& onnxTensor = isExtAttr ? mCtx->localFunctionStack().back().attrs.at(extName)->t() : this->at(key)->t(); 133 | onnx2trt::ShapedWeights weights; 134 | bool success = mCtx->getWeightsContext().convertOnnxWeights(onnxTensor, &weights, true); 135 | if (!success) 136 | { 137 | throw std::runtime_error{"Unable to convert ONNX weights"}; 138 | } 139 | return weights; 140 | } 141 | 142 | template <> 143 | nvinfer1::DataType OnnxAttrs::get(std::string const& key) const 144 | { 145 | ::ONNX_NAMESPACE::TensorProto::DataType onnx_dtype 146 | = static_cast<::ONNX_NAMESPACE::TensorProto::DataType>(this->at(key)->i()); 147 | nvinfer1::DataType dtype{}; 148 | if (!onnx2trt::convertDtype(onnx_dtype, &dtype)) 149 | { 150 | dtype = static_cast(-1); 151 | } 152 | return dtype; 153 | } 154 | 155 | template <> 156 | std::vector OnnxAttrs::get>(std::string const& key) const 157 | { 158 | auto attr = this->at(key)->ints(); 159 | auto onnx_dtypes = std::vector(attr.begin(), attr.end()); 160 | std::vector dtypes{}; 161 | for (auto onnx_dtype : onnx_dtypes) 162 | { 163 | nvinfer1::DataType dtype{}; 164 | if (!onnx2trt::convertDtype(static_cast(onnx_dtype), &dtype)) 165 | { 166 | dtype = static_cast(-1); 167 | } 168 | dtypes.push_back(dtype); 169 | } 170 | return dtypes; 171 | } 172 | 173 | inline nvinfer1::ActivationType activationStringToEnum(std::string const& type) 174 | { 175 | if (type == "Relu") 176 | { 177 | return nvinfer1::ActivationType::kRELU; 178 | } 179 | if (type == "Tanh") 180 | { 181 | return nvinfer1::ActivationType::kTANH; 182 | } 183 | if (type == "Sigmoid") 184 | { 185 | return nvinfer1::ActivationType::kSIGMOID; 186 | } 187 | if (type == "LeakyRelu") 188 | { 189 | return nvinfer1::ActivationType::kLEAKY_RELU; 190 | } 191 | if (type == "ThresholdedRelu") 192 | { 193 | return nvinfer1::ActivationType::kTHRESHOLDED_RELU; 194 | } 195 | if (type == "ScaledTanh") 196 | { 197 | return nvinfer1::ActivationType::kSCALED_TANH; 198 | } 199 | if (type == "HardSigmoid") 200 | { 201 | return nvinfer1::ActivationType::kHARD_SIGMOID; 202 | } 203 | if (type == "Elu") 204 | { 205 | return nvinfer1::ActivationType::kELU; 206 | } 207 | if (type == "Softsign") 208 | { 209 | return nvinfer1::ActivationType::kSOFTSIGN; 210 | } 211 | if (type == "Softplus") 212 | { 213 | return nvinfer1::ActivationType::kSOFTPLUS; 214 | } 215 | throw std::runtime_error("Unknown activation type: " + type); 216 | } 217 | 218 | template <> 219 | nvinfer1::ActivationType OnnxAttrs::get(std::string const& key) const 220 | { 221 | const std::string type = this->get(key); 222 | return activationStringToEnum(type); 223 | } 224 | 225 | template <> 226 | std::vector OnnxAttrs::get>( 227 | std::string const& key) const 228 | { 229 | const auto strings = this->at(key)->strings(); 230 | std::vector actTypes; 231 | for (const auto& str : strings) 232 | { 233 | actTypes.emplace_back(activationStringToEnum(str)); 234 | } 235 | return actTypes; 236 | } 237 | 238 | template <> 239 | const ::ONNX_NAMESPACE::GraphProto& OnnxAttrs::get(std::string const& key) const 240 | { 241 | return this->at(key)->g(); 242 | } 243 | 244 | template <> 245 | std::vector OnnxAttrs::get>(std::string const& key) const 246 | { 247 | auto attr = this->at(key)->strings(); 248 | return std::vector(attr.begin(), attr.end()); 249 | } 250 | 251 | template <> 252 | nvinfer1::ScaleMode OnnxAttrs::get(std::string const& key) const 253 | { 254 | std::string s = this->get(key); 255 | if (s == "uniform") 256 | { 257 | return nvinfer1::ScaleMode::kUNIFORM; 258 | } 259 | if (s == "channel") 260 | { 261 | return nvinfer1::ScaleMode::kCHANNEL; 262 | } 263 | if (s == "elementwise") 264 | { 265 | return nvinfer1::ScaleMode::kELEMENTWISE; 266 | } 267 | throw std::runtime_error("Unknown ScaleMode: " + s); 268 | } 269 | 270 | template <> 271 | nvinfer1::MatrixOperation OnnxAttrs::get(std::string const& key) const 272 | { 273 | std::string s = this->get(key); 274 | if (s == "none") 275 | { 276 | return nvinfer1::MatrixOperation::kNONE; 277 | } 278 | if (s == "transpose") 279 | { 280 | return nvinfer1::MatrixOperation::kTRANSPOSE; 281 | } 282 | if (s == "vector") 283 | { 284 | return nvinfer1::MatrixOperation::kVECTOR; 285 | } 286 | throw std::runtime_error("Unknown MatrixOperation: " + s); 287 | } 288 | 289 | template <> 290 | nvinfer1::InterpolationMode OnnxAttrs::get(std::string const& key) const 291 | { 292 | const auto& mode = this->get(key); 293 | if (mode == "nearest") 294 | { 295 | return nvinfer1::InterpolationMode::kNEAREST; 296 | } 297 | if (mode == "linear" || mode == "bilinear") 298 | { 299 | return nvinfer1::InterpolationMode::kLINEAR; 300 | } 301 | throw std::runtime_error("Unknown InterpolationMode: " + mode); 302 | } 303 | 304 | template <> 305 | nvinfer1::ResizeCoordinateTransformation OnnxAttrs::get( 306 | std::string const& key) const 307 | { 308 | const auto& transformation = this->get(key); 309 | if (transformation == "align_corners") 310 | { 311 | return nvinfer1::ResizeCoordinateTransformation::kALIGN_CORNERS; 312 | } 313 | if (transformation == "asymmetric") 314 | { 315 | return nvinfer1::ResizeCoordinateTransformation::kASYMMETRIC; 316 | } 317 | if (transformation == "half_pixel") 318 | { 319 | return nvinfer1::ResizeCoordinateTransformation::kHALF_PIXEL; 320 | } 321 | throw std::runtime_error("Unknown ResizeCoordinateTransformation: " + transformation); 322 | } 323 | 324 | template <> 325 | nvinfer1::ResizeSelector OnnxAttrs::get(std::string const& key) const 326 | { 327 | const auto& selector = this->get(key); 328 | if (selector == "formula") 329 | { 330 | return nvinfer1::ResizeSelector::kFORMULA; 331 | } 332 | if (selector == "upper") 333 | { 334 | return nvinfer1::ResizeSelector::kUPPER; 335 | } 336 | throw std::runtime_error("Unknown ResizeSelector: " + selector); 337 | } 338 | 339 | template <> 340 | nvinfer1::ResizeRoundMode OnnxAttrs::get(std::string const& key) const 341 | { 342 | const auto& roundMode = this->get(key); 343 | if (roundMode == "half_up") 344 | { 345 | return nvinfer1::ResizeRoundMode::kHALF_UP; 346 | } 347 | if (roundMode == "half_down") 348 | { 349 | return nvinfer1::ResizeRoundMode::kHALF_DOWN; 350 | } 351 | if (roundMode == "floor") 352 | { 353 | return nvinfer1::ResizeRoundMode::kFLOOR; 354 | } 355 | if (roundMode == "ceil") 356 | { 357 | return nvinfer1::ResizeRoundMode::kCEIL; 358 | } 359 | throw std::runtime_error("Unknown ResizeRoundMode: " + roundMode); 360 | } 361 | -------------------------------------------------------------------------------- /OnnxAttrs.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "ImporterContext.hpp" 13 | 14 | class OnnxAttrs 15 | { 16 | template 17 | using string_map = std::unordered_map; 18 | typedef string_map<::ONNX_NAMESPACE::AttributeProto const*> AttrMap; 19 | AttrMap _attrs; 20 | onnx2trt::ImporterContext* mCtx; 21 | 22 | public: 23 | explicit OnnxAttrs(::ONNX_NAMESPACE::NodeProto const& onnx_node, onnx2trt::ImporterContext* ctx) 24 | : mCtx{ctx} 25 | { 26 | for (auto const& attr : onnx_node.attribute()) 27 | { 28 | _attrs.insert({attr.name(), &attr}); 29 | } 30 | } 31 | 32 | bool count(std::string const& key) const 33 | { 34 | return _attrs.count(key); 35 | } 36 | 37 | ::ONNX_NAMESPACE::AttributeProto const* at(std::string key) const 38 | { 39 | if (!_attrs.count(key)) 40 | { 41 | throw std::out_of_range("Attribute not found: " + key); 42 | } 43 | return _attrs.at(key); 44 | } 45 | 46 | ::ONNX_NAMESPACE::AttributeProto::AttributeType type(std::string const& key) const 47 | { 48 | return this->at(key)->type(); 49 | } 50 | 51 | template 52 | T get(std::string const& key) const; 53 | 54 | template 55 | T get(std::string const& key, T const& default_value) const 56 | { 57 | return _attrs.count(key) ? this->get(key) : default_value; 58 | } 59 | }; 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TensorRT Backend For ONNX 4 | 5 | Parses ONNX models for execution with [TensorRT](https://developer.nvidia.com/tensorrt). 6 | 7 | See also the [TensorRT documentation](https://docs.nvidia.com/deeplearning/tensorrt/). 8 | 9 | For the list of recent changes, see the [changelog](docs/Changelog.md). 10 | 11 | For a list of commonly seen issues and questions, see the [FAQ](docs/faq.md). 12 | 13 | For business inquiries, please contact researchinquiries@nvidia.com 14 | 15 | For press and other inquiries, please contact Hector Marinez at hmarinez@nvidia.com 16 | 17 | ## Supported TensorRT Versions 18 | 19 | Development on the this branch is for the latest version of [TensorRT 10.11](https://developer.nvidia.com/nvidia-tensorrt-download) with full-dimensions and dynamic shape support. 20 | 21 | For previous versions of TensorRT, refer to their respective branches. 22 | 23 | ## Supported Operators 24 | 25 | Current supported ONNX operators are found in the [operator support matrix](docs/operators.md). 26 | 27 | # Installation 28 | 29 | ### Dependencies 30 | 31 | - [Protobuf >= 3.0.x](https://github.com/google/protobuf/releases) 32 | - [TensorRT 10.11](https://developer.nvidia.com/tensorrt) 33 | - [TensorRT 10.11 open source libraries](https://github.com/NVIDIA/TensorRT/) 34 | 35 | ### Building 36 | 37 | For building within docker, we recommend using and setting up the docker containers as instructed in the main [TensorRT repository](https://github.com/NVIDIA/TensorRT#setting-up-the-build-environment) to build the onnx-tensorrt library. 38 | 39 | Once you have cloned the repository, you can build the parser libraries and executables by running: 40 | 41 | cd onnx-tensorrt 42 | mkdir build && cd build 43 | cmake .. -DTENSORRT_ROOT= && make -j 44 | # Ensure that you update your LD_LIBRARY_PATH to pick up the location of the newly built library: 45 | export LD_LIBRARY_PATH=$PWD:$LD_LIBRARY_PATH 46 | 47 | Note that this project has a dependency on CUDA. By default the build will look in `/usr/local/cuda` for the CUDA toolkit installation. If your CUDA path is different, overwrite the default path by providing `-DCUDA_TOOLKIT_ROOT_DIR=` in the CMake command. 48 | 49 | To build with `protobuf-lite` support, add `-DUSE_ONNX_LITE_PROTO=1` to the end of the `cmake` command. 50 | 51 | ### InstanceNormalizaiton Performance 52 | 53 | There are two implementations of InstanceNormalization that may perform differently depending on various parameters. By default, the parser will use the native TensorRT implementation of InstanceNorm. Users that want to benchmark using the plugin implementation of InstanceNorm can unset the parser flag `kNATIVE_INSTANCENORM` prior to parsing the model. Note that the plugin implementation cannot be used for building version compatible or hardware compatible engines, and attempting to do so will result in an error. 54 | 55 | C++ Example: 56 | 57 | // Unset the kNATIVE_INSTANCENORM flag to use the plugin implementation. 58 | parser->unsetFlag(nvonnxparser::OnnxParserFlag::kNATIVE_INSTANCENORM); 59 | 60 | Python Example: 61 | 62 | // Unset the NATIVE_INSTANCENORM flag to use the plugin implementation. 63 | parser.clear_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM) 64 | 65 | ## Executable Usage 66 | 67 | There are currently two officially supported tools for users to quickly check if an ONNX model can parse and build into a TensorRT engine from an ONNX file. 68 | 69 | For C++ users, there is the [trtexec](https://github.com/NVIDIA/TensorRT/tree/main/samples/opensource/trtexec) binary that is typically found in the `/bin` directory. The basic command of running an ONNX model is: 70 | 71 | `trtexec --onnx=model.onnx` 72 | 73 | Refer to the link or run `trtexec -h` for more information on CLI options. 74 | 75 | For Python users, there is the [polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy) tool. The basic command for running an onnx model is: 76 | 77 | `polygraphy run model.onnx --trt` 78 | 79 | Refer to the link or run `polygraphy run -h` for more information on CLI options. 80 | 81 | ### Python Modules 82 | 83 | Python bindings for the ONNX-TensorRT parser are packaged in the shipped `.whl` files. 84 | 85 | TensorRT 10.11 supports ONNX release 1.18.0. Install it with: 86 | 87 | python3 -m pip install onnx==1.18.0 88 | 89 | The ONNX-TensorRT backend can be installed by running: 90 | 91 | python3 setup.py install 92 | 93 | ## ONNX-TensorRT Python Backend Usage 94 | 95 | The TensorRT backend for ONNX can be used in Python as follows: 96 | 97 | ```python 98 | import onnx 99 | import onnx_tensorrt.backend as backend 100 | import numpy as np 101 | 102 | model = onnx.load("/path/to/model.onnx") 103 | engine = backend.prepare(model, device='CUDA:1') 104 | input_data = np.random.random(size=(32, 3, 224, 224)).astype(np.float32) 105 | output_data = engine.run(input_data)[0] 106 | print(output_data) 107 | print(output_data.shape) 108 | ``` 109 | 110 | ## C++ Library Usage 111 | 112 | The model parser library, libnvonnxparser.so, has its C++ API declared in this header: 113 | 114 | NvOnnxParser.h 115 | 116 | ### Tests 117 | 118 | After installation (or inside the Docker container), ONNX backend tests can be run as follows: 119 | 120 | Real model tests only: 121 | 122 | python onnx_backend_test.py OnnxBackendRealModelTest 123 | 124 | All tests: 125 | 126 | python onnx_backend_test.py 127 | 128 | You can use `-v` flag to make output more verbose. 129 | 130 | ## Pre-trained Models 131 | 132 | Pre-trained models in ONNX format can be found at the [ONNX Model Zoo](https://github.com/onnx/models) 133 | -------------------------------------------------------------------------------- /RNNHelpers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "RNNHelpers.hpp" 6 | #include "LoopHelpers.hpp" 7 | #include "importerUtils.hpp" 8 | #include 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | nvinfer1::ITensor* addRNNInput(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ILoop* loop, 14 | std::vector& inputs, const std::string& direction) 15 | { 16 | // In the forward/reverse cases, we only use a single iterator. In the bidirectional case, a forward and reverse 17 | // iterator must be concatenated. 18 | // Input dimensions: [1, B, E] 19 | nvinfer1::ITensor* iterationInput{nullptr}; 20 | nvinfer1::ITensor* input = &convertToTensor(inputs.at(0), ctx); 21 | 22 | const int sequenceLenIndex = 4; 23 | bool isRagged = inputs.size() > sequenceLenIndex && inputs.at(sequenceLenIndex); 24 | 25 | if (direction == "forward") 26 | { 27 | iterationInput = unsqueezeTensor(ctx, *N_CHECK(loop->addIterator(*input)->getOutput(0)), std::vector{0}); 28 | 29 | if (isRagged) 30 | { 31 | nvinfer1::ITensor* seqLens = &convertToTensor(inputs.at(sequenceLenIndex), ctx); 32 | auto maxLen = getAxisLength(ctx, input, 0); 33 | iterationInput = clearMissingSequenceElements(ctx, node, loop, seqLens, iterationInput, maxLen); 34 | } 35 | } 36 | else if (direction == "reverse") 37 | { 38 | nvinfer1::IIteratorLayer* reverseIterator = N_CHECK(loop->addIterator(*input)); 39 | reverseIterator->setReverse(true); 40 | auto reverseIteratorOutput = N_CHECK(reverseIterator->getOutput(0)); 41 | iterationInput = unsqueezeTensor(ctx, *reverseIteratorOutput, std::vector{0}); 42 | if (isRagged) 43 | { 44 | nvinfer1::ITensor* seqLens = &convertToTensor(inputs.at(sequenceLenIndex), ctx); 45 | auto maxLen = getAxisLength(ctx, input, 0); 46 | iterationInput = clearMissingSequenceElements(ctx, node, loop, seqLens, iterationInput, maxLen, true); 47 | } 48 | } 49 | else if (direction == "bidirectional") 50 | { 51 | nvinfer1::IIteratorLayer* forward = N_CHECK(loop->addIterator(*input)); 52 | nvinfer1::IIteratorLayer* reverse = N_CHECK(loop->addIterator(*input)); 53 | reverse->setReverse(true); 54 | 55 | auto forwardInput = unsqueezeTensor(ctx, *N_CHECK(forward->getOutput(0)), std::vector{0}); 56 | auto reverseInput = unsqueezeTensor(ctx, *N_CHECK(reverse->getOutput(0)), std::vector{0}); 57 | if (isRagged) 58 | { 59 | nvinfer1::ITensor* seqLens = &convertToTensor(inputs.at(sequenceLenIndex), ctx); 60 | auto counter = addLoopCounter(ctx, loop); 61 | auto maxLen = getAxisLength(ctx, input, 0); 62 | forwardInput = clearMissingSequenceElements(ctx, node, loop, seqLens, forwardInput, maxLen, false, counter); 63 | reverseInput = clearMissingSequenceElements(ctx, node, loop, seqLens, reverseInput, maxLen, true, counter); 64 | } 65 | 66 | // Stack on the 0th axis to create a (numDirections, B, E) tensor. 67 | std::array tensors{{forwardInput, reverseInput}}; 68 | nvinfer1::IConcatenationLayer* concat = N_CHECK(ctx->network()->addConcatenation(tensors.data(), 2)); 69 | concat->setAxis(0); 70 | iterationInput = N_CHECK(concat->getOutput(0)); 71 | } 72 | if (iterationInput) 73 | { 74 | LOG_VERBOSE("Input shape: " << iterationInput->getDimensions()); 75 | } 76 | return iterationInput; 77 | } 78 | 79 | nvinfer1::ITensor* clearMissingSequenceElements(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, 80 | nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* toMask, nvinfer1::ITensor* maxLen, 81 | bool reverse, nvinfer1::ITensor* counter) 82 | { 83 | nvinfer1::ITensor* zero 84 | = addConstantScalar(ctx, 0.f, ::ONNX_NAMESPACE::TensorProto::FLOAT, nvinfer1::Dims3(1, 1, 1))->getOutput(0); 85 | nvinfer1::ITensor* seqMask = getRaggedMask(ctx, node, loop, seqLens, maxLen, reverse, counter); 86 | auto selectLayer = N_CHECK(ctx->network()->addSelect(*seqMask, *toMask, *zero)); 87 | return N_CHECK(selectLayer->getOutput(0)); 88 | } 89 | 90 | nvinfer1::ITensor* maskRNNHidden(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ILoop* loop, 91 | nvinfer1::ITensor* seqLens, nvinfer1::ITensor* prevH, nvinfer1::ITensor* Ht, nvinfer1::ITensor* maxLen, 92 | bool reverse, nvinfer1::ITensor* counter) 93 | { 94 | // maxLen must be provided if reverse is true 95 | // Forwards previous hidden state if invalid 96 | nvinfer1::ITensor* valid = getRaggedMask(ctx, node, loop, seqLens, maxLen, reverse, counter); 97 | auto selectLayer = N_CHECK(ctx->network()->addSelect(*valid, *Ht, *prevH)); 98 | return N_CHECK(selectLayer->getOutput(0)); 99 | } 100 | 101 | nvinfer1::ITensor* maskBidirRNNHidden(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, 102 | nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen, nvinfer1::ITensor* Ht1, 103 | nvinfer1::ITensor* Ht, nvinfer1::ITensor* singlePassShape) 104 | { 105 | // Splits hidden state into forward and backward states, masks each accordingly, then concatenates 106 | 107 | nvinfer1::ITensor* forwardStart = addConstant(ctx, std::vector{0, 0, 0}, 108 | ::ONNX_NAMESPACE::TensorProto::INT32, 109 | nvinfer1::Dims{1, {3}})->getOutput(0); 110 | nvinfer1::ITensor* reverseStart = addConstant(ctx, std::vector{1, 0, 0}, 111 | ::ONNX_NAMESPACE::TensorProto::INT32, 112 | nvinfer1::Dims{1, {3}})->getOutput(0); 113 | 114 | nvinfer1::ISliceLayer* HtForwardLayer 115 | = N_CHECK(ctx->network()->addSlice(*Ht, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{1, 1, 1})); 116 | HtForwardLayer->setInput(1, *forwardStart); 117 | HtForwardLayer->setInput(2, *singlePassShape); 118 | 119 | nvinfer1::ISliceLayer* HtBackwardLayer 120 | = N_CHECK(ctx->network()->addSlice(*Ht, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{1, 1, 1})); 121 | HtBackwardLayer->setInput(1, *reverseStart); 122 | HtBackwardLayer->setInput(2, *singlePassShape); 123 | 124 | nvinfer1::ISliceLayer* Ht1ForwardLayer 125 | = N_CHECK(ctx->network()->addSlice(*Ht1, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{1, 1, 1})); 126 | Ht1ForwardLayer->setInput(1, *forwardStart); 127 | Ht1ForwardLayer->setInput(2, *singlePassShape); 128 | 129 | nvinfer1::ISliceLayer* Ht1BackwardLayer 130 | = N_CHECK(ctx->network()->addSlice(*Ht1, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{0, 0, 0}, nvinfer1::Dims3{1, 1, 1})); 131 | Ht1BackwardLayer->setInput(1, *reverseStart); 132 | Ht1BackwardLayer->setInput(2, *singlePassShape); 133 | 134 | auto forwardHt = N_CHECK(HtForwardLayer->getOutput(0)); 135 | auto backwardHt = N_CHECK(HtBackwardLayer->getOutput(0)); 136 | auto forwardHt1 = N_CHECK(Ht1ForwardLayer->getOutput(0)); 137 | auto backwardHt1 = N_CHECK(Ht1BackwardLayer->getOutput(0)); 138 | 139 | auto counter = addLoopCounter(ctx, loop, 0); 140 | forwardHt = maskRNNHidden(ctx, node, loop, seqLens, forwardHt1, forwardHt, maxLen, false, counter); 141 | backwardHt = maskRNNHidden(ctx, node, loop, seqLens, backwardHt1, backwardHt, maxLen, true, counter); 142 | std::array tensors{{forwardHt, backwardHt}}; 143 | nvinfer1::IConcatenationLayer* concat = N_CHECK(ctx->network()->addConcatenation(tensors.data(), 2)); 144 | concat->setAxis(0); 145 | return N_CHECK(concat->getOutput(0)); 146 | } 147 | 148 | nvinfer1::ITensor* getRaggedMask(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ILoop* loop, 149 | nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen, bool reverse, nvinfer1::ITensor* counter) 150 | { 151 | // Returns a bool tensor which is true where the elements are valid (within the sequence) and false when outside the 152 | // sequence. 153 | // maxLen must be provided if reverse is true 154 | assert(!reverse || maxLen); 155 | 156 | if (!counter) 157 | { 158 | counter = addLoopCounter(ctx, loop, 0); 159 | } 160 | 161 | // ONNX spec currently requires seqLens to be int32 162 | counter = castHelper(ctx, counter, nvinfer1::DataType::kINT32); 163 | 164 | // Create Mask 165 | nvinfer1::ITensor* seqMask; 166 | if (reverse) 167 | { 168 | counter = getElementWiseResult( 169 | ctx, *unsqueezeTensor(ctx, *maxLen, {0}), *counter, nvinfer1::ElementWiseOperation::kSUB); 170 | seqMask = getElementWiseResult(ctx, *seqLens, *counter, nvinfer1::ElementWiseOperation::kLESS); 171 | seqMask = getUnaryResult(ctx, *seqMask, nvinfer1::UnaryOperation::kNOT); 172 | } 173 | else 174 | { 175 | seqMask = getElementWiseResult(ctx, *counter, *seqLens, nvinfer1::ElementWiseOperation::kLESS); 176 | } 177 | return unsqueezeTensor(ctx, *seqMask, std::vector{0, 2}); 178 | } 179 | 180 | } // namespace onnx2trt 181 | -------------------------------------------------------------------------------- /RNNHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "TensorOrWeights.hpp" 12 | #include "ImporterContext.hpp" 13 | 14 | namespace onnx2trt 15 | { 16 | 17 | nvinfer1::ITensor* addRNNInput(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ILoop* loop, 18 | std::vector& inputs, const std::string& direction); 19 | 20 | // Zeros out invalid timesteps in toMask. maxLen must be provided if reverse is true 21 | nvinfer1::ITensor* clearMissingSequenceElements(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, 22 | nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* toMask, nvinfer1::ITensor* maxLen, 23 | bool reverse = false, nvinfer1::ITensor* counter = nullptr); 24 | 25 | // Returns a bool tensor which is true during valid timesteps 26 | nvinfer1::ITensor* getRaggedMask(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ILoop* loop, 27 | nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen = nullptr, bool reverse = false, 28 | nvinfer1::ITensor* counter = nullptr); 29 | 30 | // Selects between prevH and Ht to forward previous hidden state through invalid timesteps 31 | nvinfer1::ITensor* maskRNNHidden(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, nvinfer1::ILoop* loop, 32 | nvinfer1::ITensor* seqLens, nvinfer1::ITensor* prevH, nvinfer1::ITensor* Ht, nvinfer1::ITensor* maxLen = nullptr, 33 | bool reverse = false, nvinfer1::ITensor* counter = nullptr); 34 | 35 | // Splits a bidirectional hidden state into forward and reverse passes, masks each using maskRNNHidden, then 36 | // concatenates 37 | nvinfer1::ITensor* maskBidirRNNHidden(ImporterContext* ctx, const ::ONNX_NAMESPACE::NodeProto& node, 38 | nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen, nvinfer1::ITensor* Ht1, 39 | nvinfer1::ITensor* Ht, nvinfer1::ITensor* singlePassShape); 40 | 41 | } // namespace onnx2trt 42 | -------------------------------------------------------------------------------- /ShapeTensor.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace onnx2trt 13 | { 14 | 15 | class ImporterContext; 16 | class TensorOrWeights; 17 | 18 | //! Represents a 0D or 1D tensor of int64_t. 19 | class ShapeTensor 20 | { 21 | public: 22 | //! Create undefined ShapeTensor. 23 | ShapeTensor() = default; 24 | 25 | //! Create ShapeTensor with known rank and int64_t values. 26 | ShapeTensor(int32_t rank_, std::vector&& values_); 27 | 28 | //! Create ShapeTensor with known rank and float values. 29 | ShapeTensor(int32_t rank_, std::vector&& values_); 30 | 31 | //! Create ShapeTensor representing value of TensorOrWeights. 32 | ShapeTensor(ImporterContext* ctx, TensorOrWeights& t); 33 | 34 | //! Construct ShapeTensor equivalent to applying IShapeLayer depth times. 35 | //! The depth may be in [0,3]. 36 | explicit ShapeTensor(nvinfer1::ITensor& t, int depth = 0); 37 | 38 | //! True if rank is known. 39 | bool rankKnown() const 40 | { 41 | return mRank != kRANK_UNKNOWN; 42 | } 43 | 44 | //! Number of dimensions. Always 0 or 1. 45 | int32_t rank() const 46 | { 47 | assert(rankKnown()); 48 | return mRank; 49 | } 50 | 51 | //! True if number of elements in tensor is known. 52 | bool sizeKnown() const 53 | { 54 | return mSize != kSIZE_UNKNOWN; 55 | } 56 | 57 | //! Number of elements in the tensor. Asserts that sizeKnown()==true. 58 | int32_t size() const 59 | { 60 | assert(sizeKnown()); 61 | return mSize; 62 | } 63 | 64 | //! True if tensor is known to be an empty vector. 65 | bool isEmpty() const 66 | { 67 | // No need to check rank because if rank is 0, then mSize==1, 68 | // and if rank is unknown, mSize = kSIZE_UNKNOWN. 69 | return mSize == 0; 70 | } 71 | 72 | //! True if all element values are known. 73 | bool allValuesKnown() const 74 | { 75 | return mAllValuesKnown; 76 | } 77 | 78 | //! True if all element values equal the given value. 79 | bool isAll(int64_t value) const; 80 | 81 | //! True if floating-point shape tensor. 82 | bool isFloat() const 83 | { 84 | return mIsFloat; 85 | } 86 | 87 | using const_iterator = std::vector::const_iterator; 88 | 89 | //! Iterator pointing to beginning of sequence of element values. 90 | //! Requires that allValuesKnown() is true. 91 | const_iterator begin() const 92 | { 93 | assert(mAllValuesKnown); 94 | return mValues.begin(); 95 | } 96 | 97 | //! Iterator pointing to end of sequence of element values. 98 | //! Requires that allValuesKnown() is true. 99 | const_iterator end() const 100 | { 101 | assert(mAllValuesKnown); 102 | return mValues.end(); 103 | } 104 | 105 | //! True if operator[](k) is valid. 106 | bool valueKnown(int k) const; 107 | 108 | //! Return kth value. 109 | //! For a 0D tensor, k must be 0. 110 | //! Requires that valueKnown(k) is true. 111 | int64_t operator[](int k) const 112 | { 113 | assert(valueKnown(k)); 114 | return mValues[k]; 115 | } 116 | 117 | //! Return true if x and y always have the same value. 118 | friend bool operator==(const ShapeTensor& x, const ShapeTensor& y); 119 | friend ShapeTensor shapeOf(const ShapeTensor& t); 120 | 121 | //! Get TensorRT tensor representation. 122 | nvinfer1::ITensor& tensor(ImporterContext* ctx) const; 123 | 124 | private: 125 | //! Number of IShapeLayer to apply to mTensor to get ITensor representing value of *this. 126 | //! -1 for undefined *this, a value in [0,2] otherwise. 127 | //! 0: *this represents value of the tensor (always 0D or 1D) 128 | //! 1: *this represents shape of mTensor (always 1D) 129 | //! 2: *this represents rank of mTensor (always 1D tensor of length 1) 130 | mutable int8_t mDepth{-1}; 131 | 132 | //! True if all values are known. 133 | bool mAllValuesKnown{false}; 134 | 135 | static constexpr int kRANK_UNKNOWN = -1; 136 | static constexpr int kSIZE_UNKNOWN = -1; 137 | 138 | //! Rank of *this. 139 | //! Always -1, 0 or 1. 140 | int8_t mRank{kRANK_UNKNOWN}; 141 | 142 | //! Number of elements in the tensor, or -1 if unknown. 143 | int32_t mSize{kSIZE_UNKNOWN}; 144 | 145 | //! Must be non-null if mAllValuesKnown. 146 | mutable nvinfer1::ITensor* mTensor{nullptr}; 147 | 148 | //! Values of elements if some might be known. 149 | //! mValues.size() is always zero or equal to mSize. 150 | //! When mAllValuesKnown==true, all the values in mValues are correct 151 | //! and mValues.size() == mSize. 152 | //! When mAllValuesKnown==false, only the non-negative values in mValues 153 | //! are guaranteed to be correct, and only so if mValues.size() == mSize. 154 | std::vector mValues{}; 155 | 156 | bool mIsFloat{false}; 157 | }; 158 | 159 | //! Print ShapeTensor. Unknown values are printed as _. 160 | std::ostream& operator<<(std::ostream& stream, const ShapeTensor& x); 161 | 162 | //! Create 1D ShapeTensor of length n filled with value. 163 | //! count must be 1D ShapeTensor of size 1. 164 | ShapeTensor fillShapeVector(ImporterContext* ctx, int64_t value, const ShapeTensor& count); 165 | 166 | //! Create 1D ShapeTensor of length 1 containing given value. 167 | ShapeTensor shapeVector(int64_t value); 168 | 169 | //! Create 0D ShapeTensor containing the given value. 170 | ShapeTensor shapeScalar(int64_t value); 171 | 172 | //! Create 1D ShapeTensor containing [0,n). 173 | ShapeTensor iotaShapeVector(int32_t n); 174 | 175 | //! Create ShapeTensor filled with value that has same shape as exemplar. 176 | //! The exemplar must be 1D. 177 | ShapeTensor similar(ImporterContext* ctx, const ShapeTensor& exemplar, int64_t value); 178 | 179 | //! Elementwise addition 180 | ShapeTensor add(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 181 | 182 | //! Elementwise subtraction 183 | ShapeTensor sub(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 184 | 185 | //! Elementwise multiplication 186 | ShapeTensor mul(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 187 | 188 | //! Elementwise min 189 | ShapeTensor min(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 190 | 191 | //! Elementwise max 192 | ShapeTensor max(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 193 | 194 | //! Elementwise floor division 195 | ShapeTensor floorDiv(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 196 | 197 | //! Elementwise f, for a partial function f defined by: 198 | //! f(x,x) = x 199 | //! f(1,x) = x 200 | //! f(x,1) = x 201 | //! Undefined otherwise or if x < 0. 202 | ShapeTensor broadcast(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 203 | 204 | //! Return product of x[i] for i in [first..last), as 0D or one-element 1D tensor of given rank. 205 | ShapeTensor product(ImporterContext* ctx, const ShapeTensor& x, int first, int last, int rank); 206 | 207 | //! Gather where data is 1D tensor and indices can be 0D or 1D 208 | ShapeTensor gather(ImporterContext* ctx, const ShapeTensor& data, const ShapeTensor& indices); 209 | 210 | //! Concatenation of two 1D tensors 211 | ShapeTensor concat(ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y); 212 | 213 | //! Cast to int32_t shape tensor. 214 | ShapeTensor castToInt32(ImporterContext* ctx, ShapeTensor const& x); 215 | 216 | //! Cast to int64_t shape tensor. 217 | ShapeTensor castToInt64(ImporterContext* ctx, ShapeTensor const& x); 218 | 219 | //! Return gather(concat(x,y),subscripts) 220 | inline ShapeTensor interlace( 221 | ImporterContext* ctx, const ShapeTensor& x, const ShapeTensor& y, const ShapeTensor& subscripts) 222 | { 223 | return gather(ctx, concat(ctx, x, y), subscripts); 224 | } 225 | 226 | //! Return shape of a tensor. 227 | ShapeTensor shapeOf(nvinfer1::ITensor& tensor); 228 | ShapeTensor shapeOf(const ShapeTensor& tensor); 229 | ShapeTensor shapeOf(TensorOrWeights& t); 230 | 231 | //! Reshape 0D tensor to 1D tensor. 232 | ShapeTensor convertTo1D(ImporterContext* ctx, const ShapeTensor& tensor); 233 | 234 | //! Reshape single value 1D tensor to a 0D tensor. 235 | ShapeTensor convertTo0D(ImporterContext* ctx, const ShapeTensor& tensor); 236 | 237 | //! Convert ShapeTensor to Dims, with bounds checking. 238 | nvinfer1::Dims shapeTensorToDims(const ShapeTensor& x, const char* what, int32_t minAllowed, int32_t maxAllowed); 239 | 240 | //! Add an ISliceLayer. 241 | nvinfer1::ISliceLayer* addSlice(ImporterContext* ctx, nvinfer1::ITensor& data, const ShapeTensor& starts, 242 | const ShapeTensor& sizes, const ShapeTensor& strides); 243 | 244 | //! Add an IShuffleLayer. 245 | //! If the result does not need to have its parameters changed, and 246 | //! optimizing the no-op case away is okay, use function reshape instead. 247 | //! 248 | //! In general the default zeroIsPlaceholder=false should be used so 249 | //! that reshaping to empty tensors works correctly. Calling with 250 | //! zeroIsPlaceholder=true should happen only when replicating the 251 | //! semantics of the ONNX Reshape operator. 252 | nvinfer1::IShuffleLayer* addShuffle( 253 | ImporterContext* ctx, nvinfer1::ITensor& data, const ShapeTensor& reshapeDims, bool zeroIsPlaceholder = false); 254 | 255 | //! Add an IFillLayer. 256 | nvinfer1::IFillLayer* addFill(ImporterContext* ctx, const ShapeTensor& shape, nvinfer1::FillOperation op); 257 | 258 | //! Reshape a tensor. 259 | //! 260 | //! Treats any zeros in newShape as dimensions, not placeholders. 261 | //! Implementation note: does not insert shuffle if it's a no-op. 262 | nvinfer1::ITensor& reshape(ImporterContext* ctx, nvinfer1::ITensor& data, const ShapeTensor& newShape); 263 | 264 | } // namespace onnx2trt 265 | -------------------------------------------------------------------------------- /ShapedWeights.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "ShapedWeights.hpp" 6 | #include "importerUtils.hpp" 7 | #include 8 | #include 9 | #include 10 | 11 | namespace onnx2trt 12 | { 13 | 14 | size_t ShapedWeights::count() const 15 | { 16 | assert(shape.nbDims >= 0); 17 | size_t c = 1; 18 | for (int32_t i = 0; i < this->shape.nbDims; ++i) 19 | { 20 | if (shape.d[i] == 0) 21 | { 22 | c = 0; 23 | break; 24 | } 25 | if (c > std::numeric_limits::max() / shape.d[i]) 26 | { 27 | throw std::runtime_error("Count of weights exceeds maximum!"); 28 | } 29 | c *= this->shape.d[i]; 30 | } 31 | return c; 32 | } 33 | 34 | ShapedWeights ShapedWeights::empty(DataType type) 35 | { 36 | return ShapedWeights(type, nullptr, nvinfer1::Dims{1, {0}}); 37 | } 38 | 39 | ShapedWeights::ShapedWeights(DataType type_, void* values_, nvinfer1::Dims shape_) 40 | : type(type_) 41 | , values(values_) 42 | , shape(shape_) 43 | { 44 | // Note: this->shape.type[] is not used 45 | } 46 | 47 | size_t ShapedWeights::size_bytes() const 48 | { 49 | return getTensorOrWeightsSizeBytes(this->count(), this->type); 50 | } 51 | 52 | ShapedWeights::operator bool() const 53 | { 54 | return (bool) this->values; 55 | } 56 | 57 | ShapedWeights::operator nvinfer1::Weights() const 58 | { 59 | nvinfer1::Weights w{}; 60 | w.values = this->values; 61 | bool supported_type = convertDtype(this->type, &w.type); 62 | (void) supported_type; 63 | assert(supported_type); 64 | w.count = this->count(); 65 | return w; 66 | } 67 | 68 | const char* ShapedWeights::getName() const 69 | { 70 | return this->name; 71 | } 72 | 73 | void ShapedWeights::setName(const char* n) 74 | { 75 | this->name = n; 76 | } 77 | 78 | } // namespace onnx2trt 79 | -------------------------------------------------------------------------------- /ShapedWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | class ShapedWeights 14 | { 15 | public: 16 | using DataType = int32_t; 17 | 18 | //! Create 1D zero-length ShapedWeights of given type, count()==0, and values=nullptr. 19 | static ShapedWeights empty(DataType type); 20 | 21 | //! Construct ShapedWeights that is not expected to be usable, 22 | //! except with `operator=` and method `setName()`. 23 | ShapedWeights() = default; 24 | 25 | explicit ShapedWeights(DataType type, void* values, nvinfer1::Dims shape_); 26 | 27 | size_t count() const; 28 | 29 | size_t size_bytes() const; 30 | 31 | const char* getName() const; 32 | 33 | void setName(const char* name); 34 | 35 | //! True if values exist. 36 | explicit operator bool() const; 37 | 38 | operator nvinfer1::Weights() const; 39 | 40 | template 41 | T& at(size_t index) 42 | { 43 | assert(values && index >= 0 && index < count()); 44 | return static_cast(values)[index]; 45 | } 46 | 47 | template 48 | const T& at(size_t index) const 49 | { 50 | assert(values && index >= 0 && index < count()); 51 | return static_cast(values)[index]; 52 | } 53 | 54 | public: 55 | DataType type{static_cast(-1)}; 56 | void* values{nullptr}; 57 | nvinfer1::Dims shape{-1, {}}; 58 | const char* name{}; 59 | }; 60 | 61 | class ImporterContext; 62 | 63 | } // namespace onnx2trt 64 | -------------------------------------------------------------------------------- /TensorOrWeights.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "TensorOrWeights.hpp" 6 | #include 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | std::string TensorOrWeights::getType() const 12 | { 13 | if (is_tensor()) 14 | { 15 | switch (tensor().getType()) 16 | { 17 | case nvinfer1::DataType::kFLOAT: return "FLOAT"; 18 | case nvinfer1::DataType::kHALF: return "HALF"; 19 | case nvinfer1::DataType::kBF16: return "BF16"; 20 | case nvinfer1::DataType::kINT8: return "INT8"; 21 | case nvinfer1::DataType::kUINT8: return "UINT8"; 22 | case nvinfer1::DataType::kINT32: return "INT32"; 23 | case nvinfer1::DataType::kINT64: return "INT64"; 24 | case nvinfer1::DataType::kBOOL: return "BOOL"; 25 | case nvinfer1::DataType::kFP8: return "FP8"; 26 | case nvinfer1::DataType::kINT4: return "INT4"; 27 | case nvinfer1::DataType::kFP4: return "FP4"; 28 | } 29 | } 30 | else 31 | { 32 | switch (weights().type) 33 | { 34 | // Demote double to float. 35 | case ::ONNX_NAMESPACE::TensorProto::DOUBLE: return "FLOAT"; 36 | case ::ONNX_NAMESPACE::TensorProto::FLOAT: return "FLOAT"; 37 | case ::ONNX_NAMESPACE::TensorProto::INT8: return "INT8"; 38 | case ::ONNX_NAMESPACE::TensorProto::UINT8: return "UINT8"; 39 | case ::ONNX_NAMESPACE::TensorProto::FLOAT16: return "HALF"; 40 | case ::ONNX_NAMESPACE::TensorProto::BFLOAT16: return "BF16"; 41 | case ::ONNX_NAMESPACE::TensorProto::BOOL: return "BOOL"; 42 | case ::ONNX_NAMESPACE::TensorProto::INT32: return "INT32"; 43 | case ::ONNX_NAMESPACE::TensorProto::INT64: return "INT64"; 44 | case ::ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FN: return "FP8"; 45 | case ::ONNX_NAMESPACE::TensorProto::INT4: return "INT4"; 46 | case ::ONNX_NAMESPACE::TensorProto::FLOAT4E2M1: return "FP4"; 47 | } 48 | } 49 | return "UNKNOWN TYPE"; 50 | } 51 | 52 | nvinfer1::DataType TensorOrWeights::convertONNXDataType(ShapedWeights::DataType datatype) const 53 | { 54 | switch (datatype) 55 | { 56 | case ::ONNX_NAMESPACE::TensorProto::DOUBLE: return nvinfer1::DataType::kFLOAT; 57 | case ::ONNX_NAMESPACE::TensorProto::FLOAT: return nvinfer1::DataType::kFLOAT; 58 | case ::ONNX_NAMESPACE::TensorProto::INT8: return nvinfer1::DataType::kINT8; 59 | case ::ONNX_NAMESPACE::TensorProto::UINT8: return nvinfer1::DataType::kUINT8; 60 | case ::ONNX_NAMESPACE::TensorProto::FLOAT16: return nvinfer1::DataType::kHALF; 61 | case ::ONNX_NAMESPACE::TensorProto::BFLOAT16: return nvinfer1::DataType::kBF16; 62 | case ::ONNX_NAMESPACE::TensorProto::BOOL: return nvinfer1::DataType::kBOOL; 63 | case ::ONNX_NAMESPACE::TensorProto::INT32: return nvinfer1::DataType::kINT32; 64 | case ::ONNX_NAMESPACE::TensorProto::INT64: return nvinfer1::DataType::kINT64; 65 | case ::ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FN: return nvinfer1::DataType::kFP8; 66 | case ::ONNX_NAMESPACE::TensorProto::INT4: return nvinfer1::DataType::kINT4; 67 | case ::ONNX_NAMESPACE::TensorProto::FLOAT4E2M1: return nvinfer1::DataType::kFP4; 68 | } 69 | assert(false && "Unknown datatype"); 70 | return nvinfer1::DataType::kFLOAT; 71 | } 72 | 73 | ShapedWeights::DataType TensorOrWeights::convertTRTDataType(nvinfer1::DataType datatype) const 74 | { 75 | switch (datatype) 76 | { 77 | case nvinfer1::DataType::kFLOAT: return ::ONNX_NAMESPACE::TensorProto::FLOAT; 78 | case nvinfer1::DataType::kINT8: return ::ONNX_NAMESPACE::TensorProto::INT8; 79 | case nvinfer1::DataType::kUINT8: return ::ONNX_NAMESPACE::TensorProto::UINT8; 80 | case nvinfer1::DataType::kHALF: return ::ONNX_NAMESPACE::TensorProto::FLOAT16; 81 | case nvinfer1::DataType::kBF16: return ::ONNX_NAMESPACE::TensorProto::BFLOAT16; 82 | case nvinfer1::DataType::kBOOL: return ::ONNX_NAMESPACE::TensorProto::BOOL; 83 | case nvinfer1::DataType::kINT32: return ::ONNX_NAMESPACE::TensorProto::INT32; 84 | case nvinfer1::DataType::kINT64: return ::ONNX_NAMESPACE::TensorProto::INT64; 85 | case nvinfer1::DataType::kFP8: return ::ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FN; 86 | case nvinfer1::DataType::kINT4: return ::ONNX_NAMESPACE::TensorProto::INT4; 87 | case nvinfer1::DataType::kFP4: return ::ONNX_NAMESPACE::TensorProto::FLOAT4E2M1; 88 | } 89 | assert(false && "Unknown datatype"); 90 | return ::ONNX_NAMESPACE::TensorProto::FLOAT; 91 | } 92 | 93 | } // namespace onnx2trt 94 | -------------------------------------------------------------------------------- /TensorOrWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ShapedWeights.hpp" 8 | #include 9 | #include 10 | #include 11 | 12 | namespace onnx2trt 13 | { 14 | 15 | //! Abstract representation of a tensor, which might be a nvinfer1::ITensor or ShapedWeights. 16 | class TensorOrWeights 17 | { 18 | union 19 | { 20 | nvinfer1::ITensor* _tensor; 21 | ShapedWeights _weights; 22 | }; 23 | enum 24 | { 25 | NODE_TENSOR, 26 | NODE_WEIGHTS 27 | } _variant; 28 | 29 | public: 30 | //! Represents "null tensor", which is used to denote "missing tensor". 31 | TensorOrWeights() 32 | : _tensor(nullptr) 33 | , _variant(NODE_TENSOR) 34 | { 35 | } 36 | TensorOrWeights(nvinfer1::ITensor* tensor) 37 | : _tensor(tensor) 38 | , _variant(NODE_TENSOR) 39 | { 40 | } 41 | TensorOrWeights(ShapedWeights const& weights) 42 | : _weights(weights) 43 | , _variant(NODE_WEIGHTS) 44 | { 45 | } 46 | bool is_tensor() const 47 | { 48 | return _variant == NODE_TENSOR; 49 | } 50 | bool is_weights() const 51 | { 52 | return _variant == NODE_WEIGHTS; 53 | } 54 | bool isNullTensor() const 55 | { 56 | return is_tensor() && _tensor == nullptr; 57 | } 58 | nvinfer1::ITensor& tensor() 59 | { 60 | if (is_weights() || isNullTensor()) 61 | { 62 | throw std::runtime_error("Trying to access weights or a null tensor!"); 63 | } 64 | return *_tensor; 65 | } 66 | nvinfer1::ITensor const& tensor() const 67 | { 68 | if (is_weights() || isNullTensor()) 69 | { 70 | throw std::runtime_error("Trying to access weights or a null tensor!"); 71 | } 72 | return *_tensor; 73 | } 74 | ShapedWeights& weights() 75 | { 76 | if (is_tensor()) 77 | { 78 | throw std::runtime_error("Trying to access a null weights!"); 79 | } 80 | return _weights; 81 | } 82 | ShapedWeights const& weights() const 83 | { 84 | if (is_tensor()) 85 | { 86 | throw std::runtime_error("Trying to access a null weights!"); 87 | } 88 | return _weights; 89 | } 90 | nvinfer1::Dims shape() const 91 | { 92 | return is_tensor() ? tensor().getDimensions() : weights().shape; 93 | } 94 | explicit operator bool() const 95 | { 96 | return is_tensor() ? _tensor != nullptr : static_cast(_weights); 97 | } 98 | bool isFp32() const 99 | { 100 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kFLOAT 101 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT; 102 | } 103 | bool isFp16() const 104 | { 105 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kHALF 106 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; 107 | } 108 | bool isBFp16() const 109 | { 110 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kBF16 111 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; 112 | } 113 | bool isInt32() const 114 | { 115 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kINT32 116 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_INT32; 117 | } 118 | bool isInt64() const 119 | { 120 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kINT64 121 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_INT64; 122 | } 123 | bool isInt8() const 124 | { 125 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kINT8 126 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_INT8; 127 | } 128 | bool isUint8() const 129 | { 130 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kUINT8 131 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_UINT8; 132 | } 133 | bool isBool() const 134 | { 135 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kBOOL 136 | : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_BOOL; 137 | } 138 | bool isFp8() const 139 | { 140 | return is_tensor() ? tensor().getType() == nvinfer1::DataType::kFP8 : weights().type == ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN; 141 | } 142 | std::string getName() const 143 | { 144 | return is_tensor() ? tensor().getName() : weights().getName(); 145 | } 146 | std::string getType() const; 147 | 148 | nvinfer1::DataType convertONNXDataType(ShapedWeights::DataType datatype) const; 149 | 150 | ShapedWeights::DataType convertTRTDataType(nvinfer1::DataType datatype) const; 151 | 152 | nvinfer1::DataType getDataType() const 153 | { 154 | if (is_tensor()) 155 | { 156 | return tensor().getType(); 157 | } 158 | else 159 | { 160 | return convertONNXDataType(weights().type); 161 | } 162 | } 163 | 164 | ShapedWeights::DataType getONNXDataType() const 165 | { 166 | if (is_tensor()) 167 | { 168 | return convertTRTDataType(tensor().getType()); 169 | } 170 | else 171 | { 172 | return weights().type; 173 | } 174 | } 175 | }; 176 | 177 | } // namespace onnx2trt 178 | -------------------------------------------------------------------------------- /WeightsContext.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ShapedWeights.hpp" 8 | #include "Status.hpp" 9 | #include "errorHelpers.hpp" 10 | #include "weightUtils.hpp" 11 | #include 12 | #include 13 | 14 | namespace onnx2trt 15 | { 16 | 17 | // Class reponsible for reading, casting, and converting weight values from an ONNX model and into ShapedWeights 18 | // objects. All temporary weights are stored in a buffer owned by the class so they do not go out of scope. 19 | 20 | class WeightsContext 21 | { 22 | struct BufferDeleter 23 | { 24 | void operator()(void* ptr) 25 | { 26 | operator delete(ptr); 27 | } 28 | }; 29 | 30 | using BufferPtr = std::unique_ptr; 31 | 32 | nvinfer1::ILogger* mLogger; 33 | 34 | // Vector of hunks to maintain ownership of weights. 35 | std::vector mWeightBuffers; 36 | 37 | // Keeps track of the absolute location of the file in order to read external weights. 38 | std::string mOnnxFileLocation; 39 | 40 | public: 41 | WeightsContext(nvinfer1::ILogger* logger) 42 | : mLogger(logger){}; 43 | 44 | int32_t* convertUINT8(uint8_t const* weightValues, nvinfer1::Dims const& shape); 45 | 46 | float* convertDouble(double const* weightValues, nvinfer1::Dims const& shape); 47 | 48 | template 49 | DataType* convertInt32Data(int32_t const* weightValues, nvinfer1::Dims const& shape, int32_t onnxdtype); 50 | 51 | uint8_t* convertPackedInt32Data( 52 | int32_t const* weightValues, nvinfer1::Dims const& shape, size_t nbytes, int32_t onnxdtype); 53 | 54 | // Function to create an internal buffer to own the weights without any type conversions. 55 | void* ownWeights(void const* weightValues, ShapedWeights::DataType const dataType, nvinfer1::Dims const& shape, 56 | size_t const nBytes); 57 | 58 | // Function to read bytes from an external file and return the data in a buffer. 59 | bool parseExternalWeights( 60 | 61 | std::string const& file, int64_t offset, int64_t length, std::vector& weightsBuf, size_t& size); 62 | 63 | // Function to read data from an ONNX Tensor and move it into a ShapedWeights object. 64 | // Handles external weights as well. 65 | bool convertOnnxWeights( 66 | ::ONNX_NAMESPACE::TensorProto const& onnxTensor, ShapedWeights* weights, bool ownAllWeights = false); 67 | 68 | // Helper function to convert weightValues' type from fp16/bf16 to fp32. 69 | template 70 | [[nodiscard]] float* convertToFp32(ShapedWeights const& w); 71 | 72 | // Helper function to get fp32 representation of fp16, bf16, or fp32 weights. 73 | float* getFP32Values(ShapedWeights const& w); 74 | 75 | // Register an unique name for the created weights. 76 | ShapedWeights createNamedTempWeights(ShapedWeights::DataType type, nvinfer1::Dims const& shape, 77 | std::set& namesSet, int64_t& suffixCounter, bool batchNormNode = false); 78 | 79 | // Create weights with a given name. 80 | ShapedWeights createNamedWeights(ShapedWeights::DataType type, nvinfer1::Dims const& shape, std::string const& name, 81 | std::set* bufferedNames = nullptr); 82 | 83 | // Creates a ShapedWeights object class of a given type and shape. 84 | ShapedWeights createTempWeights(ShapedWeights::DataType type, nvinfer1::Dims const& shape); 85 | 86 | // Sets the absolute filepath of the loaded ONNX model in order to read external weights. 87 | void setOnnxFileLocation(std::string location) 88 | { 89 | mOnnxFileLocation = location; 90 | } 91 | 92 | // Returns the absolutate filepath of the loaded ONNX model. 93 | std::string getOnnxFileLocation() 94 | { 95 | return mOnnxFileLocation; 96 | } 97 | 98 | // Returns the logger object. 99 | nvinfer1::ILogger& logger() 100 | { 101 | return *mLogger; 102 | } 103 | }; 104 | 105 | template 106 | DataType* WeightsContext::convertInt32Data(int32_t const* weightValues, nvinfer1::Dims const& shape, int32_t onnxdtype) 107 | { 108 | size_t const nbWeights = volume(shape); 109 | DataType* newWeights{static_cast(createTempWeights(onnxdtype, shape).values)}; 110 | 111 | for (size_t i = 0; i < nbWeights; i++) 112 | { 113 | newWeights[i] = static_cast(weightValues[i]); 114 | } 115 | return newWeights; 116 | } 117 | template 118 | [[nodiscard]] float* WeightsContext::convertToFp32(ShapedWeights const& w) 119 | { 120 | int64_t const nbWeights = volume(w.shape); 121 | auto result = static_cast(createTempWeights(::ONNX_NAMESPACE::TensorProto::FLOAT, w.shape).values); 122 | std::copy_n(static_cast(w.values), nbWeights, result); 123 | 124 | return result; 125 | } 126 | 127 | } // namespace onnx2trt 128 | -------------------------------------------------------------------------------- /bfloat16.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "bfloat16.hpp" 6 | #include 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | BFloat16::operator float() const 12 | { 13 | static_assert(sizeof(uint32_t) == sizeof(float), ""); 14 | float val{0.F}; 15 | auto bits = static_cast(mRep) << 16; 16 | std::memcpy(&val, &bits, sizeof(uint32_t)); 17 | return val; 18 | } 19 | 20 | BFloat16::BFloat16(float x) 21 | { 22 | static_assert(sizeof(uint32_t) == sizeof(float), ""); 23 | uint32_t bits{0}; 24 | std::memcpy(&bits, &x, sizeof(float)); 25 | 26 | // FP32 format: 1 sign bit, 8 bit exponent, 23 bit mantissa 27 | // BF16 format: 1 sign bit, 8 bit exponent, 7 bit mantissa 28 | 29 | // Mask for exponent 30 | constexpr uint32_t exponent = 0xFFU << 23; 31 | 32 | // Check if exponent is all 1s (NaN or infinite) 33 | if ((bits & exponent) != exponent) 34 | { 35 | // x is finite - round to even 36 | bits += 0x7FFFU + (bits >> 16 & 1); 37 | } 38 | 39 | mRep = static_cast(bits >> 16); 40 | } 41 | 42 | } // namespace onnx2trt 43 | -------------------------------------------------------------------------------- /bfloat16.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | namespace onnx2trt 10 | { 11 | 12 | //! Implements "Brain Floating Point": like an IEEE FP32, 13 | //! but the significand is only 7 bits instead of 23 bits. 14 | class BFloat16 15 | { 16 | public: 17 | BFloat16() 18 | : mRep(0) 19 | { 20 | } 21 | 22 | // Rounds to even if there is a tie. 23 | BFloat16(float x); 24 | 25 | operator float() const; 26 | 27 | private: 28 | //! Value stored in BFloat16 representation. 29 | uint16_t mRep; 30 | }; 31 | 32 | } // namespace onnx2trt 33 | -------------------------------------------------------------------------------- /common.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include // For ::open 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | // Namespace for common functions used throughout onnx-trt 18 | namespace common 19 | { 20 | struct InferDeleter { 21 | template 22 | void operator()(T* obj) const { 23 | if( obj ) { 24 | obj->destroy(); 25 | } 26 | } 27 | }; 28 | 29 | template 30 | inline std::shared_ptr infer_object(T* obj) { 31 | if( !obj ) { 32 | throw std::runtime_error("Failed to create object"); 33 | } 34 | return std::shared_ptr(obj, InferDeleter()); 35 | } 36 | 37 | // Logger for TensorRT info/warning/errors 38 | class TRT_Logger : public nvinfer1::ILogger { 39 | nvinfer1::ILogger::Severity _verbosity; 40 | std::ostream* _ostream; 41 | public: 42 | TRT_Logger(Severity verbosity=Severity::kWARNING, 43 | std::ostream& ostream=std::cout) 44 | : _verbosity(verbosity), _ostream(&ostream) {} 45 | void log(Severity severity, const char* msg) noexcept override { 46 | if( severity <= _verbosity ) { 47 | time_t rawtime = std::time(0); 48 | char buf[256]; 49 | strftime(&buf[0], 256, 50 | "%Y-%m-%d %H:%M:%S", 51 | std::gmtime(&rawtime)); 52 | const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" : 53 | severity == Severity::kERROR ? " ERROR" : 54 | severity == Severity::kWARNING ? "WARNING" : 55 | severity == Severity::kINFO ? " INFO" : 56 | "UNKNOWN"); 57 | (*_ostream) << "[" << buf << " " << sevstr << "] " 58 | << msg 59 | << std::endl; 60 | } 61 | } 62 | }; 63 | 64 | inline bool ParseFromFile_WAR(google::protobuf::Message* msg, 65 | const char* filename) { 66 | int fd = ::open(filename, O_RDONLY); 67 | google::protobuf::io::FileInputStream raw_input(fd); 68 | raw_input.SetCloseOnDelete(true); 69 | google::protobuf::io::CodedInputStream coded_input(&raw_input); 70 | #if GOOGLE_PROTOBUF_VERSION >= 3011000 71 | // Starting Protobuf 3.11 accepts only single parameter. 72 | coded_input.SetTotalBytesLimit(std::numeric_limits::max()); 73 | #else 74 | // Note: This WARs the very low default size limit (64MB) 75 | coded_input.SetTotalBytesLimit(std::numeric_limits::max(), 76 | std::numeric_limits::max()/4); 77 | #endif 78 | return msg->ParseFromCodedStream(&coded_input); 79 | } 80 | 81 | inline bool MessageToFile(const google::protobuf::Message* msg, 82 | const char* filename) { 83 | int fd = ::open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); 84 | google::protobuf::io::FileOutputStream raw_output(fd); 85 | raw_output.SetCloseOnDelete(true); 86 | google::protobuf::io::CodedOutputStream output(&raw_output); 87 | 88 | // Write the size. 89 | const int size = msg->ByteSize(); 90 | 91 | uint8_t* buffer = output.GetDirectBufferForNBytesAndAdvance(size); 92 | if (buffer != NULL) { 93 | // Optimization: The msg fits in one buffer, so use the faster 94 | // direct-to-array serialization path. 95 | msg->SerializeWithCachedSizesToArray(buffer); 96 | } else { 97 | // Slightly-slower path when the msg is multiple buffers. 98 | msg->SerializeWithCachedSizes(&output); 99 | if (output.HadError()) return false; 100 | } 101 | 102 | return true; 103 | } 104 | 105 | inline bool ParseFromTextFile(google::protobuf::Message* msg, 106 | const char* filename) { 107 | int fd = ::open(filename, O_RDONLY); 108 | google::protobuf::io::FileInputStream raw_input(fd); 109 | raw_input.SetCloseOnDelete(true); 110 | return google::protobuf::TextFormat::Parse(&raw_input, msg); 111 | } 112 | 113 | inline std::string onnx_ir_version_string(int64_t ir_version=::ONNX_NAMESPACE::IR_VERSION) { 114 | int onnx_ir_major = ir_version / 1000000; 115 | int onnx_ir_minor = ir_version % 1000000 / 10000; 116 | int onnx_ir_patch = ir_version % 10000; 117 | return (std::to_string(onnx_ir_major) + "." + 118 | std::to_string(onnx_ir_minor) + "." + 119 | std::to_string(onnx_ir_patch)); 120 | } 121 | 122 | inline void print_version() { 123 | std::cout << "Parser built against:" << std::endl; 124 | std::cout << " ONNX IR version: " << onnx_ir_version_string(::ONNX_NAMESPACE::IR_VERSION) << std::endl; 125 | std::cout << " TensorRT version: " 126 | << NV_TENSORRT_MAJOR << "." 127 | << NV_TENSORRT_MINOR << "." 128 | << NV_TENSORRT_PATCH << std::endl; 129 | } 130 | } // namespace common 131 | -------------------------------------------------------------------------------- /docs/Changelog.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # ONNX-TensorRT Changelog 4 | 5 | # TensorRT 10.11 GA Release - 2025-5-16 6 | For more details, see the 10.11 GA release notes 7 | 8 | - Added `kENABLE_UINT8_AND_ASYMMETRIC_QUANTIZATION_DLA` parser flag to enable UINT8 asymmetric quantization on engines targeting DLA 9 | - Removed restriction that inputs to `RandomNormalLike` and `RandomUniformLike` must be tensors 10 | - Clarified limitations of scan outputs for `Loop` nodes 11 | - Updated ONNX version to `1.18` 12 | 13 | # TensorRT 10.10 GA Release - 2025-5-8 14 | For more details, see the 10.10 GA release notes 15 | 16 | - Cleaned up log spam when the ONNX network contained a mixture Plugins and LocalFunctions 17 | - UINT8 constants are now properly imported for QuantizeLinear & DequantizeLinear nodes 18 | - Plugin fallback importer now also reads its namespace from a Node's domain field 19 | 20 | # TensorRT 10.9 GA Release - 2025-3-7 21 | For more details, see the 10.9 GA release notes 22 | 23 | - Added support for Python AOT plugins 24 | - Added support for opset 21 GroupNorm 25 | - Fixed support for opset 18+ ScatterND 26 | 27 | # TensorRT 10.8 GA Release - 2025-1-30 28 | For more details, see the 10.8 GA release notes 29 | 30 | - Added support for `FLOAT4E2M1` types for quantized networks 31 | - Added support for dynamic axes and improved performance of `CumSum` operations 32 | - Fixed the import of local functions when their input tensor names aliased one from an outside scope 33 | - Added support for `Pow` ops with integer-typed exponent values 34 | 35 | # TensorRT 10.7 GA Release - 2024-11-26 36 | For more details, see the 10.7 GA release notes 37 | 38 | - Now prioritizes using plugins over local functions when a corresponding plugin is available in the registry 39 | - Added dynamic axes support for `Squeeze` and `Unsqueeze` operations 40 | - Added support for parsing mixed-precision `BatchNormalization` nodes in strongly-typed mode 41 | 42 | # TensorRT 10.6 GA Release - 2024-11-1 43 | For more details, see the 10.6 GA release notes 44 | 45 | - Updated ONNX submodule version to 1.17.0 46 | - Fix issue where conditional layers were incorrectly being added 47 | - Updated local function metadata to contain more information 48 | - Added support for parsing nodes with Quickly Deployable Plugins 49 | - Fixed handling of optional outputs 50 | 51 | # TensorRT 10.5 GA Release - 2024-10-1 52 | For more details, see the 10.5 GA release notes. 53 | 54 | - Added support for real-valued `STFT` operations 55 | - Improved error handling in `IParser` 56 | 57 | # TensorRT 10.4 GA Release - 2024-9-5 58 | For more details, see the 10.4 GA release notes. 59 | 60 | - Added support for tensor `axes` for `Pad` operations 61 | - Added support for `BlackmanWindow`, `HammingWindow`, and `HannWindow` operations 62 | - Improved error handling in `IParserRefitter` 63 | - Fixed kernel shape inference in multi-input convolutions 64 | 65 | # TensorRT 10.3 GA Release - 2024-8-7 66 | For more details, see the 10.3 GA release notes. 67 | 68 | - Added support for tensor `axes` inputs for `Slice` nodes 69 | - Updated `ScatterElements` importer to use an updated plugin 70 | 71 | # TensorRT 10.2 GA Release - 2024-7-10 72 | For more details, see the 10.2 GA release notes. 73 | 74 | - Improved error handling with new macros and classes 75 | - Minor changes to op importers for `GRU` and `Squeeze` 76 | 77 | # TensorRT 10.1 GA Release - 2024-6-10 78 | For more details, see the 10.1 GA release notes. 79 | 80 | - Added `supportsModelV2` API 81 | - Added support for `DeformConv` operation 82 | - Added support for `PluginV3` TensorRT Plugins 83 | - Marked all IParser and IParserRefitter APIs as `noexcept` 84 | - Shape inputs can be passed to custom ops supported by `IPluginV3`-based plugins by indicating the input indices to be interpreted as shape inputs by a node attribute named `tensorrt_plugin_shape_input_indices`. 85 | 86 | # TensorRT 10.0 GA Release - 2024-4-25 87 | For more details, see the 10.0 GA release notes. 88 | 89 | - Added support for building with with `protobuf-lite` 90 | - Fixed issue when parsing and refitting models with nested `BatchNormalization` nodes 91 | - Added support for empty inputs in custom plugin nodes 92 | 93 | # TensorRT 10.0 EA Release - 2024-4-1 94 | For more details, see the 10.0 EA release notes. 95 | 96 | - Added new class `IParserRefitter` that can be used to refit a TensorRT engine with the weights of an ONNX model 97 | - `kNATIVE_INSTANCENORM` is now set to ON by default 98 | - Added support for `IPluginV3` interfaces from TensorRT 99 | - Added support for `INT4` quantization 100 | - Added support for the `reduction` attribute in `ScatterElements` 101 | - Added support for `wrap` padding mode in `Pad` 102 | 103 | # TensorRT 9.3 GA Release - 2024-2-8 104 | For more details, see the 9.3 GA release notes for the fixes since 9.2 GA. 105 | 106 | - Added native support for `INT32` and `INT64` types for `ArgMin` and `ArgMax` nodes 107 | - Fixed check for valid `zero_point` values in `QuantizeLinear` and `DequantizeLinear` nodes 108 | 109 | # TensorRT 9.2 GA Release - 2023-11-8 110 | For more details, see the 9.2 GA release notes for the fixes since 9.1 GA. 111 | 112 | - Added support for `Hardmax` 113 | - Fixed type inference for few operators to use native ONNX types 114 | 115 | # TensorRT 9.1 GA Release - 2023-10-18 116 | For more details, see the 9.1 GA release notes for the fixes since 9.0 GA. 117 | 118 | - Added new `ErrorCode` enums to improve error logging 119 | - Added new members to `IParserError` to improve error logging 120 | - Added static checkers when parsing nodes, resulting better reporting of errors 121 | 122 | # TensorRT 9.0 GA Release - 2023-9-5 123 | For more details, see the 9.0 GA release notes for the fixes since 9.0 EA. 124 | 125 | - Added support for FP8 and BF16 datatypes. 126 | - Fixed a bug that previously caused `If` nodes to fail import due to branch output size mismatch 127 | - Improved support for importing ONNX Local Functions 128 | 129 | # TensorRT 9.0 EA Release - 2023-8-4 130 | For more details, see the 9.0 EA release notes for the fixes since 8.6 GA. 131 | 132 | - Added support for INT64 data type. The ONNX parser no longer automatically casts INT64 to INT32. 133 | - Added support for ONNX local functions when parsing ONNX models with the ONNX parser. 134 | - Breaking API Change: In TensorRT 9.0, due to the introduction of INT64 as a supported data type, ONNX models with INT64 I/O require INT64 bindings. Note that prior to this release, such models required INT32 bindings. 135 | - Updated ONNX submodule to v1.14.0. 136 | 137 | # TensorRT 8.6 GA Release - 2023-5-1 138 | For more details, see the 8.6 GA release notes for the fixes since 8.6 EA. 139 | 140 | - Renamed `kVERSION_COMPATIBLE` flag to `kNATIVE_INSTANCENORM` 141 | - Added support for N-D `Trilu` 142 | - Removed old LSTM importer 143 | - Updated ONNX submodule to v1.13.1. 144 | 145 | # TensorRT 8.6 EA Release - 2023-3-13 146 | 147 | ## Added 148 | 149 | For more details, see the 8.6 EA release notes for new features added in TensorRT 8.6. 150 | 151 | - Added support for `GroupNormalization`, `LayerNormalization`, `IsInf` operations 152 | - Added support for INT32 input types for `Argmin`, `Argmax`, and `TopK` 153 | - Added support for `ReverseSequence` operators with dynamic shapes 154 | - Added support for `TopK` operators with dynamic `K` values 155 | - Added `OnnxParserFlag` enum and `setFlag` interfaces to the ONNX parser to modify the default parsing behavior 156 | - Added metadata tracking, now ONNX node metadata will be embedded into TensorRT layers 157 | 158 | ## Changed 159 | 160 | - All cast operations will now use the new `CastLayer` over the pervious `IdentityLayer`. 161 | 162 | # TensorRT 8.5 GA Release - 2022-11-2 163 | 164 | ## Added 165 | 166 | For more details, see the 8.5 GA release notes for new features added in TensorRT 8.5 167 | 168 | - Added the `RandomNormal`, `RandomUniform`, `MeanVarianceNormalization`, `RoiAlign`, `Mod`, `Trilu`, `GridSample` and `NonZero` operations 169 | - Added native support for the `NonMaxSuppression` operator 170 | - Added support for importing ONNX networks with `UINT8` I/O types 171 | 172 | ## Fixed 173 | - Fixed an issue with output padding with 1D deconv 174 | - Fixed an issue when flattening 1D tensors 175 | - Fixed an issue when parsing String attributes from TRT plugins 176 | - Fixed an issue when importing `If` subgraphs with shared initializer names 177 | - Fixied an issue when importing `Loop` subgraphs with `INT_MAX` trip counts 178 | 179 | ## Removed 180 | - Removed `onnx2trt` binary. See the README.md for alternative binaries to run ONNX model with TensorRT. 181 | 182 | ## TensorRT 22.08 Release 2022-8-16 183 | ### Updated 184 | - Updated TensorRT version to 8.4.2 185 | - Updated ONNX submodule version to 1.12 186 | - Updated operators support documentation 187 | 188 | ### Fixes 189 | - Fixed handling of no-op `Flatten` operations 190 | - Fixed `allowZero` logic in Reshape operation 191 | 192 | ### Deprecated 193 | - Deprecated `onnx2trt` binary. This will be removed in the next release of TensorRT. 194 | 195 | ## TensorRT 8.4 GA Release - 2022-6-6 196 | 197 | ### Added 198 | 199 | For more details, see the 8.4 GA release notes for new features added in TensorRT 8.4 200 | 201 | - Added native FP16 support for importing and manipulating FP16 initializers 202 | - Added support for `Shrink` 203 | - Added support for `Xor` 204 | - Added dynamic shape support for `ArgMax` and `ArgMin` 205 | - Added dynamic shape support for `Range` for floating point types 206 | 207 | ### Fixes 208 | - Fixed an issue in tensor name scoping in ONNX models with nested subgraphs 209 | - Fixed misc issues when dealing with empty tensors 210 | - Fixed the operations in the `Celu` importer function 211 | - Removed unnecessary reshapes in the `GEMM` importer function 212 | 213 | ## TensorRT 8.2 GA Release - 2021-11-23 214 | 215 | ### Added 216 | 217 | See the 8.2 EA release notes for new features added in TensorRT 8.2. 218 | 219 | ### Fixes 220 | - Removed duplicate constant layer checks that caused some performance regressions 221 | - Fixed expand dynamic shape calculations 222 | - Added parser-side checks for Scatter layer support 223 | 224 | ## TensorRT 8.2 EA Release - 2021-10-04 225 | ### Added 226 | - Added support for the following ONNX operators: 227 | - Einsum 228 | - IsNan 229 | - GatherND 230 | - Scatter 231 | - ScatterElements 232 | - ScatterND 233 | - Sign 234 | - Round 235 | 236 | ### Updated 237 | - Updated `Gather` and `GatherElements` implementations to natively support negative indices 238 | - Updated `Pad` layer to support ND padding, along with `edge` and `reflect` padding mode support 239 | - Updated `If` layer with general performance improvements. 240 | 241 | ## TensorRT 8.0 Release - 2021-07-02 242 | ### Added 243 | - Rehauled resize operator, now fully supporting the following modes: 244 | - Coordinate Transformation modes: `half_pixel`, `pytorch_half_pixel`, `tf_half_pixel_for_nn`, `asymmetric`, and `align_corners` 245 | - Modes: `nearest`, `linear` 246 | - Nearest Modes: `floor`, `ceil`, `round_prefer_floor`, `round_prefer_ceil` 247 | - QuantizeLinear/DequantizeLinear updates: 248 | - Added support for tensor scales 249 | - Added support for per-axis quantization 250 | - Added support for multi-input ConvTranpose 251 | - Added support for generic 2D padding 252 | - Added experimental support for `NonMaxSuppression` 253 | 254 | ### Updated 255 | - Moved `RefitMap` API to core TensorRT. 256 | - Added Datatype column to [operators.md](https://github.com/onnx/onnx-tensorrt/blob/master/docs/operators.md) 257 | 258 | ## 21.05 Container Release - 2021-05-17 259 | ### Added 260 | - Added library only build target [#659](https://github.com/onnx/onnx-tensorrt/pull/659) 261 | - Added support for negative gather indices [#681](https://github.com/onnx/onnx-tensorrt/pull/681) 262 | - Added support for `DOUBLE`-typed inputs and weights through downcast to float [#674](https://github.com/onnx/onnx-tensorrt/pull/674) 263 | - Added support for optional plugin fields in FallbackPlugin path [#676](https://github.com/onnx/onnx-tensorrt/pull/676) 264 | 265 | ### Updated 266 | - Updated license [#657](https://github.com/onnx/onnx-tensorrt/pull/657) 267 | 268 | ### Fixes 269 | - Fixed index offset calculation in GatherElements [#675](https://github.com/onnx/onnx-tensorrt/pull/675) 270 | - Clarified dynamic shape support for ReverseSequence 271 | 272 | ## 21.03 Container Release - 2021-03-09 273 | ### Added 274 | - Added opset13 support for `SoftMax`, `LogSoftmax`, `Squeeze`, and `Unsqueeze` 275 | - Added support for the `EyeLike` operator 276 | - Added support for the `GatherElements` operator 277 | 278 | ### Fixes 279 | ### Removed 280 | 281 | ## 21.02 Container Release - 2021-01-18 282 | ### Added 283 | - Added support for the `ReverseSequence` operator [#590](https://github.com/onnx/onnx-tensorrt/pull/590) 284 | - Updated `parse()` and `supportsModel()` API calls with an optional `model_path` parameter to support models with external weights [#621](https://github.com/onnx/onnx-tensorrt/pull/621) 285 | - Added support for the `Celu` operator 286 | - Added support for the `CumSum` operator 287 | - Added support for the `LessOrEqual` operator 288 | - Added support for the `LpNormalization` operator 289 | - Added support for the `LpPool` operator 290 | - Added support for the `GreaterOrEqual` operator 291 | - Added support for dynamic inputs in `onnx_tensorrt` python backend 292 | - Added FAQ section for commonly asked questions 293 | 294 | ### Fixes 295 | - Fixed relative path imports for models with external weights [#619]https://github.com/onnx/onnx-tensorrt/pull/619 296 | - Fixed importing loops operators with no loop-carried depedencies [#619](https://github.com/onnx/onnx-tensorrt/pull/619) 297 | - Worked around unsupported BOOL concats through casting [#620](https://github.com/onnx/onnx-tensorrt/pull/620) 298 | - Fixed compilation error with GCC9 [#568](https://github.com/onnx/onnx-tensorrt/pull/568) 299 | 300 | ### Removed 301 | - Removed `onnx_tensorrt/config.py` as it is no longer needed 302 | 303 | ## 20.12 Container Release - 2020-12-17 304 | 305 | ### Added 306 | - Added `setup.py` to properly install `onnx_tensorrt` python backend 307 | - Added 4D transpose for ONNX weights [#557](https://github.com/onnx/onnx-tensorrt/pull/557) 308 | 309 | ### Fixes 310 | - Fixed slice computations for large slices [#558](https://github.com/onnx/onnx-tensorrt/pull/558) 311 | 312 | ## TensorRT 7.2.1 Release - 2020-10-20 313 | 314 | ### Added 315 | - Added support for parsing large models with external data 316 | - Added API for interfacing with TensorRT's refit feature 317 | - Updated `onnx_tensorrt` backend to support dynamic shapes 318 | - Added support for 3D instance normalizations [#515](https://github.com/onnx/onnx-tensorrt/pull/515) 319 | - Improved clarity on the resize modes TRT supports [#512](https://github.com/onnx/onnx-tensorrt/pull/521) 320 | - Added Changelog 321 | 322 | ### Changed 323 | - Unified docker usage between ONNX-TensorRT and TensorRT. 324 | 325 | ## Removed 326 | - Removed deprecated docker files. 327 | - Removed deprecated `setup.py`. 328 | -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributing 3 | 4 | Contributions are always welcome to improve the onnx-tensorrt parser. For those looking to contribute, please follow the PR process as outlined in the [TensorRT Open Source Software repository](https://github.com/NVIDIA/TensorRT/blob/master/CONTRIBUTING.md). 5 | 6 | #### Signing Your Work 7 | 8 | * We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license. 9 | 10 | * Any contribution which contains commits that are not Signed-Off will not be accepted. 11 | 12 | * To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes: 13 | ```bash 14 | $ git commit -s -m "Add cool feature." 15 | ``` 16 | This will append the following to your commit message: 17 | ``` 18 | Signed-off-by: Your Name 19 | ``` 20 | 21 | * Full text of the DCO: 22 | 23 | ``` 24 | Developer Certificate of Origin 25 | Version 1.1 26 | 27 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 28 | 1 Letterman Drive 29 | Suite D4700 30 | San Francisco, CA, 94129 31 | 32 | Everyone is permitted to copy and distribute verbatim copies of this 33 | license document, but changing it is not allowed. 34 | 35 | 36 | Developer's Certificate of Origin 1.1 37 | 38 | By making a contribution to this project, I certify that: 39 | 40 | (a) The contribution was created in whole or in part by me and I 41 | have the right to submit it under the open source license 42 | indicated in the file; or 43 | 44 | (b) The contribution is based upon previous work that, to the best 45 | of my knowledge, is covered under an appropriate open source 46 | license and I have the right under that license to submit that 47 | work with modifications, whether created in whole or in part 48 | by me, under the same open source license (unless I am 49 | permitted to submit under a different license), as indicated 50 | in the file; or 51 | 52 | (c) The contribution was provided directly to me by some other 53 | person who certified (a), (b) or (c) and I have not modified 54 | it. 55 | 56 | (d) I understand and agree that this project and the contribution 57 | are public and that a record of the contribution (including all 58 | personal information I submit with it, including my sign-off) is 59 | maintained indefinitely and may be redistributed consistent with 60 | this project or the open source license(s) involved. 61 | ``` 62 | -------------------------------------------------------------------------------- /docs/faq.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # ONNX-TensorRT FAQ 4 | 5 | For all uses we recommend installing the following tools: 6 | * [ONNX-Graphsurgeon](https://github.com/NVIDIA/TensorRT/tree/main/tools/onnx-graphsurgeon) 7 | * [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy) 8 | 9 | ## How do I import and run an ONNX model through TensorRT? 10 | 11 | There are currently two officially supported tools for users to quickly check if an ONNX model can parse and build into a TensorRT engine from an ONNX file. 12 | 13 | For C++ users, there is the [trtexec](https://github.com/NVIDIA/TensorRT/tree/main/samples/trtexec) binary that is typically found in the `/bin` directory. The basic command of running an ONNX model is: 14 | 15 | `trtexec --onnx=model.onnx` 16 | 17 | Refer to the link or run `trtexec -h` for more information on CLI options. 18 | 19 | For Python users, there is the [polygraphy](https://github.com/NVIDIA/TensorRT/tree/main/tools/Polygraphy) tool. The basic command for running an onnx model is: 20 | 21 | `polygraphy run model.onnx --trt` 22 | 23 | Refer to the link or run `polygraphy run -h` for more information on CLI options. 24 | 25 | ## Common Assertion Errors 26 | 27 | ### `inputs.at(0) must be an initializer!` or `inputs.at(0).is_weights()` 28 | 29 | This is a common error seen when importing some ONNX models into TensorRT. Currently for some TensorRT layers (such as TopK and Padding), some attributes are required to be graph-level constants. We’ve seen some examples in the past where some convertors will insert a subgraph to compute these constants rather than use an initializer for these nodes. In the majority of these cases, constant-folding these subgraphs will result in an ONNX model that can be imported by TensorRT. Polygraphy's surgeon tool provides a constant folding function. It can be run through: 30 | 31 | `polygraphy surgeon sanitize model.onnx --fold-constants --output model_folded.onnx` 32 | 33 | ### `Network must have at least one output!` 34 | 35 | This is a generic error which is seen when there was an issue parsing the ONNX model. To better root cause where the error has occurred, re-run the parsing step with verbose logging to better understand where the parser failed. 36 | 37 | ### `getPluginCreator() could not find Plugin version 1` 38 | 39 | This is an error stating that onnx-tensorrt does not have an import function defined for a particular operator. The TensorRT team is continuously working on improving the operator coverage in onnx-tensorrt. Feel free to open an issue on any unsupported operators that you come across in your models. For more information on how to write implementation for unsupported operators yourself, see the `custom layer support` section below. 40 | 41 | ## Custom Layer Support 42 | 43 | Custom layer support in onnx-tensorrt is done through TensorRT plugins. Any custom plugins must be registered with TensorRT’s plugin registry in order for it to be visible to the onnx-tensorrt parser. 44 | 45 | For writing a plugin for a custom ONNX operator, the quickest way to do so without modifying the parser code is by utilizing the `fallbackPluginImporter` function. As long as the inputs, outputs, and attributes of your custom operator are consistent with those of your plugin, the ONNX-TensorRT parser will do the mapping for you. You can refer to [this blog post](https://developer.nvidia.com/blog/estimating-depth-beyond-2d-using-custom-layers-on-tensorrt-and-onnx-models/) on how to write a plugin for a custom ONNX operator. 46 | 47 | For writing a plugin for existing ONNX operators that requires modification of the parser code, you can refer to the InstanceNormalization import function and the [corresponding plugin implementation](https://github.com/NVIDIA/TensorRT/tree/main/plugin/instanceNormalizationPlugin) in the main TensorRT repository. 48 | 49 | ## Quantized Operator Support 50 | 51 | As of the latest release version of TensorRT, the only two ONNX quantizing operators we support are INT8 and FP8 `QuantizeLinear` and `DequantizeLinear`. 52 | -------------------------------------------------------------------------------- /errorHelpers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | #include "errorHelpers.hpp" 5 | 6 | namespace onnx2trt 7 | { 8 | OnnxTrtException::OnnxTrtException(Status status) 9 | : mStatus(status) 10 | { 11 | } 12 | Status OnnxTrtException::getStatus() const noexcept 13 | { 14 | return mStatus; 15 | } 16 | char const* OnnxTrtException::what() const noexcept 17 | { 18 | if (mMessage.empty()) 19 | { 20 | mMessage = parserErrorStr(&mStatus); 21 | } 22 | return mMessage.c_str(); 23 | } 24 | 25 | nvinfer1::ErrorCode errorCodeToTrtCode(ErrorCode const code) 26 | { 27 | switch (code) 28 | { 29 | case ErrorCode::kSUCCESS: return nvinfer1::ErrorCode::kSUCCESS; 30 | 31 | case ErrorCode::kINTERNAL_ERROR: 32 | case ErrorCode::kMODEL_DESERIALIZE_FAILED: 33 | case ErrorCode::kREFIT_FAILED: 34 | { 35 | return nvinfer1::ErrorCode::kINTERNAL_ERROR; 36 | } 37 | 38 | case ErrorCode::kMEM_ALLOC_FAILED: 39 | { 40 | return nvinfer1::ErrorCode::kFAILED_ALLOCATION; 41 | } 42 | 43 | case ErrorCode::kINVALID_VALUE: 44 | case ErrorCode::kINVALID_GRAPH: 45 | case ErrorCode::kINVALID_NODE: 46 | case ErrorCode::kUNSUPPORTED_GRAPH: 47 | case ErrorCode::kUNSUPPORTED_NODE: 48 | case ErrorCode::kUNSUPPORTED_NODE_ATTR: 49 | case ErrorCode::kUNSUPPORTED_NODE_INPUT: 50 | case ErrorCode::kUNSUPPORTED_NODE_DATATYPE: 51 | case ErrorCode::kUNSUPPORTED_NODE_DYNAMIC: 52 | case ErrorCode::kUNSUPPORTED_NODE_SHAPE: 53 | { 54 | return nvinfer1::ErrorCode::kINVALID_ARGUMENT; 55 | } 56 | } 57 | return nvinfer1::ErrorCode::kINTERNAL_ERROR; 58 | } 59 | } // namespace onnx2trt 60 | -------------------------------------------------------------------------------- /errorHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | #pragma once 5 | 6 | #include "Status.hpp" 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #define ONNXTRT_TRY try 13 | 14 | #define ONNXTRT_CATCH_RECORD \ 15 | catch (OnnxTrtException & e) \ 16 | { \ 17 | Status status = e.getStatus(); \ 18 | mImporterCtx.getErrorRecorder()->reportError(errorCodeToTrtCode(status.code()), e.what()); \ 19 | mErrors.push_back(status); \ 20 | } \ 21 | catch (std::exception & e) \ 22 | { \ 23 | mImporterCtx.getErrorRecorder()->reportError(nvinfer1::ErrorCode::kUNSPECIFIED_ERROR, e.what()); \ 24 | mErrors.push_back(Status{ErrorCode::kINTERNAL_ERROR, e.what()}); \ 25 | } 26 | 27 | #define ONNXTRT_CATCH_LOG(logger) \ 28 | catch (OnnxTrtException & e) \ 29 | { \ 30 | Status status = e.getStatus(); \ 31 | (logger)->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, e.what()); \ 32 | mErrors.push_back(status); \ 33 | } \ 34 | catch (std::exception & e) \ 35 | { \ 36 | (logger)->log(nvinfer1::ILogger::Severity::kINTERNAL_ERROR, e.what()); \ 37 | mErrors.push_back(Status{ErrorCode::kINTERNAL_ERROR, e.what()}); \ 38 | } 39 | 40 | #define ONNXTRT_THROW(status) throw OnnxTrtException(status) 41 | 42 | #define ONNXTRT_CHECK(cond, desc, code) \ 43 | if (!(cond)) \ 44 | { \ 45 | std::ostringstream ss; \ 46 | ss << "Assertion failed: " << #cond << ": " << desc; \ 47 | ONNXTRT_THROW(MAKE_ERROR(ss.str(), (code))); \ 48 | } 49 | 50 | #define ONNXTRT_CHECK_NODE(cond, desc, node, nodeIdx, code) \ 51 | if (!(cond)) \ 52 | { \ 53 | std::ostringstream ss; \ 54 | ss << "Assertion failed: " << #cond << ": " << desc; \ 55 | ONNXTRT_THROW(MAKE_NODE_ERROR((ss.str()), (code), (node), (nodeIdx))); \ 56 | } 57 | 58 | namespace onnx2trt 59 | { 60 | inline char const* errorCodeStr(ErrorCode code) 61 | { 62 | switch (code) 63 | { 64 | case ErrorCode::kSUCCESS: return "SUCCESS"; 65 | case ErrorCode::kINTERNAL_ERROR: return "INTERNAL_ERROR"; 66 | case ErrorCode::kMEM_ALLOC_FAILED: return "MEM_ALLOC_FAILED"; 67 | case ErrorCode::kMODEL_DESERIALIZE_FAILED: return "MODEL_DESERIALIZE_FAILED"; 68 | case ErrorCode::kINVALID_VALUE: return "INVALID_VALUE"; 69 | case ErrorCode::kINVALID_GRAPH: return "INVALID_GRAPH"; 70 | case ErrorCode::kINVALID_NODE: return "INVALID_NODE"; 71 | case ErrorCode::kUNSUPPORTED_GRAPH: return "UNSUPPORTED_GRAPH"; 72 | case ErrorCode::kUNSUPPORTED_NODE: return "UNSUPPORTED_NODE"; 73 | case ErrorCode::kUNSUPPORTED_NODE_ATTR: return "UNSUPPORTED_NODE_ATTR"; 74 | case ErrorCode::kUNSUPPORTED_NODE_INPUT: return "UNSUPPORTED_NODE_INPUT"; 75 | case ErrorCode::kUNSUPPORTED_NODE_DATATYPE: return "UNSUPPORTED_NODE_DATATYPE"; 76 | case ErrorCode::kUNSUPPORTED_NODE_DYNAMIC: return "UNSUPPORTED_NODE_DYNAMIC"; 77 | case ErrorCode::kUNSUPPORTED_NODE_SHAPE: return "UNSUPPORTED_NODE_SHAPE"; 78 | case ErrorCode::kREFIT_FAILED: return "REFIT_FAILED"; 79 | } 80 | return "UNKNOWN"; 81 | }; 82 | 83 | inline std::string const parserErrorStr(nvonnxparser::IParserError const* error) 84 | { 85 | std::string const nodeInfo = "In node " + std::to_string(error->node()) + " with name: " + error->nodeName() 86 | + " and operator: " + error->nodeOperator() + " "; 87 | std::string const errorInfo 88 | = std::string("(") + error->func() + "): " + errorCodeStr(error->code()) + ": " + error->desc(); 89 | if (error->code() == ErrorCode::kMODEL_DESERIALIZE_FAILED || error->code() == ErrorCode::kREFIT_FAILED) 90 | { 91 | return errorInfo.c_str(); 92 | } 93 | return (nodeInfo + errorInfo).c_str(); 94 | } 95 | 96 | nvinfer1::ErrorCode errorCodeToTrtCode(ErrorCode const code); 97 | 98 | class OnnxTrtException : public std::exception 99 | { 100 | Status mStatus; 101 | mutable std::string mMessage; 102 | 103 | public: 104 | OnnxTrtException(Status status); 105 | 106 | Status getStatus() const noexcept; 107 | 108 | virtual char const* what() const noexcept override; 109 | 110 | virtual ~OnnxTrtException() {} 111 | }; 112 | 113 | } // namespace onnx2trt 114 | -------------------------------------------------------------------------------- /getSupportedAPITest.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include 6 | #include 7 | #include // For ::getopt 8 | #include 9 | #include "NvOnnxParser.h" 10 | #include "NvInferPlugin.h" 11 | #include "onnx_utils.hpp" 12 | #include "common.hpp" 13 | 14 | using std::cout; 15 | using std::cerr; 16 | using std::endl; 17 | 18 | void print_usage() { 19 | cout << "This program will determine whether or not an ONNX model is compatible with TensorRT. " 20 | << "If it isn't, a list of supported subgraphs and unsupported operations will be printed." << endl; 21 | cout << "Usage: getSupportedAPITest -m onnx_model.pb" << endl; 22 | cout << "Optional argument: -e TRT_engine" << endl; 23 | } 24 | 25 | void printSubGraphs(SubGraphCollection_t& subGraphs, ::ONNX_NAMESPACE::ModelProto onnx_model) 26 | { 27 | if (subGraphs.size() != 1) 28 | { 29 | cout << "The model contains unsupported Nodes. It has been partitioned to a set of supported subGraphs." << endl; 30 | cout << "There are "<< subGraphs.size() << " supported subGraphs: " << endl; 31 | cout << "NOTE: Due to some limitations with the parser, the support of specific subgraphs may not have been determined." 32 | << " Please refer to the printed subgraphs to see if they are truly supported or not." << endl; 33 | } 34 | else 35 | { 36 | cout << "The model is fully supported by TensorRT. Printing the parsed graph:" << endl; 37 | } 38 | 39 | for (auto subGraph: subGraphs) 40 | { 41 | cout << "\t{"; 42 | for (auto idx: subGraph.first) cout << "\t" << idx << "," <createNetworkV2(1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH))); 91 | auto trt_parser = common::infer_object(nvonnxparser::createParser(*trt_network, trt_logger)); 92 | 93 | initLibNvInferPlugins(&trt_logger, ""); 94 | 95 | cout << "Parsing model: " << onnx_filename << endl; 96 | 97 | std::ifstream onnx_file(onnx_filename.c_str(), 98 | std::ios::binary | std::ios::ate); 99 | std::streamsize file_size = onnx_file.tellg(); 100 | onnx_file.seekg(0, std::ios::beg); 101 | std::vector onnx_buf(file_size); 102 | 103 | if( !onnx_file.read(onnx_buf.data(), onnx_buf.size()) ) { 104 | cerr << "ERROR: Failed to read from file " << onnx_filename << endl; 105 | return -1; 106 | } 107 | 108 | ::ONNX_NAMESPACE::ModelProto onnx_model; 109 | if (!common::ParseFromFile_WAR(&onnx_model, onnx_filename.c_str())) 110 | { 111 | cout << "Failure while parsing ONNX file" << endl; 112 | return -1; 113 | } 114 | 115 | SubGraphCollection_t SubGraphCollection; 116 | 117 | // supportsModel() parses the graph and returns a list of supported subgraphs. 118 | if (!trt_parser->supportsModel(onnx_buf.data(), onnx_buf.size(), SubGraphCollection)) 119 | { 120 | cout << "Model cannot be fully parsed by TensorRT!" << endl; 121 | printSubGraphs(SubGraphCollection, onnx_model); 122 | return -1; 123 | } 124 | 125 | printSubGraphs(SubGraphCollection, onnx_model); 126 | 127 | // If -e was specified, create and save the TensorRT engine to disk. 128 | // Note we do not call trt_parser->parse() here since it's already done above in parser->supportsModel() 129 | if( !engine_filename.empty() ) { 130 | trt_builder->setMaxBatchSize(max_batch_size); 131 | auto builder_config = common::infer_object(trt_builder->createBuilderConfig()); 132 | builder_config->setMaxWorkspaceSize(max_workspace_size); 133 | 134 | cout << "input name: " << trt_network->getInput(0)->getName() << endl; 135 | cout << "output name: " << trt_network->getOutput(0)->getName() << endl; 136 | cout << "num layers: " << trt_network->getNbLayers() << endl; 137 | cout << "outputs: " << trt_network->getNbOutputs() << endl; 138 | 139 | auto trt_engine = common::infer_object(trt_builder->buildEngineWithConfig(*trt_network.get(), *builder_config.get())); 140 | 141 | if( verbosity >= (int)nvinfer1::ILogger::Severity::kWARNING ) { 142 | cout << "Writing TensorRT engine to " << engine_filename << endl; 143 | } 144 | auto engine_plan = common::infer_object(trt_engine->serialize()); 145 | std::ofstream engine_file(engine_filename.c_str(), std::ios::binary); 146 | engine_file.write(reinterpret_cast(engine_plan->data()), engine_plan->size()); 147 | engine_file.close(); 148 | } 149 | 150 | if( verbosity >= (int)nvinfer1::ILogger::Severity::kWARNING ) { 151 | cout << "All done" << endl; 152 | } 153 | return 0; 154 | } 155 | -------------------------------------------------------------------------------- /half.h: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | * SPDX-License-Identifier: Apache-2.0 4 | * 5 | * Licensed under the Apache License, Version 2.0 (the "License"); 6 | * you may not use this file except in compliance with the License. 7 | * You may obtain a copy of the License at 8 | * 9 | * http://www.apache.org/licenses/LICENSE-2.0 10 | * 11 | * Unless required by applicable law or agreed to in writing, software 12 | * distributed under the License is distributed on an "AS IS" BASIS, 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | * See the License for the specific language governing permissions and 15 | * limitations under the License. 16 | */ 17 | // 18 | // Custom wrapper around external half-precision header 19 | // 20 | // Header has some "extra parentheses" warnings when different rounding modes are used. 21 | 22 | #if defined(__GNUC__) 23 | #pragma GCC diagnostic push 24 | #pragma GCC diagnostic ignored "-Wparentheses" 25 | #endif 26 | 27 | 28 | #if defined(__clang__) 29 | #pragma clang diagnostic push 30 | #pragma clang diagnostic ignored "-Wmismatched-tags" 31 | #endif 32 | 33 | #include "ieee_half.h" 34 | typedef half_float::half float16; 35 | 36 | #if defined(__clang__) 37 | #pragma clang diagnostic pop 38 | #endif 39 | 40 | #if defined(__GNUC__) 41 | #pragma GCC diagnostic pop 42 | #endif 43 | -------------------------------------------------------------------------------- /libnvonnxparser.version: -------------------------------------------------------------------------------- 1 | { 2 | global: 3 | createNvOnnxParser_INTERNAL; 4 | createNvOnnxParserRefitter_INTERNAL; 5 | getNvOnnxParserVersion; 6 | extern "C++" { 7 | vtable*nvonnxparser::*; 8 | }; 9 | local: 10 | *; 11 | }; 12 | -------------------------------------------------------------------------------- /onnx2trt_common.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #if NV_TENSORRT_MAJOR < 4 11 | namespace nvinfer1 12 | { 13 | 14 | enum class PluginFormat : uint8_t 15 | { 16 | kNCHW = 0, //!< NCHW 17 | kNC2HW2 = 1, //!< NCHW with 2-element packed channels 18 | kNHWC8 = 2 //!< NHWC with 8-element packed channels (C must be a multiple of 8) 19 | }; 20 | // from NvInfer.h 21 | class IPluginExt : public IPlugin 22 | { 23 | public: 24 | virtual int getTensorRTVersion() const noexcept 25 | { 26 | return NV_TENSORRT_VERSION; 27 | } 28 | virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0; 29 | virtual void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, 30 | DataType type, PluginFormat format, int maxBatchSize) noexcept 31 | = 0; 32 | 33 | protected: 34 | void configure( 35 | const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) noexcept final 36 | { 37 | try 38 | { 39 | DataType type = nvinfer1::DataType::kFLOAT; 40 | PluginFormat format = nvinfer1::PluginFormat::kLINEAR; 41 | return this->configureWithFormat(inputDims, nbInputs, outputDims, nbOutputs, type, format, maxBatchSize); 42 | } 43 | catch (const std::exception& e) 44 | { 45 | nvinfer1::getLogger()->log(nvinfer1::ILogger::Severity::kERROR, e.what().c_str()); 46 | } 47 | } 48 | virtual ~IPluginExt() 49 | { 50 | } 51 | }; 52 | 53 | } // namespace nvinfer1 54 | #endif 55 | 56 | namespace onnx2trt 57 | { 58 | 59 | struct IOwnable 60 | { 61 | virtual void destroy() = 0; 62 | 63 | protected: 64 | virtual ~IOwnable() 65 | { 66 | } 67 | }; 68 | 69 | struct OwnableDeleter 70 | { 71 | void operator()(IOwnable* obj) const 72 | { 73 | obj->destroy(); 74 | } 75 | }; 76 | 77 | using UniqueOwnable = std::unique_ptr; 78 | class Plugin; 79 | class PluginV2; 80 | 81 | } // namespace onnx2trt 82 | -------------------------------------------------------------------------------- /onnx2trt_runtime.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "onnx2trt_common.hpp" 8 | 9 | namespace onnx2trt 10 | { 11 | 12 | typedef Plugin* (*plugin_deserializer)(const void* serialData, size_t serialLength); 13 | 14 | } // namespace onnx2trt 15 | -------------------------------------------------------------------------------- /onnxErrorRecorder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "onnxErrorRecorder.hpp" 6 | #include 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | 12 | ONNXParserErrorRecorder* ONNXParserErrorRecorder::create( 13 | nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder) 14 | { 15 | try 16 | { 17 | auto recorder = new ONNXParserErrorRecorder(logger, otherRecorder); 18 | if (recorder) 19 | { 20 | recorder->incRefCount(); 21 | } 22 | return recorder; 23 | } 24 | catch (const std::exception& e) 25 | { 26 | logError(logger, e.what()); 27 | return nullptr; 28 | } 29 | } 30 | 31 | void ONNXParserErrorRecorder::destroy(ONNXParserErrorRecorder*& recorder) 32 | { 33 | if (recorder) 34 | { 35 | recorder->decRefCount(); 36 | recorder = nullptr; 37 | } 38 | } 39 | 40 | void ONNXParserErrorRecorder::logError(nvinfer1::ILogger* logger, const char* str) 41 | { 42 | if (logger) 43 | { 44 | logger->log(ILogger::Severity::kERROR, str); 45 | } 46 | } 47 | 48 | ONNXParserErrorRecorder::ONNXParserErrorRecorder( 49 | nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder) 50 | : mUserRecorder(otherRecorder) 51 | , mLogger(logger) 52 | { 53 | if (mUserRecorder) 54 | { 55 | mUserRecorder->incRefCount(); 56 | } 57 | } 58 | 59 | ONNXParserErrorRecorder::~ONNXParserErrorRecorder() noexcept 60 | { 61 | if (mUserRecorder) 62 | { 63 | mUserRecorder->decRefCount(); 64 | } 65 | } 66 | 67 | void ONNXParserErrorRecorder::clear() noexcept 68 | { 69 | try 70 | { 71 | // grab a lock so that there is no addition while clearing. 72 | std::lock_guard guard(mStackLock); 73 | mErrorStack.clear(); 74 | } 75 | catch (const std::exception& e) 76 | { 77 | logError(mLogger, e.what()); 78 | } 79 | }; 80 | 81 | bool ONNXParserErrorRecorder::reportError( 82 | nvinfer1::ErrorCode val, nvinfer1::IErrorRecorder::ErrorDesc desc) noexcept 83 | { 84 | try 85 | { 86 | std::lock_guard guard(mStackLock); 87 | mErrorStack.push_back(errorPair(val, desc)); 88 | if (mUserRecorder) 89 | { 90 | mUserRecorder->reportError(val, desc); 91 | } 92 | else 93 | { 94 | logError(mLogger, desc); 95 | } 96 | } 97 | catch (const std::exception& e) 98 | { 99 | logError(mLogger, e.what()); 100 | } 101 | // All errors are considered fatal. 102 | return true; 103 | } 104 | 105 | nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::incRefCount() noexcept 106 | { 107 | // Atomically increment or decrement the ref counter. 108 | return ++mRefCount; 109 | } 110 | 111 | nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::decRefCount() noexcept 112 | { 113 | auto newVal = --mRefCount; 114 | if (newVal == 0) 115 | { 116 | delete this; 117 | } 118 | return newVal; 119 | } 120 | 121 | } // namespace onnx2trt 122 | -------------------------------------------------------------------------------- /onnxErrorRecorder.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "NvInferRuntime.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace onnx2trt 16 | { 17 | 18 | //! 19 | //! A simple implementation of the IErrorRecorder interface for 20 | //! use by ONNX importer. 21 | //! ONNX-importer Error recorder is based on a vector that pairs the error 22 | //! code and the error string into a single element. It also uses 23 | //! standard mutex and atomics in order to make sure that the code 24 | //! works in a multi-threaded environment. 25 | //! 26 | class ONNXParserErrorRecorder : public nvinfer1::IErrorRecorder 27 | { 28 | using RefCount = nvinfer1::IErrorRecorder::RefCount; 29 | using ErrorDesc = nvinfer1::IErrorRecorder::ErrorDesc; 30 | using ErrorCode = nvinfer1::ErrorCode; 31 | using IErrorRecorder = nvinfer1::IErrorRecorder; 32 | using ILogger = nvinfer1::ILogger; 33 | 34 | using errorPair = std::pair; 35 | using errorStack = std::vector; 36 | 37 | public: 38 | static ONNXParserErrorRecorder* create( 39 | ILogger* logger, IErrorRecorder* otherRecorder = nullptr); 40 | 41 | static void destroy(ONNXParserErrorRecorder*& recorder); 42 | 43 | void clear() noexcept final; 44 | RefCount incRefCount() noexcept final; 45 | RefCount decRefCount() noexcept final; 46 | bool reportError(ErrorCode val, ErrorDesc desc) noexcept final; 47 | 48 | int32_t getNbErrors() const noexcept final 49 | { 50 | return mErrorStack.size(); 51 | } 52 | 53 | ErrorCode getErrorCode(int32_t errorIdx) const noexcept final 54 | { 55 | return invalidIndexCheck(errorIdx) ? ErrorCode::kINVALID_ARGUMENT : (*this)[errorIdx].first; 56 | } 57 | 58 | ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept final 59 | { 60 | return invalidIndexCheck(errorIdx) ? "errorIdx out of range." : (*this)[errorIdx].second.c_str(); 61 | } 62 | 63 | bool hasOverflowed() const noexcept final 64 | { 65 | // This class can never overflow since we have dynamic resize via std::vector usage. 66 | return false; 67 | } 68 | 69 | protected: 70 | ONNXParserErrorRecorder(ILogger* logger, IErrorRecorder* otherRecorder = nullptr); 71 | 72 | virtual ~ONNXParserErrorRecorder() noexcept; 73 | 74 | static void logError(ILogger* logger, const char* str); 75 | 76 | // Simple helper functions. 77 | const errorPair& operator[](size_t index) const noexcept 78 | { 79 | return mErrorStack[index]; 80 | } 81 | 82 | bool invalidIndexCheck(int32_t index) const noexcept 83 | { 84 | // By converting signed to unsigned, we only need a single check since 85 | // negative numbers turn into large positive greater than the size. 86 | size_t sIndex = index; 87 | return sIndex >= mErrorStack.size(); 88 | } 89 | // Mutex to hold when locking mErrorStack. 90 | std::mutex mStackLock; 91 | 92 | // Reference count of the class. Destruction of the class when mRefCount 93 | // is not zero causes undefined behavior. 94 | std::atomic mRefCount{0}; 95 | 96 | // The error stack that holds the errors recorded by TensorRT. 97 | errorStack mErrorStack; 98 | 99 | // Original error recorder (set by user) 100 | IErrorRecorder* mUserRecorder{nullptr}; 101 | 102 | // logger 103 | ILogger* mLogger{nullptr}; 104 | }; // class ONNXParserErrorRecorder 105 | 106 | } // namespace onnx2trt 107 | -------------------------------------------------------------------------------- /onnxOpCheckers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ImporterContext.hpp" 8 | 9 | namespace onnx2trt 10 | { 11 | 12 | StringMap& getOpStaticErrorCheckerMap(); 13 | 14 | } // namespace onnx2trt 15 | -------------------------------------------------------------------------------- /onnxOpImporters.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ImporterContext.hpp" 8 | 9 | namespace onnx2trt 10 | { 11 | 12 | StringMap& getBuiltinOpImporterMap(); 13 | 14 | } // namespace onnx2trt 15 | -------------------------------------------------------------------------------- /onnxProtoUtils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "onnxProtoUtils.hpp" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace onnx2trt 16 | { 17 | void removeRawDataStrings(std::string& s) 18 | { 19 | std::string::size_type beg = 0; 20 | const std::string key = "raw_data: \""; 21 | const std::string sub = "..."; 22 | while ((beg = s.find(key, beg)) != std::string::npos) 23 | { 24 | beg += key.length(); 25 | std::string::size_type end = beg - 1; 26 | // Note: Must skip over escaped end-quotes 27 | while (s[(end = s.find("\"", ++end)) - 1] == '\\') 28 | { 29 | } 30 | if (end - beg > 128) 31 | { // Only remove large data strings 32 | s.replace(beg, end - beg, "..."); 33 | } 34 | beg += sub.length(); 35 | } 36 | } 37 | 38 | std::string removeRepeatedDataStrings(std::string const& s) 39 | { 40 | std::istringstream iss(s); 41 | std::ostringstream oss; 42 | bool is_repeat = false; 43 | for (std::string line; std::getline(iss, line);) 44 | { 45 | if (line.find("float_data:") != std::string::npos || line.find("int32_data:") != std::string::npos 46 | || line.find("int64_data:") != std::string::npos) 47 | { 48 | if (!is_repeat) 49 | { 50 | is_repeat = true; 51 | oss << line.substr(0, line.find(":") + 1) << " ...\n"; 52 | } 53 | } 54 | else 55 | { 56 | is_repeat = false; 57 | oss << line << "\n"; 58 | } 59 | } 60 | return oss.str(); 61 | } 62 | 63 | std::string onnxIRVersionAsString(int64_t irVersion) 64 | { 65 | int64_t verMajor = irVersion / 1000000; 66 | int64_t verMinor = irVersion % 1000000 / 10000; 67 | int64_t verPatch = irVersion % 10000; 68 | return (std::to_string(verMajor) + "." + std::to_string(verMinor) + "." + std::to_string(verPatch)); 69 | } 70 | 71 | } // namespace onnx2trt 72 | -------------------------------------------------------------------------------- /onnxProtoUtils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "Status.hpp" 8 | #include "errorHelpers.hpp" 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | #if USE_LITE_PROTOBUF 20 | #include 21 | #else // !USE_LITE_PROTOBUF 22 | #include 23 | #include 24 | #endif // USE_LITE_PROTOBUF 25 | 26 | // This file contains the declaration of helper functions used for converting and working with Protobuf files. 27 | 28 | namespace onnx2trt 29 | { 30 | 31 | // Removes raw data from the text representation of an ONNX model. 32 | void removeRawDataStrings(std::string& s); 33 | 34 | // Removes float_data, int32_data etc. from the text representation of an ONNX model. 35 | std::string removeRepeatedDataStrings(std::string const& s); 36 | 37 | // Returns the ONNX IR version as a string. 38 | std::string onnxIRVersionAsString(int64_t ir_version = ::ONNX_NAMESPACE::IR_VERSION); 39 | 40 | // Converts a raw protobuf::Message or protobuf::MessageLite into a string representation. 41 | template 42 | std::string convertProtoToString(ProtoMessage const& message) 43 | { 44 | std::string s{}; 45 | // Textformat available in full proto only. Return only the name when using protobuf-lite. 46 | #if USE_LITE_PROTOBUF 47 | s = "Node name: " + message.name(); 48 | return s; 49 | #else 50 | ::google::protobuf::TextFormat::PrintToString(message, &s); 51 | removeRawDataStrings(s); 52 | s = removeRepeatedDataStrings(s); 53 | return s; 54 | #endif // USE_LITE_PROTOBUF 55 | } 56 | 57 | // Deserializes an ONNX ModelProto passed in as a protobuf::Message or a protobuf::MessageLite. 58 | template 59 | void deserializeOnnxModel(void const* serializedModel, size_t serializedModelSize, ProtoMessage* model) 60 | { 61 | google::protobuf::io::ArrayInputStream rawInput(serializedModel, serializedModelSize); 62 | google::protobuf::io::CodedInputStream codedInput(&rawInput); 63 | #if GOOGLE_PROTOBUF_VERSION >= 3011000 64 | // Starting Protobuf 3.11 accepts only single parameter. 65 | codedInput.SetTotalBytesLimit(std::numeric_limits::max()); 66 | #else 67 | // Note: This WARs the very low default size limit (64MB) 68 | codedInput.SetTotalBytesLimit(std::numeric_limits::max(), std::numeric_limits::max() / 4); 69 | #endif 70 | ONNXTRT_CHECK(model->ParseFromCodedStream(&codedInput), "Failed to parse the ONNX model.", 71 | ErrorCode::kMODEL_DESERIALIZE_FAILED); 72 | } 73 | 74 | // Helper function to dispatch to deserializeOnnxModel when user provides a path to the model. 75 | template 76 | bool ParseFromFileAsBinary(ProtoMessage* msg, char const* filename) 77 | { 78 | std::ifstream onnxFile(filename, std::ios::ate | std::ios::binary); 79 | if (!onnxFile) 80 | { 81 | std::cerr << "Could not open file " << std::string(filename) << std::endl; 82 | return false; 83 | } 84 | // Determine the file size 85 | auto fileSize = onnxFile.tellg(); 86 | onnxFile.seekg(0, std::ios::beg); 87 | 88 | // Create buffer and read tne entire file to the buffer. 89 | std::vector buffer(fileSize); 90 | if (!onnxFile.read(buffer.data(), fileSize)) 91 | { 92 | std::cerr << "Error reading file: " << filename << std::endl; 93 | return false; 94 | } 95 | 96 | deserializeOnnxModel(buffer.data(), buffer.size(), msg); 97 | return true; 98 | } 99 | 100 | // ostream overload for printing NodeProtos. 101 | inline std::ostream& operator<<(std::ostream& stream, ::ONNX_NAMESPACE::NodeProto const& message) 102 | { 103 | stream << convertProtoToString(message); 104 | return stream; 105 | } 106 | 107 | } // namespace onnx2trt 108 | -------------------------------------------------------------------------------- /onnx_backend_test.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | from __future__ import unicode_literals 7 | 8 | import os 9 | 10 | import unittest 11 | import onnx.backend.test 12 | 13 | import onnx_tensorrt.backend as trt 14 | 15 | # This is a pytest magic variable to load extra plugins 16 | pytest_plugins = 'onnx.backend.test.report', 17 | 18 | backend_test = onnx.backend.test.BackendTest(trt, __name__) 19 | 20 | # Include all of the nodes that we support. 21 | # Onnx native node tests 22 | backend_test.include(r'.*test_abs.*') 23 | backend_test.include(r'.*test_acos.*') 24 | backend_test.include(r'.*test_acosh.*') 25 | backend_test.include(r'.*test_add.*') 26 | backend_test.include(r'.*test_argmax.*') 27 | backend_test.include(r'.*test_argmin.*') 28 | backend_test.include(r'.*test_asin.*') 29 | backend_test.include(r'.*test_asinh.*') 30 | backend_test.include(r'.*test_atan.*') 31 | backend_test.include(r'.*test_atanh.*') 32 | backend_test.include(r'.*test_averagepool.*') 33 | backend_test.include(r'.*test_AvgPool.*') 34 | backend_test.include(r'.*test_BatchNorm.*eval.*') 35 | backend_test.include(r'.*test_ceil.*') 36 | backend_test.include(r'.*test_celu.*') 37 | backend_test.include(r'.*test_clip.*') 38 | backend_test.include(r'.*test_concat.*') 39 | backend_test.include(r'.*test_constant.*') 40 | backend_test.include(r'.*test_Conv[1-3]d*') 41 | backend_test.include(r'.*test_cos.*') 42 | backend_test.include(r'.*test_cosh.*') 43 | backend_test.include(r'.*test_depthtospace.*') 44 | backend_test.include(r'.*test_div.*') 45 | backend_test.include(r'.*test_dropout.*') 46 | backend_test.include(r'.*test_ELU*') 47 | backend_test.include(r'.*test_elu.*') 48 | backend_test.include(r'.*test_equal.*') 49 | backend_test.include(r'.*test_Embedding*') 50 | backend_test.include(r'.*test_exp.*') 51 | backend_test.include(r'.*test_eyelike.*') 52 | backend_test.include(r'.*test_flatten.*') 53 | backend_test.include(r'.*test_floor.*') 54 | backend_test.include(r'.*test_gather.*') 55 | backend_test.include(r'.*test_gemm.*') 56 | backend_test.include(r'.*test_globalaveragepool.*') 57 | backend_test.include(r'.*test_globalmaxpool.*') 58 | backend_test.include(r'.*test_greater.*') 59 | backend_test.include(r'.*test_hardsigmoid.*') 60 | backend_test.include(r'.*test_identity.*') 61 | backend_test.include(r'.*test_LeakyReLU*') 62 | backend_test.include(r'.*test_leakyrelu.*') 63 | backend_test.include(r'.*test_less.*') 64 | backend_test.include(r'.*test_Linear.*') 65 | backend_test.include(r'.*test_log.*') 66 | backend_test.include(r'.*test_logsoftmax.*') 67 | backend_test.include(r'.*test_LogSoftmax.*') 68 | backend_test.include(r'.*test_log_softmax.*') 69 | backend_test.include(r'.*test_lrn.*') 70 | backend_test.include(r'.*test_matmul.*') 71 | backend_test.include(r'.*test_max.*') 72 | backend_test.include(r'.*test_MaxPool[1-9]d.*') 73 | backend_test.include(r'.*test_mean.*') 74 | backend_test.include(r'.*test_min.*') 75 | backend_test.include(r'.*test_mul.*') 76 | backend_test.include(r'.*test_neg.*') 77 | backend_test.include(r'.*test_not.*') 78 | backend_test.include(r'.*test_operator_addmm.*') 79 | backend_test.include(r'.*test_operator_basic.*') 80 | backend_test.include(r'.*test_operator_chunk.*') 81 | backend_test.include(r'.*test_operator_clip.*') 82 | backend_test.include(r'.*test_operator_concat2.*') 83 | backend_test.include(r'.*test_operator_conv_.*') 84 | backend_test.include(r'.*test_operator_exp.*') 85 | backend_test.include(r'.*test_operator_flatten.*') 86 | backend_test.include(r'.*test_operator_index.*') 87 | backend_test.include(r'.*test_operator_max_.*') 88 | backend_test.include(r'.*test_operator_maxpool.*') 89 | backend_test.include(r'.*test_operator_min.*') 90 | backend_test.include(r'.*test_operator_mm.*') 91 | backend_test.include(r'.*test_operator_non_float_params.*') 92 | backend_test.include(r'.*test_operator_params.*') 93 | backend_test.include(r'.*test_operator_permute2.*') 94 | backend_test.include(r'.*test_operator_pow.*') 95 | backend_test.include(r'.*test_operator_reduced_mean_.*') 96 | backend_test.include(r'.*test_operator_reduced_mean_keepdim.*') 97 | backend_test.include(r'.*test_operator_reduced_sum_.*') 98 | backend_test.include(r'.*test_operator_reduced_sum_keepdim.*') 99 | backend_test.include(r'.*test_operator_selu.*') 100 | backend_test.include(r'.*test_operator_sqrt.*') 101 | backend_test.include(r'.*test_operator_symbolic_override.*') 102 | backend_test.include(r'.*test_operator_symbolic_override_nested.*') 103 | backend_test.include(r'.*test_operator_view.*') 104 | backend_test.include(r'.*test_pow.*') 105 | backend_test.include(r'.*test_PoissonNLLLLoss_no_reduce*') 106 | backend_test.include(r'.*test_reciprocal.*') 107 | backend_test.include(r'.*test_reduce.*') 108 | backend_test.include(r'.*test_ReLU*') 109 | backend_test.include(r'.*test_relu.*') 110 | backend_test.include(r'.*test_selu.*') 111 | backend_test.include(r'.*test_shape.*') 112 | backend_test.include(r'.*test_Sigmoid*') 113 | backend_test.include(r'.*test_sigmoid.*') 114 | backend_test.include(r'.*test_sin.*') 115 | backend_test.include(r'.*test_sinh.*') 116 | backend_test.include(r'.*test_size.*') 117 | backend_test.include(r'.*test_Softmax*') 118 | backend_test.include(r'.*test_softmax.*') 119 | backend_test.include(r'.*test_Softmin*') 120 | backend_test.include(r'.*test_Softplus*') 121 | backend_test.include(r'.*test_softplus.*') 122 | backend_test.include(r'.*test_softsign.*') 123 | backend_test.include(r'.*test_sqrt.*') 124 | backend_test.include(r'.*test_squeeze_cuda') 125 | backend_test.include(r'.*test_sub.*') 126 | backend_test.include(r'.*test_sum.*') 127 | backend_test.include(r'.*test_tan.*') 128 | backend_test.include(r'.*test_Tanh*') 129 | backend_test.include(r'.*test_tanh.*') 130 | backend_test.include(r'.*test_thresholdedrelu.*') 131 | backend_test.include(r'.*test_transpose.*') 132 | backend_test.include(r'.*test_unsqueeze.*') 133 | backend_test.include(r'.*test_ZeroPad2d*') 134 | 135 | # # Onnx native model tests 136 | backend_test.include(r'.*test_bvlc_alexnet.*') 137 | backend_test.include(r'.*test_densenet121.*') 138 | backend_test.include(r'.*test_inception_v1.*') 139 | backend_test.include(r'.*test_inception_v2.*') 140 | backend_test.include(r'.*test_resnet50.*') 141 | backend_test.include(r'.*test_shufflenet.*') 142 | backend_test.include(r'.*test_squeezenet.*') 143 | backend_test.include(r'.*test_vgg19.*') 144 | backend_test.include(r'.*test_zfnet512.*') 145 | 146 | 147 | #TRT custom tests 148 | backend_test.include(r'.*test_basic_conv_.*custom.*') 149 | backend_test.include(r'.*test_conv_.*custom.*') 150 | backend_test.include(r'.*test_convtranspose.*custom.*') 151 | backend_test.include(r'.*test_batchnorm.*custom.*') 152 | backend_test.include(r'.*test_reshape.*custom.*') 153 | backend_test.include(r'.*test_prelu.*custom.*') 154 | backend_test.include(r'.*test_topk.*custom.*') 155 | backend_test.include(r'.*test_upsample.*custom.*') 156 | backend_test.include(r'.*test_constant_pad_custom.*') 157 | backend_test.include(r'.*test_resize.*custom.*') 158 | backend_test.include(r'.*test_split.*custom.*') 159 | backend_test.include(r'.*test_instancenorm_.*_custom.*') 160 | backend_test.include(r'.*test_slice.*custom.*') 161 | 162 | 163 | # exclude unenabled ops get pulled in with wildcards 164 | # test_constant_pad gets pulled in with the test_constant* wildcard. Explicitly disable padding tests for now. 165 | backend_test.exclude(r'.*test_constant_pad.*') 166 | backend_test.exclude(r'.*test_constantofshape.*') 167 | backend_test.exclude(r'.*test_expand.*') 168 | # Operator MATMULINTEGER is not supported by TRT 169 | backend_test.exclude(r'.*test_matmulinteger.*') 170 | backend_test.exclude(r'.*test_maxpool.*') 171 | backend_test.exclude(r'.*test_maxunpool.*') 172 | # Mismatch: 0.476%, relative diff is good. 173 | # Absolute diff failed because 174 | # numpy compares the difference between actual and desired to atol + rtol * abs(desired) 175 | backend_test.exclude(r'.*test_convtranspose_3d_custom_cuda') 176 | # dilations not supported in ConvTRanspose layer 177 | backend_test.exclude(r'.*test_convtranspose_dilations_custom_cuda') 178 | 179 | globals().update(backend_test 180 | .enable_report() 181 | .test_cases) 182 | 183 | if __name__ == '__main__': 184 | unittest.main() 185 | -------------------------------------------------------------------------------- /onnx_tensorrt/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | from __future__ import absolute_import 4 | 5 | from . import backend 6 | 7 | __version__ = "10.11.0" 8 | -------------------------------------------------------------------------------- /onnx_tensorrt/backend.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | from __future__ import print_function 4 | from .tensorrt_engine import Engine 5 | import tensorrt as trt 6 | from onnx.backend.base import Backend, BackendRep, Device, DeviceType, namedtupledict 7 | import onnx 8 | from onnx import helper as onnx_helper 9 | from onnx import numpy_helper 10 | import numpy as np 11 | import six 12 | 13 | # HACK Should look for a better way/place to do this 14 | from ctypes import cdll, c_char_p 15 | libcudart = cdll.LoadLibrary('libcudart.so') 16 | libcudart.cudaGetErrorString.restype = c_char_p 17 | def cudaSetDevice(device_idx): 18 | ret = libcudart.cudaSetDevice(device_idx) 19 | if ret != 0: 20 | error_string = libcudart.cudaGetErrorString(ret) 21 | if isinstance(error_string, bytes): 22 | error_string = error_string.decode("utf-8") 23 | raise RuntimeError("cudaSetDevice: " + error_string) 24 | 25 | def count_trailing_ones(vals): 26 | count = 0 27 | for val in reversed(vals): 28 | if val != 1: 29 | return count 30 | count += 1 31 | return count 32 | 33 | TRT_LOGGER = trt.Logger(trt.Logger.WARNING) 34 | 35 | class TensorRTBackendRep(BackendRep): 36 | def __init__(self, model, device, 37 | max_workspace_size=None, serialize_engine=False, verbose=False, **kwargs): 38 | if not isinstance(device, Device): 39 | device = Device(device) 40 | self._set_device(device) 41 | self._logger = TRT_LOGGER 42 | self.builder = trt.Builder(self._logger) 43 | self.network = self.builder.create_network(flags=0) 44 | self.parser = trt.OnnxParser(self.network, self._logger) 45 | self.config = self.builder.create_builder_config() 46 | self.serialize_engine = serialize_engine 47 | self.verbose = verbose 48 | self.dynamic = False 49 | 50 | if self.verbose: 51 | print(f'\nRunning {model.graph.name}...') 52 | TRT_LOGGER.min_severity = trt.Logger.VERBOSE 53 | 54 | if not isinstance(model, six.string_types): 55 | model_str = model.SerializeToString() 56 | else: 57 | model_str = model 58 | 59 | if not trt.init_libnvinfer_plugins(TRT_LOGGER, ""): 60 | msg = "Failed to initialize TensorRT's plugin library." 61 | raise RuntimeError(msg) 62 | 63 | if not self.parser.parse(model_str): 64 | error = self.parser.get_error(0) 65 | msg = "While parsing node number %i:\n" % error.node() 66 | msg += ("%s:%i In function %s:\n[%i] %s" % 67 | (error.file(), error.line(), error.func(), 68 | error.code(), error.desc())) 69 | raise RuntimeError(msg) 70 | if max_workspace_size is None: 71 | max_workspace_size = 1 << 28 72 | 73 | self.config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, max_workspace_size) 74 | 75 | num_inputs = self.network.num_inputs 76 | for idx in range(num_inputs): 77 | inp_tensor = self.network.get_input(idx) 78 | if inp_tensor.is_shape_tensor or -1 in inp_tensor.shape: 79 | self.dynamic = True 80 | break 81 | 82 | if self.verbose: 83 | for layer in self.network: 84 | print(layer) 85 | 86 | print(f'Output shape: {self.network[-1].get_output(0).shape}') 87 | 88 | if self.dynamic: 89 | if self.verbose: 90 | print("Found dynamic inputs! Deferring engine build to run stage") 91 | else: 92 | self._build_engine() 93 | 94 | self._output_shapes = {} 95 | self._output_dtype = {} 96 | for output in model.graph.output: 97 | dims = output.type.tensor_type.shape.dim 98 | output_shape = tuple([dim.dim_value for dim in dims]) 99 | self._output_shapes[output.name] = output_shape 100 | self._output_dtype[output.name] = output.type.tensor_type.elem_type 101 | 102 | def _build_engine(self, inputs=None): 103 | """ 104 | Builds TensorRT Engine, with BuilderConfig if needed 105 | :param inputs: inputs to the model; if not None, this means we are building the engine at run time, 106 | because we need to register optimization profiles for some inputs 107 | :type inputs: List of np.ndarray 108 | """ 109 | 110 | if inputs: 111 | opt_profile = self.builder.create_optimization_profile() 112 | 113 | # Set optimization profiles for the input bindings that need them 114 | for i in range(self.network.num_inputs): 115 | inp_tensor = self.network.get_input(i) 116 | name = inp_tensor.name 117 | # Set profiles for shape tensors 118 | if inp_tensor.is_shape_tensor: 119 | if inputs[i].ndim > 0: 120 | val_list = inputs[i].tolist() 121 | opt_profile.set_shape_input(name, val_list, val_list, val_list) 122 | else: 123 | opt_profile.set_shape_input(name, [inputs[i]], [inputs[i]], [inputs[i]]) 124 | # Set profiles for dynamic execution tensors 125 | elif -1 in inp_tensor.shape: 126 | opt_profile.set_shape(name, inputs[i].shape, inputs[i].shape, inputs[i].shape) 127 | 128 | self.config.add_optimization_profile(opt_profile) 129 | trt_blob = self.builder.build_serialized_network(self.network, self.config) 130 | 131 | if trt_blob is None: 132 | raise RuntimeError("Failed to build TensorRT engine from network") 133 | 134 | trt_engine = self._deserialize(trt_blob) 135 | self.engine = Engine(trt_engine) 136 | 137 | def _set_device(self, device): 138 | self.device = device 139 | assert(device.type == DeviceType.CUDA) 140 | cudaSetDevice(device.device_id) 141 | 142 | def _deserialize(self, trt_blob): 143 | self.runtime = trt.Runtime(TRT_LOGGER) 144 | del self.parser # Parser no longer needed for ownership of plugins 145 | trt_engine = self.runtime.deserialize_cuda_engine(trt_blob) 146 | return trt_engine 147 | 148 | def run(self, inputs, **kwargs): 149 | """Execute the prepared engine and return the outputs as a named tuple. 150 | inputs -- Input tensor(s) as a Numpy array or list of Numpy arrays. 151 | """ 152 | if isinstance(inputs, np.ndarray): 153 | inputs = [inputs] 154 | 155 | if self.dynamic: 156 | self._build_engine(inputs) 157 | 158 | outputs = self.engine.run(inputs) 159 | output_names = [output.name for output in self.engine.outputs] 160 | 161 | for i, (name, array) in enumerate(zip(output_names, outputs)): 162 | output_shape = self._output_shapes[name] 163 | # HACK WAR for unknown output shape in run_node 164 | if output_shape == (-99,): 165 | # WAR for TRT requiring at least 2 dims (NC) 166 | min_dims = 2 167 | if _tensorrt_version()[0] < 4: 168 | # WAR for TRT only supporting 4D (NCHW) tensors 169 | min_dims = 4 170 | if array.ndim == min_dims: 171 | npadding_dims = count_trailing_ones(array.shape) 172 | if npadding_dims > 0: 173 | outputs[i] = array.reshape( 174 | array.shape[:-npadding_dims]) 175 | else: 176 | # HACK WAR replace fixed batch dim with variable 177 | if self._output_dtype[name] == onnx.TensorProto.INT64 and array.dtype == np.int32: 178 | casted_output = np.array(outputs[i], dtype=np.int64) 179 | if np.equal(outputs[i], casted_output).all(): 180 | outputs[i] = np.array(outputs[i], dtype=np.int64) 181 | if self._output_dtype[name] == onnx.TensorProto.DOUBLE and array.dtype == np.float32: 182 | casted_output = np.array(outputs[i], dtype=np.double) 183 | if np.equal(outputs[i], casted_output).all(): 184 | outputs[i] = np.array(outputs[i], dtype=np.double) 185 | 186 | outputs_tuple = namedtupledict('Outputs', output_names)(*outputs) 187 | return namedtupledict('Outputs', output_names)(*outputs) 188 | 189 | def np2onnx_dtype(np_dtype): 190 | if np_dtype == np.dtype('float32'): 191 | return onnx.TensorProto.FLOAT 192 | elif np_dtype == np.dtype('float16'): 193 | return onnx.TensorProto.FLOAT16 194 | elif np_dtype == np.dtype('int64'): 195 | return onnx.TensorProto.INT64 196 | elif np_dtype == np.dtype('int32'): 197 | return onnx.TensorProto.INT32 198 | elif np_dtype == np.dtype('int8'): 199 | return onnx.TensorProto.INT8 200 | elif np_dtype == np.dtype('double'): 201 | return onnx.TensorProto.DOUBLE 202 | else: 203 | raise TypeError("Unsupported data type:", np_dtype) 204 | 205 | def make_node_test_model(node, inputs, use_weights=True): 206 | # HACK TODO: The output info is unknown here; not sure what the best solution is 207 | output_dtype = np.float32 # Dummy value only 208 | output_shape = [-99] # Dummy value only 209 | graph_inputs = [onnx_helper.make_tensor_value_info( 210 | name, np2onnx_dtype(array.dtype), array.shape) 211 | for name, array in zip(node.input, inputs)] 212 | graph_outputs = [onnx_helper.make_tensor_value_info( 213 | name, np2onnx_dtype(output_dtype), output_shape) 214 | for name in node.output] 215 | if use_weights: 216 | # Add initializers for all inputs except the first 217 | initializers = [onnx_helper.make_tensor( 218 | name, np2onnx_dtype(array.dtype), array.shape, array.flatten().tolist()) 219 | for name, array in zip(node.input[1:], inputs[1:])] 220 | else: 221 | initializers = [] 222 | graph = onnx_helper.make_graph( 223 | [node], "RunNodeGraph_" + node.op_type, 224 | graph_inputs, graph_outputs, initializer=initializers) 225 | model = onnx_helper.make_model(graph) 226 | return model 227 | 228 | class TensorRTBackend(Backend): 229 | @classmethod 230 | def prepare(cls, model, device='CUDA:0', **kwargs): 231 | """Build an engine from the given model. 232 | model -- An ONNX model as a deserialized protobuf, or a string or file- 233 | object containing a serialized protobuf. 234 | """ 235 | super(TensorRTBackend, cls).prepare(model, device, **kwargs) 236 | return TensorRTBackendRep(model, device, **kwargs) 237 | @classmethod 238 | def run_model(cls, model, inputs, device='CUDA:0', **kwargs): 239 | """Build and run an engine from the given model. 240 | model -- An ONNX model as a deserialized protobuf, or a string or file- 241 | object containing a serialized protobuf. 242 | inputs -- Input tensor(s) as a Numpy array or list of Numpy arrays. 243 | """ 244 | return cls.prepare(model, device, **kwargs).run(inputs) 245 | @classmethod 246 | def run_node(cls, node, inputs, device='CUDA:0'): 247 | """Build and run an engine from the given node. 248 | node -- An ONNX node as a deserialized protobuf. 249 | Note: This function is intended for testing purposes only; 250 | use prepare() or run_model() for other purposes. 251 | """ 252 | super(TensorRTBackend, cls).run_node(node, inputs, device) 253 | # HACK TODO: This is somewhat dodgy. We first try with weights for all 254 | # inputs but the first, then we try again with no weights if 255 | # the first try fails. 256 | model = make_node_test_model(node, inputs, use_weights=True) 257 | try: results = TensorRTBackend.prepare(model, device).run(inputs[:1]) 258 | except RuntimeError: 259 | model = make_node_test_model(node, inputs, use_weights=False) 260 | results = TensorRTBackend.prepare(model, device).run(inputs) 261 | return results 262 | @classmethod 263 | def supports_device(cls, device_str): 264 | device = Device(device_str) 265 | return device.type == DeviceType.CUDA 266 | 267 | prepare = TensorRTBackend.prepare 268 | run_node = TensorRTBackend.run_node 269 | run_model = TensorRTBackend.run_model 270 | supports_device = TensorRTBackend.supports_device 271 | -------------------------------------------------------------------------------- /onnx_tensorrt/tensorrt_engine.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import tensorrt as trt 4 | import pycuda.driver 5 | import pycuda.gpuarray 6 | import pycuda.autoinit 7 | import numpy as np 8 | from six import string_types 9 | 10 | class Binding(object): 11 | def __init__(self, engine, idx_or_name): 12 | if isinstance(idx_or_name, string_types): 13 | self.name = idx_or_name 14 | else: 15 | self.index = idx_or_name 16 | self.name = engine.get_tensor_name(self.index) 17 | if self.name is None: 18 | raise IndexError("Binding index out of range: %i" % self.index) 19 | self.is_input = engine.get_tensor_mode(self.name) == trt.TensorIOMode.INPUT 20 | 21 | 22 | dtype = engine.get_tensor_dtype(self.name) 23 | dtype_map = {trt.DataType.FLOAT: np.float32, 24 | trt.DataType.HALF: np.float16, 25 | trt.DataType.INT8: np.int8, 26 | trt.DataType.BOOL: np.bool_,} 27 | if hasattr(trt.DataType, 'INT32'): 28 | dtype_map[trt.DataType.INT32] = np.int32 29 | if hasattr(trt.DataType, 'INT64'): 30 | dtype_map[trt.DataType.INT64] = np.int64 31 | 32 | self.dtype = dtype_map[dtype] 33 | shape = engine.get_tensor_shape(self.name) 34 | 35 | self.shape = tuple(shape) 36 | self._host_buf = None 37 | self._device_buf = None 38 | @property 39 | def host_buffer(self): 40 | if self._host_buf is None: 41 | self._host_buf = pycuda.driver.pagelocked_empty(self.shape, self.dtype) 42 | return self._host_buf 43 | @property 44 | def device_buffer(self): 45 | if self._device_buf is None: 46 | self._device_buf = pycuda.gpuarray.empty(self.shape, self.dtype) 47 | return self._device_buf 48 | def get_async(self, stream): 49 | src = self.device_buffer 50 | dst = self.host_buffer 51 | src.get_async(stream, dst) 52 | return dst 53 | 54 | def squeeze_hw(x): 55 | if x.shape[-2:] == (1, 1): 56 | x = x.reshape(x.shape[:-2]) 57 | elif x.shape[-1] == 1: 58 | x = x.reshape(x.shape[:-1]) 59 | return x 60 | 61 | def check_input_validity(input_idx, input_array, input_binding): 62 | # Check shape 63 | trt_shape = tuple(input_binding.shape) 64 | onnx_shape = tuple(input_array.shape) 65 | 66 | if onnx_shape != trt_shape: 67 | if not (trt_shape == (1,) and onnx_shape == ()) : 68 | raise ValueError("Wrong shape for input %i. Expected %s, got %s." % 69 | (input_idx, trt_shape, onnx_shape)) 70 | 71 | # Check dtype 72 | if input_array.dtype != input_binding.dtype: 73 | #TRT does not support INT64, need to convert to INT32 74 | if input_array.dtype == np.int64 and input_binding.dtype == np.int32: 75 | casted_input_array = np.array(input_array, copy=True, dtype=np.int32) 76 | if np.equal(input_array, casted_input_array).all(): 77 | input_array = casted_input_array 78 | else: 79 | raise TypeError("Wrong dtype for input %i. Expected %s, got %s. Cannot safely cast." % 80 | (input_idx, input_binding.dtype, input_array.dtype)) 81 | else: 82 | raise TypeError("Wrong dtype for input %i. Expected %s, got %s." % 83 | (input_idx, input_binding.dtype, input_array.dtype)) 84 | return input_array 85 | 86 | 87 | class Engine(object): 88 | def __init__(self, trt_engine): 89 | self.engine = trt_engine 90 | 91 | bindings = [Binding(self.engine, i) 92 | for i in range(self.engine.num_io_tensors)] 93 | self.binding_addrs = [b.device_buffer.ptr for b in bindings] 94 | self.inputs = [b for b in bindings if b.is_input] 95 | self.outputs = [b for b in bindings if not b.is_input] 96 | 97 | for binding in self.inputs + self.outputs: 98 | _ = binding.device_buffer # Force buffer allocation 99 | for binding in self.outputs: 100 | _ = binding.host_buffer # Force buffer allocation 101 | self.context = self.engine.create_execution_context() 102 | self.stream = pycuda.driver.Stream() 103 | def __del__(self): 104 | if self.engine is not None: 105 | del self.engine 106 | 107 | def run(self, inputs): 108 | # len(inputs) > len(self.inputs) with Shape operator, input is never used 109 | # len(inputs) == len(self.inputs) for other operators 110 | 111 | if len(inputs) < len(self.inputs): 112 | raise ValueError("Not enough inputs. Expected %i, got %i." % 113 | (len(self.inputs), len(inputs))) 114 | if isinstance(inputs, dict): 115 | inputs = [inputs[b.name] for b in self.inputs] 116 | 117 | for i, (input_array, input_binding) in enumerate(zip(inputs, self.inputs)): 118 | input_array = check_input_validity(i, input_array, input_binding) 119 | input_binding_array = input_binding.device_buffer 120 | input_binding_array.set_async(input_array, self.stream) 121 | 122 | num_io = self.engine.num_io_tensors 123 | for i in range(num_io): 124 | tensor_name = self.engine.get_tensor_name(i) 125 | if i < len(inputs) and self.engine.is_shape_inference_io(tensor_name): 126 | self.context.set_tensor_address(tensor_name, inputs[i].ctypes.data) 127 | else: 128 | self.context.set_tensor_address(tensor_name, self.binding_addrs[i]) 129 | 130 | self.context.execute_async_v3(self.stream.handle) 131 | 132 | results = [output.get_async(self.stream) 133 | for output in self.outputs] 134 | self.stream.synchronize() 135 | return results 136 | 137 | def run_no_dma(self): 138 | self.context.execute_async_v3(self.stream.handle) 139 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | import os 4 | import sys 5 | import onnx_tensorrt 6 | from setuptools import setup, find_packages 7 | 8 | def no_publish(): 9 | blacklist = ['register'] 10 | for cmd in blacklist: 11 | if cmd in sys.argv: 12 | raise RuntimeError("Command \"{}\" blacklisted".format(cmd)) 13 | 14 | 15 | REQUIRED_PACKAGES = [ 16 | "pycuda", 17 | "numpy", 18 | "onnx" 19 | ] 20 | 21 | def main(): 22 | no_publish() 23 | setup( 24 | name="onnx_tensorrt", 25 | version=onnx_tensorrt.__version__, 26 | description="ONNX-TensorRT - TensorRT backend for running ONNX models", 27 | long_description=open("README.md", "r", encoding="utf-8").read(), 28 | url="https://github.com/onnx/onnx-tensorrt", 29 | author="NVIDIA", 30 | author_email="svc_tensorrt@nvidia.com", 31 | classifiers=[ 32 | 'Intended Audience :: Developers', 33 | 'Programming Language :: Python :: 3', 34 | ], 35 | install_requires=REQUIRED_PACKAGES, 36 | packages=find_packages(), 37 | zip_safe=True, 38 | ) 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /toposort.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | using std::cout; 12 | using std::cerr; 13 | using std::endl; 14 | 15 | namespace 16 | { 17 | 18 | enum NodeState 19 | { 20 | NODE_UNVISITED, 21 | NODE_ACTIVE, 22 | NODE_VISITED 23 | }; 24 | 25 | template 26 | bool get_post_order(size_t node_idx, Container const& nodes, std::unordered_map const& node_map, 27 | std::vector* node_states, std::vector* order) 28 | { 29 | NodeState& node_state = node_states->at(node_idx); 30 | if (node_state == NODE_ACTIVE) 31 | { 32 | // Cycle detected! 33 | cerr << "ERROR: Graph contains a cycle" << endl; 34 | return false; 35 | } 36 | else if (node_state == NODE_VISITED) 37 | { 38 | return true; 39 | } 40 | else 41 | { 42 | node_state = NODE_ACTIVE; 43 | // TODO: This .Get().input() is highly specific to protobuf, should 44 | // generalise it somehow. 45 | for (auto const& input : nodes.Get(node_idx).input()) 46 | { 47 | if (!node_map.count(input)) 48 | { 49 | // Input node not found in graph! 50 | // cerr << "ERROR: Input node not found in graph: " 51 | // << input << endl; 52 | // return false; 53 | continue; // Skip missing input edges 54 | } 55 | size_t input_node_idx = node_map.at(input); 56 | if (!get_post_order(input_node_idx, nodes, node_map, node_states, order)) 57 | { 58 | return false; 59 | } 60 | } 61 | node_state = NODE_VISITED; 62 | order->push_back(node_idx); 63 | } 64 | return true; 65 | } 66 | 67 | } // anonymous namespace 68 | 69 | template 70 | bool toposort(Container const& nodes, std::vector* order) 71 | { 72 | std::unordered_map node_map; 73 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 74 | { 75 | // TODO: This .Get().input() is highly specific to protobuf, should 76 | // generalise it somehow. 77 | for (auto const& output : nodes.Get(i).output()) 78 | { 79 | // Empty output strings mean null outputs, do not register them. 80 | if (output.empty()) 81 | { 82 | continue; 83 | } 84 | if (!node_map.emplace(output, i).second) 85 | { 86 | // Output name appears more than once in graph! 87 | cerr << "ERROR: Output name is not unique: " << output << endl; 88 | return false; 89 | } 90 | } 91 | } 92 | order->reserve(nodes.size()); 93 | std::vector node_states(nodes.size(), NODE_UNVISITED); 94 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 95 | { 96 | if (!get_post_order(i, nodes, node_map, &node_states, order)) 97 | { 98 | return false; 99 | } 100 | } 101 | return true; 102 | } 103 | -------------------------------------------------------------------------------- /weightUtils.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "weightUtils.hpp" 6 | #include "bfloat16.hpp" 7 | #include "half.h" 8 | #include // For std::memcpy 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace onnx2trt 18 | { 19 | 20 | char const* getDtypeName(int32_t onnxDtype) 21 | { 22 | switch (onnxDtype) 23 | { 24 | case ::ONNX_NAMESPACE::TensorProto::FLOAT: return "FLOAT"; 25 | case ::ONNX_NAMESPACE::TensorProto::UINT8: return "UINT8"; 26 | case ::ONNX_NAMESPACE::TensorProto::INT8: return "INT8"; 27 | case ::ONNX_NAMESPACE::TensorProto::UINT16: return "UINT16"; 28 | case ::ONNX_NAMESPACE::TensorProto::INT16: return "INT16"; 29 | case ::ONNX_NAMESPACE::TensorProto::INT32: return "INT32"; 30 | case ::ONNX_NAMESPACE::TensorProto::INT64: return "INT64"; 31 | case ::ONNX_NAMESPACE::TensorProto::STRING: return "STRING"; 32 | case ::ONNX_NAMESPACE::TensorProto::BOOL: return "BOOL"; 33 | case ::ONNX_NAMESPACE::TensorProto::FLOAT16: return "FLOAT16"; 34 | case ::ONNX_NAMESPACE::TensorProto::BFLOAT16: return "BFLOAT16"; 35 | case ::ONNX_NAMESPACE::TensorProto::DOUBLE: return "DOUBLE"; 36 | case ::ONNX_NAMESPACE::TensorProto::UINT32: return "UINT32"; 37 | case ::ONNX_NAMESPACE::TensorProto::UINT64: return "UINT64"; 38 | case ::ONNX_NAMESPACE::TensorProto::COMPLEX64: return "COMPLEX64"; 39 | case ::ONNX_NAMESPACE::TensorProto::COMPLEX128: return "COMPLEX128"; 40 | case ::ONNX_NAMESPACE::TensorProto::FLOAT4E2M1: return "FLOAT4E2M1"; 41 | default: return ""; 42 | } 43 | } 44 | 45 | int32_t getDtypeSizeBits(int32_t onnxDtype) 46 | { 47 | switch (onnxDtype) 48 | { 49 | case ::ONNX_NAMESPACE::TensorProto::FLOAT16: return 16; 50 | case ::ONNX_NAMESPACE::TensorProto::BFLOAT16: return 16; 51 | case ::ONNX_NAMESPACE::TensorProto::FLOAT: return 32; 52 | case ::ONNX_NAMESPACE::TensorProto::DOUBLE: return 64; 53 | case ::ONNX_NAMESPACE::TensorProto::COMPLEX64: return 64; 54 | case ::ONNX_NAMESPACE::TensorProto::COMPLEX128: return 128; 55 | case ::ONNX_NAMESPACE::TensorProto::UINT8: return 8; 56 | case ::ONNX_NAMESPACE::TensorProto::INT8: return 8; 57 | case ::ONNX_NAMESPACE::TensorProto::UINT16: return 16; 58 | case ::ONNX_NAMESPACE::TensorProto::INT16: return 16; 59 | case ::ONNX_NAMESPACE::TensorProto::UINT32: return 32; 60 | // Booleans are stored in int32 tensors in ONNX 61 | case ::ONNX_NAMESPACE::TensorProto::BOOL: return 8; 62 | case ::ONNX_NAMESPACE::TensorProto::INT32: return 32; 63 | case ::ONNX_NAMESPACE::TensorProto::UINT64: return 64; 64 | case ::ONNX_NAMESPACE::TensorProto::INT64: return 64; 65 | case ::ONNX_NAMESPACE::TensorProto::FLOAT8E4M3FN: return 8; 66 | case ::ONNX_NAMESPACE::TensorProto::INT4: return 4; 67 | case ::ONNX_NAMESPACE::TensorProto::FLOAT4E2M1: return 4; 68 | default: return -1; 69 | } 70 | } 71 | 72 | size_t getTensorOrWeightsSizeBytes(int64_t count, int32_t onnxDtype) 73 | { 74 | 75 | int32_t dTypeSize = getDtypeSizeBits(onnxDtype); 76 | 77 | if (dTypeSize == -1 78 | || static_cast(count) > std::numeric_limits::max() / static_cast(dTypeSize)) 79 | { 80 | throw std::runtime_error("Size of weights exceeds maximum!"); 81 | } 82 | 83 | int64_t sizeInBits = count * dTypeSize; 84 | if (sizeInBits % 8 != 0) 85 | { 86 | // This is a specific implementation to INT4, since this is currently the only sub-byte data type 87 | // we're supporting. Different data-types may have different padding. 88 | assert( 89 | onnxDtype == ::ONNX_NAMESPACE::TensorProto::INT4 || onnxDtype == ::ONNX_NAMESPACE::TensorProto::FLOAT4E2M1); 90 | sizeInBits += 4; 91 | } 92 | assert(sizeInBits % 8 == 0); 93 | return static_cast(sizeInBits / 8); 94 | } 95 | 96 | int64_t volume(nvinfer1::Dims const& dims) 97 | { 98 | std::for_each( 99 | dims.d, dims.d + dims.nbDims, [](int32_t d) { assert(d >= 0 && "volume makes no sense for dynamic shapes"); }); 100 | return std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 101 | } 102 | 103 | std::string normalizePath(std::string const& path) 104 | { 105 | std::vector normPath; 106 | auto addToPath = [&normPath](std::string s) { 107 | // Ignore all extra slashes, and current directory paths 108 | if (s == "/" || s == "./") 109 | { 110 | return; 111 | } 112 | // Push back to normPath under the following circumstances 113 | // 1. Current string is not "../" or 114 | // 2. "../" if it's the first string or 115 | // 3. "../" is the previous string in normPath 116 | if (s != "../" || normPath.empty() || (!normPath.empty() && normPath.back() == "../")) 117 | { 118 | normPath.push_back(s); 119 | } 120 | // Remove previous entry since "../" was encountered. 121 | else 122 | { 123 | normPath.pop_back(); 124 | } 125 | }; 126 | 127 | size_t i = 0; 128 | size_t n = path.size(); 129 | std::string sep = "/"; 130 | 131 | // Loop through path, split on all path seperator tokens, and append to normPath if applicable. 132 | while (i < n) 133 | { 134 | auto slashPos = path.find(sep, i); 135 | if (slashPos == std::string::npos) 136 | { 137 | addToPath(path.substr(i, n - i)); 138 | break; 139 | } 140 | else 141 | { 142 | addToPath(path.substr(i, slashPos - i + 1)); 143 | i = slashPos + 1; 144 | } 145 | } 146 | 147 | // Build final output string 148 | std::string out; 149 | for (auto s : normPath) 150 | { 151 | out += s; 152 | } 153 | return out; 154 | } 155 | 156 | std::string const& generateUniqueName( 157 | std::set& namesSet, int64_t& suffixCounter, std::string const& basename) 158 | { 159 | std::string candidate = basename; 160 | 161 | while (namesSet.find(candidate) != namesSet.end()) 162 | { 163 | candidate = basename + "_" + std::to_string(suffixCounter); 164 | ++suffixCounter; 165 | } 166 | 167 | namesSet.insert(candidate); 168 | // Return reference to newly inserted string to avoid any c_str()'s going out of scope 169 | return *namesSet.find(candidate); 170 | } 171 | 172 | } // namespace onnx2trt 173 | -------------------------------------------------------------------------------- /weightUtils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | #include "ShapedWeights.hpp" 7 | #include "bfloat16.hpp" 8 | #include "half.h" 9 | #include 10 | #include 11 | #include 12 | 13 | // Subset of helper functions that deal exclusively with weights to be shared across IParser and IParserRefitter classes. 14 | // Define weightLog Macros here to ensure that an ImporterCtx class is not needed to log. 15 | 16 | namespace onnx2trt 17 | { 18 | 19 | // Return the name of an ONNX data enum. 20 | char const* getDtypeName(int32_t onnxDtype); 21 | 22 | // Return the size in bits of an ONNX data type. 23 | int32_t getDtypeSizeBits(int32_t onnxDtype); 24 | 25 | // Return the size in bytes of an tensor/weights object, handle sub-byte padding. 26 | size_t getTensorOrWeightsSizeBytes(int64_t count, int32_t onnxDtype); 27 | 28 | // Find the corresponding ONNX data type of a built-in data type. 29 | template 30 | ShapedWeights::DataType getShapedWeightsDataType() 31 | { 32 | static std::unordered_map const tMap({ 33 | {std::type_index(typeid(bool)), ::ONNX_NAMESPACE::TensorProto::BOOL}, 34 | {std::type_index(typeid(int8_t)), ::ONNX_NAMESPACE::TensorProto::INT8}, 35 | {std::type_index(typeid(uint8_t)), ::ONNX_NAMESPACE::TensorProto::UINT8}, 36 | {std::type_index(typeid(int16_t)), ::ONNX_NAMESPACE::TensorProto::INT16}, 37 | {std::type_index(typeid(uint16_t)), ::ONNX_NAMESPACE::TensorProto::UINT16}, 38 | {std::type_index(typeid(int32_t)), ::ONNX_NAMESPACE::TensorProto::INT32}, 39 | {std::type_index(typeid(uint32_t)), ::ONNX_NAMESPACE::TensorProto::UINT32}, 40 | {std::type_index(typeid(int64_t)), ::ONNX_NAMESPACE::TensorProto::INT64}, 41 | {std::type_index(typeid(uint64_t)), ::ONNX_NAMESPACE::TensorProto::UINT64}, 42 | {std::type_index(typeid(float)), ::ONNX_NAMESPACE::TensorProto::FLOAT}, 43 | {std::type_index(typeid(double)), ::ONNX_NAMESPACE::TensorProto::DOUBLE}, 44 | {std::type_index(typeid(half_float::half)), ::ONNX_NAMESPACE::TensorProto::FLOAT16}, 45 | {std::type_index(typeid(BFloat16)), ::ONNX_NAMESPACE::TensorProto::BFLOAT16}, 46 | // TRT-22989: Add fp8 and int4 support 47 | }); 48 | 49 | if (tMap.find(std::type_index(typeid(T))) != tMap.end()) 50 | { 51 | return tMap.at(std::type_index(typeid(T))); 52 | } 53 | return ::ONNX_NAMESPACE::TensorProto::UNDEFINED; 54 | } 55 | 56 | // Return the volume of a Dims object 57 | int64_t volume(nvinfer1::Dims const& dims); 58 | 59 | // Normalize the slashes in a string representing a filepath. 60 | std::string normalizePath(std::string const& path); 61 | 62 | // Generate a unique name for a given weight or tensor name (passed as the |basename|) 63 | std::string const& generateUniqueName( 64 | std::set& namesSet, int64_t& suffixCounter, std::string const& basename); 65 | 66 | } // namespace onnx2trt 67 | --------------------------------------------------------------------------------