├── .gitignore ├── CMakeLists.txt ├── README.md ├── check_models.py ├── example1.cpp ├── example2.cpp ├── example3-load.cpp ├── example3-save.cpp ├── example4.cpp ├── example5.cpp ├── example6.cpp ├── example7.cpp ├── example7prof.cpp ├── gen_models.py ├── model1.onnx └── model2.onnx /.gitignore: -------------------------------------------------------------------------------- 1 | # CMake 2 | build/ 3 | # CLion + QtCreator + KDevelop 4 | build*/ 5 | cmake-build-*/ 6 | *.user 7 | .idea/ 8 | .kdev4/ 9 | *.kdev4 10 | *.iml 11 | .vscode/ 12 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.11) 2 | project(trt-cpp-min) 3 | set(CMAKE_CXX_STANDARD 14) 4 | file(COPY model1.onnx DESTINATION .) 5 | file(COPY model2.onnx DESTINATION .) 6 | 7 | # CUDA 8 | find_package(CUDA REQUIRED) 9 | include_directories(${CUDA_INCLUDE_DIRS}) 10 | message("CUDA_TOOLKIT_ROOT_DIR = ${CUDA_TOOLKIT_ROOT_DIR}") 11 | message("CUDA_INCLUDE_DIRS = ${CUDA_INCLUDE_DIRS}") 12 | message("CUDA_LIBRARIES = ${CUDA_LIBRARIES}") 13 | 14 | 15 | link_libraries(nvinfer nvonnxparser ${CUDA_LIBRARIES}) 16 | 17 | add_executable(example1 example1.cpp) 18 | add_executable(example2 example2.cpp) 19 | add_executable(example3-save example3-save.cpp) 20 | add_executable(example3-load example3-load.cpp) 21 | add_executable(example4 example4.cpp) 22 | add_executable(example5 example5.cpp) 23 | add_executable(example6 example6.cpp) 24 | add_executable(example7 example7.cpp) 25 | add_executable(example7prof example7prof.cpp) 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | TensorRT 7 C++ (almost) minimal examples 2 | ==== 3 | 4 | By Oleksiy Grechnyev, IT-JIM, Mar-Apr 2020. 5 | 6 | ### Introduction 7 | 8 | `example1` is a minimal C++ TensorRT 7 example, much simpler than Nvidia examples. I create a trivial neural network 9 | of a single Linear layer (3D -> 2D output) in PyTorch, convert in to ONNX, and run in C++ TensorRT 7. Requires CUDA and 10 | TensorRT 7 (`libnvinfer`, `libnvonnxparser`) installed in your system. Other examples are not much harder. 11 | 12 | Note : These examples are for TensorRT 7+ only (see discussion below on TensorRT 6). A lot has changed in this version, especially compared to TensorRT 5 ! 13 | ONNX with dynamic batch size is now difficult. 14 | You must set the optimization profile, min/max/opt input size, and finally actual input size (in the context). 15 | Here I use `model1.onnx` with fixed batch size in `example1`, and `model2.onnx` with dynamic batch size in `example2`. 16 | 17 | `model1`, `model2` weights and biases: 18 | w=[[1., 2., 3.], [4., 5., 6.]] 19 | b=[-1., -2.] 20 | 21 | For example, inferring for x=[0.5, -0.5, 1.0] should give y=[1.5, 3.5]. 22 | 23 | ### Experiments with TensorRT 6: 24 | 25 | I tried to run this with TensorRT 6 in docker and discovered the following issues: 26 | 1. Parser does not like ONNX generated with PyTorch > 1.2, re-generated models on PyTorch 1.2 27 | 2. The code does not run without an extra line `config->setMaxWorkspaceSize(...);` 28 | 3. At this point, examples 1, 4, 5 work fine, but not 2, 3 (Parse ONNX with dynamic batch size) 29 | 4. However, now `example1` can infer `model2.onnx` (only with batch_size = 1), which did not work on TensorRT 7 30 | 31 | My investigation showed that TensorRT 6 internally has all the dynamic dimension infrastructure 32 | (dim=-1, optimization profiles), but the ONNX parser cannot parse the ONNX network with the dynamic dimension! 33 | It just throws away the batch dimension (it is removed, not set to 1). As the result, you can infer such network 34 | as in `example1`, and only with batch_size = 1. 35 | 36 | Update: This was with the "explicit batch" (`kEXPLICIT_BATCH`) option in the model definition. What does this mean? 37 | Apparently, this option means that network has an explicit batch dimension (which can be 1 or -1 or something else). 38 | 39 | * TensorRT 7 : Without `kEXPLICIT_BATCH`, ONNX cannot be parsed 40 | * TensorRT 6 : With `kEXPLICIT_BATCH`, ONNX parser does not support dynamic dimensions, and even without them it tends to misbehave 41 | for many networks. However, with TensorRT 6 you can parse ONNX without `kEXPLICIT_BATCH`. This works fine in TensorRT 6, but not 7! 42 | 43 | ### Examples 44 | 45 | * `gen_models.py` A python 3 code to create `model1.onnx` and `model2.onnx`. Requires `torch` 46 | * `check_models.py` A python 3 code to check and test `model1.onnx` and `model2.onnx`. Requires `numpy`, `onnx`, `onnxruntime` 47 | * `example1` A minimal C++ example, runs `model1.onnx` (with fixed batch size of 1) 48 | * `example2` Runs `model2.onnx` (with dynamic batch size) 49 | * `example3` Serialization: like `example2`, but split into save and load parts 50 | * `example4` Create simple network in-place (no ONNX parsing) 51 | * `example5` Another in-place network with FullyConnected layer, and tried INT8 quantization (but it fails for this layer, it seems). FP16 works fine though. 52 | * `example6` Convolution layer example 53 | * `example7` Finally succeeded with INT8 using a conv->relu->conv->relu network 54 | -------------------------------------------------------------------------------- /check_models.py: -------------------------------------------------------------------------------- 1 | # By Olekisy Grechnyev, IT-JIM on 3/30/20. 2 | # Load ONNX model and check 3 | # Infer it with onxruntime 4 | 5 | import numpy as np 6 | import onnx 7 | import onnxruntime 8 | 9 | #import caffe2.python.onnx.backend as backend 10 | 11 | def check_onnx(file_name): 12 | """Check the ONNX model""" 13 | print('\nChecking ONNX ' + file_name + ' ...') 14 | onnx_model = onnx.load(file_name) 15 | # print('onnx_model =', onnx_model) 16 | onnx.checker.check_model(onnx_model) 17 | print(onnx.helper.printable_graph(onnx_model.graph)) 18 | 19 | 20 | def infer_onnx(file_name): 21 | """Infer with onnxruntime""" 22 | print('\nInferring ONNX with onnxruntime : ' + file_name + ' ...') 23 | sess = onnxruntime.InferenceSession(file_name) 24 | print('sess =', sess) 25 | input, output = sess.get_inputs()[0], sess.get_outputs()[0] 26 | print('input =', input) 27 | print('output =', output) 28 | x = np.array([[0.5, -0.5, 1.0]], dtype='float32') 29 | y = sess.run([output.name], {input.name: x}) 30 | print('y =', y) # [[1.5, 3.5]] 31 | 32 | 33 | def infer_caffe2(): 34 | print('\nInferring ONNX with caffe2 ...') 35 | onnx_model = onnx.load('model1.onnx') 36 | rep = backend.prepare(onnx_model, device='CUDA:0') # Does not work ! 37 | print('rep =', rep) 38 | 39 | if __name__ == '__main__': 40 | check_onnx('model1.onnx') 41 | check_onnx('model2.onnx') 42 | infer_onnx('model1.onnx') 43 | infer_onnx('model2.onnx') 44 | # infer_caffe2() 45 | -------------------------------------------------------------------------------- /example1.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 1: An (almost) minimal TensorRT C++ inference example 3 | // This one uses model1.onnx with fixed batch size (1) 4 | // Batch size at inference must be the same ! 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | 15 | #include 16 | 17 | //====================================================================================================================== 18 | 19 | class Logger : public nvinfer1::ILogger { 20 | public: 21 | void log(Severity severity, const char *msg) override { 22 | using namespace std; 23 | string s; 24 | switch (severity) { 25 | case Severity::kINTERNAL_ERROR: 26 | s = "INTERNAL_ERROR"; 27 | break; 28 | case Severity::kERROR: 29 | s = "ERROR"; 30 | break; 31 | case Severity::kWARNING: 32 | s = "WARNING"; 33 | break; 34 | case Severity::kINFO: 35 | s = "INFO"; 36 | break; 37 | case Severity::kVERBOSE: 38 | s = "VERBOSE"; 39 | break; 40 | } 41 | cerr << s << ": " << msg << endl; 42 | } 43 | }; 44 | //====================================================================================================================== 45 | 46 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 47 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 48 | template 49 | struct Destroy { 50 | void operator()(T *t) const { 51 | t->destroy(); 52 | } 53 | }; 54 | 55 | //====================================================================================================================== 56 | 57 | /// Parse onnx file and create a TRT engine 58 | nvinfer1::ICudaEngine *createCudaEngine(const std::string &onnxFileName, nvinfer1::ILogger &logger) { 59 | using namespace std; 60 | using namespace nvinfer1; 61 | 62 | unique_ptr> builder{createInferBuilder(logger)}; 63 | unique_ptr> network{ 64 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)}; 65 | unique_ptr> parser{ 66 | nvonnxparser::createParser(*network, logger)}; 67 | 68 | if (!parser->parseFromFile(onnxFileName.c_str(), static_cast(ILogger::Severity::kINFO))) 69 | throw runtime_error("ERROR: could not parse ONNX model " + onnxFileName + " !"); 70 | 71 | // Modern version with config 72 | unique_ptr> config(builder->createBuilderConfig()); 73 | // This is needed for TensorRT 6, not needed by 7 ! 74 | config->setMaxWorkspaceSize(64*1024*1024); 75 | return builder->buildEngineWithConfig(*network, *config); 76 | } 77 | 78 | //====================================================================================================================== 79 | /// Run a single inference 80 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 81 | std::vector &outputTensor, void **bindings, int batchSize) { 82 | 83 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 84 | 85 | // Infer synchronously as an alternative, no stream needed 86 | // cudaMemcpy(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice); 87 | // bool res = context->executeV2(bindings); 88 | // cudaMemcpy(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), cudaMemcpyDeviceToHost); 89 | 90 | // Infer asynchronously, in a proper cuda way ! 91 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 92 | stream); 93 | context->enqueueV2(bindings, stream, nullptr); 94 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 95 | cudaMemcpyDeviceToHost, stream); 96 | } 97 | 98 | //====================================================================================================================== 99 | int main() { 100 | using namespace std; 101 | using namespace nvinfer1; 102 | 103 | // Parse model, create engine 104 | Logger logger; 105 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT (almost) minimal example1 !!! "); 106 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 107 | unique_ptr> engine(createCudaEngine("model1.onnx", logger)); 108 | 109 | if (!engine) 110 | throw runtime_error("Engine creation failed !"); 111 | 112 | // Optional : Print all bindings : name + dims + dtype 113 | cout << "=============\nBindings :\n"; 114 | int n = engine->getNbBindings(); 115 | for (int i = 0; i < n; ++i) { 116 | Dims d = engine->getBindingDimensions(i); 117 | cout << i << " : " << engine->getBindingName(i) << " : dims="; 118 | for (int j = 0; j < d.nbDims; ++j) { 119 | cout << d.d[j]; 120 | if (j < d.nbDims - 1) 121 | cout << "x"; 122 | } 123 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 124 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 125 | } 126 | cout << "=============\n\n"; 127 | 128 | // Create context 129 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 130 | unique_ptr> context(engine->createExecutionContext()); 131 | 132 | // Create data structures for the inference 133 | cudaStream_t stream; 134 | cudaStreamCreate(&stream); 135 | vector inputTensor{0.5, -0.5, 1.0}; 136 | vector outputTensor(2, -4.9); 137 | void *bindings[2]{0}; 138 | int batchSize = 1; 139 | // Alloc cuda memory for IO tensors 140 | for (int i = 0; i < engine->getNbBindings(); ++i) { 141 | Dims dims{engine->getBindingDimensions(i)}; 142 | size_t size = accumulate(dims.d, dims.d + dims.nbDims, batchSize, multiplies()); 143 | // Create CUDA buffer for Tensor. 144 | cudaMalloc(&bindings[i], size * sizeof(float)); 145 | } 146 | 147 | // Run the inference ! 148 | cout << "Running the inference !" << endl; 149 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 150 | cudaStreamSynchronize(stream); 151 | // Must be [1.5, 3.5] 152 | cout << "y = [" << outputTensor[0] << ", " << outputTensor[1] << "]" << endl; 153 | 154 | cudaStreamDestroy(stream); 155 | cudaFree(bindings[0]); 156 | cudaFree(bindings[1]); 157 | return 0; 158 | } 159 | -------------------------------------------------------------------------------- /example2.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 2 : Batch inference for model2.onnx with dynamic batch size 3 | // I use here batch of 2 4 | // This is TensorRT 7.0 API, things were easier in older TensorRT ! 5 | // Also contains rudimentary "print network" example 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | #include 17 | 18 | //====================================================================================================================== 19 | 20 | class Logger : public nvinfer1::ILogger { 21 | public: 22 | void log(Severity severity, const char *msg) override { 23 | using namespace std; 24 | string s; 25 | switch (severity) { 26 | case Severity::kINTERNAL_ERROR: 27 | s = "INTERNAL_ERROR"; 28 | break; 29 | case Severity::kERROR: 30 | s = "ERROR"; 31 | break; 32 | case Severity::kWARNING: 33 | s = "WARNING"; 34 | break; 35 | case Severity::kINFO: 36 | s = "INFO"; 37 | break; 38 | case Severity::kVERBOSE: 39 | s = "VERBOSE"; 40 | break; 41 | } 42 | cerr << s << ": " << msg << endl; 43 | } 44 | }; 45 | //====================================================================================================================== 46 | 47 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 48 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 49 | template 50 | struct Destroy { 51 | void operator()(T *t) const { 52 | t->destroy(); 53 | } 54 | }; 55 | 56 | //====================================================================================================================== 57 | /// Optional : Print dimensions as string 58 | std::string printDim(const nvinfer1::Dims & d) { 59 | using namespace std; 60 | ostringstream oss; 61 | for (int j = 0; j < d.nbDims; ++j) { 62 | oss << d.d[j]; 63 | if (j < d.nbDims - 1) 64 | oss << "x"; 65 | } 66 | return oss.str(); 67 | } 68 | //====================================================================================================================== 69 | /// Optional : Print layers of the network 70 | void printNetwork(const nvinfer1::INetworkDefinition &net) { 71 | using namespace std; 72 | using namespace nvinfer1; 73 | cout << "\n=============\nNetwork info :" << endl; 74 | 75 | cout << "\nInputs : " << endl; 76 | for (int i = 0; i < net.getNbInputs(); ++i) { 77 | ITensor * inp = net.getInput(i); 78 | cout << "input" << i << " , dtype=" << (int)inp->getType() << " , dims=" << printDim(inp->getDimensions()) << endl; 79 | } 80 | 81 | cout << "\nLayers : " << endl; 82 | cout << "Number of layers : " << net.getNbLayers() << endl; 83 | for (int i = 0; i < net.getNbLayers(); ++i) { 84 | ILayer *l = net.getLayer(i); 85 | cout << "layer" << i << " , name=" << l->getName() << " , type=" << (int)l->getType() << " , IN "; 86 | for (int j = 0; j < l->getNbInputs(); ++j) 87 | cout << printDim(l->getInput(j)->getDimensions()) << " "; 88 | cout << ", OUT "; 89 | for (int j = 0; j < l->getNbOutputs(); ++j) 90 | cout << printDim(l->getOutput(j)->getDimensions()) << " "; 91 | cout << endl; 92 | } 93 | 94 | cout << "\nOutputs : " << endl; 95 | for (int i = 0; i < net.getNbOutputs(); ++i) { 96 | ITensor * outp = net.getOutput(i); 97 | cout << "input" << i << " , dtype=" << (int)outp->getType() << " , dims=" << printDim(outp->getDimensions()) << endl; 98 | } 99 | 100 | cout << "=============\n" << endl; 101 | } 102 | //====================================================================================================================== 103 | 104 | /// Parse onnx file and create a TRT engine 105 | nvinfer1::ICudaEngine *createCudaEngine(const std::string &onnxFileName, nvinfer1::ILogger &logger, int batchSize) { 106 | using namespace std; 107 | using namespace nvinfer1; 108 | 109 | unique_ptr> builder{createInferBuilder(logger)}; 110 | unique_ptr> network{ 111 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH) 112 | }; 113 | unique_ptr> parser{ 114 | nvonnxparser::createParser(*network, logger)}; 115 | 116 | if (!parser->parseFromFile(onnxFileName.c_str(), static_cast(ILogger::Severity::kINFO))) 117 | throw runtime_error("ERROR: could not parse ONNX model " + onnxFileName + " !"); 118 | 119 | // Optional : print network info 120 | printNetwork(*network); 121 | 122 | // Create Optimization profile and set the batch size 123 | IOptimizationProfile *profile = builder->createOptimizationProfile(); 124 | profile->setDimensions("input", OptProfileSelector::kMIN, Dims2{batchSize, 3}); 125 | profile->setDimensions("input", OptProfileSelector::kMAX, Dims2{batchSize, 3}); 126 | profile->setDimensions("input", OptProfileSelector::kOPT, Dims2{batchSize, 3}); 127 | 128 | // Build engine 129 | unique_ptr> config(builder->createBuilderConfig()); 130 | // This is needed for TensorRT 6, not needed by 7 ! 131 | config->setMaxWorkspaceSize(64*1024*1024); 132 | config->addOptimizationProfile(profile); 133 | return builder->buildEngineWithConfig(*network, *config); 134 | } 135 | 136 | //====================================================================================================================== 137 | /// Run a single inference 138 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 139 | std::vector &outputTensor, void **bindings, int batchSize) { 140 | 141 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 142 | 143 | // Infer asynchronously, in a proper cuda way ! 144 | using namespace std; 145 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 146 | stream); 147 | context->enqueueV2(bindings, stream, nullptr); 148 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 149 | cudaMemcpyDeviceToHost, stream); 150 | } 151 | 152 | //====================================================================================================================== 153 | int main() { 154 | using namespace std; 155 | using namespace nvinfer1; 156 | 157 | // Parse model, create engine 158 | Logger logger; 159 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example2 !!! "); 160 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 161 | int batchSize = 2; 162 | unique_ptr> engine(createCudaEngine("model2.onnx", logger, batchSize)); 163 | 164 | if (!engine) 165 | throw runtime_error("Engine creation failed !"); 166 | 167 | // Optional : Print all bindings : name + dims + dtype 168 | cout << "=============\nBindings :\n"; 169 | int n = engine->getNbBindings(); 170 | for (int i = 0; i < n; ++i) { 171 | Dims d = engine->getBindingDimensions(i); 172 | cout << i << " : " << engine->getBindingName(i) << " : dims=" << printDim(d); 173 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 174 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 175 | } 176 | cout << "=============\n\n"; 177 | 178 | // Create context 179 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 180 | unique_ptr> context(engine->createExecutionContext()); 181 | // Very important, you must set batch size here, otherwise you get zero output ! 182 | context->setBindingDimensions(0, Dims2(batchSize, 3)); 183 | 184 | // Create data structures for the inference 185 | cudaStream_t stream; 186 | cudaStreamCreate(&stream); 187 | vector inputTensor{0.5, -0.5, 1.0, 0.0, 0.0, 0.0}; 188 | vector outputTensor(2 * batchSize, -4.9); 189 | void *bindings[2]{0}; 190 | // Alloc cuda memory for IO tensors 191 | size_t sizes[] = {inputTensor.size(), outputTensor.size()}; 192 | for (int i = 0; i < engine->getNbBindings(); ++i) { 193 | // Create CUDA buffer for Tensor. 194 | cudaMalloc(&bindings[i], sizes[i] * sizeof(float)); 195 | } 196 | 197 | // Run the inference ! 198 | cout << "Running the inference !" << endl; 199 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 200 | cudaStreamSynchronize(stream); 201 | // Must be [ [1.5, 3.5], [-1,-2] ] 202 | cout << "y = ["; 203 | for (int i = 0; i < batchSize; ++i) { 204 | cout << " [" << outputTensor.at(2 * i) << ", " << outputTensor.at(2 * i + 1) << "]"; 205 | if (i < batchSize - 1) 206 | cout << ", "; 207 | } 208 | cout << " ]" << endl; 209 | 210 | cudaStreamDestroy(stream); 211 | cudaFree(bindings[0]); 212 | cudaFree(bindings[1]); 213 | return 0; 214 | } 215 | -------------------------------------------------------------------------------- /example3-load.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 3-save : Like example 2, divided into 'save' and load 'parts' 3 | // I use here batch of 2 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | //====================================================================================================================== 17 | 18 | class Logger : public nvinfer1::ILogger { 19 | public: 20 | void log(Severity severity, const char *msg) override { 21 | using namespace std; 22 | string s; 23 | switch (severity) { 24 | case Severity::kINTERNAL_ERROR: 25 | s = "INTERNAL_ERROR"; 26 | break; 27 | case Severity::kERROR: 28 | s = "ERROR"; 29 | break; 30 | case Severity::kWARNING: 31 | s = "WARNING"; 32 | break; 33 | case Severity::kINFO: 34 | s = "INFO"; 35 | break; 36 | case Severity::kVERBOSE: 37 | s = "VERBOSE"; 38 | break; 39 | } 40 | cerr << s << ": " << msg << endl; 41 | } 42 | }; 43 | //====================================================================================================================== 44 | 45 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 46 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 47 | template 48 | struct Destroy { 49 | void operator()(T *t) const { 50 | t->destroy(); 51 | } 52 | }; 53 | 54 | //====================================================================================================================== 55 | /// Run a single inference 56 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 57 | std::vector &outputTensor, void **bindings, int batchSize) { 58 | 59 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 60 | 61 | // Infer asynchronously, in a proper cuda way ! 62 | using namespace std; 63 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 64 | stream); 65 | context->enqueueV2(bindings, stream, nullptr); 66 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 67 | cudaMemcpyDeviceToHost, stream); 68 | } 69 | 70 | //====================================================================================================================== 71 | int main() { 72 | using namespace std; 73 | using namespace nvinfer1; 74 | 75 | Logger logger; 76 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example3-load !!! "); 77 | 78 | // Load file, create engine 79 | logger.log(ILogger::Severity::kINFO, "Loading engine from example3.engine..."); 80 | int batchSize = 2; 81 | 82 | vector buffer; 83 | { 84 | ifstream in("example3.engine", ios::binary | ios::ate); 85 | if (!in) 86 | throw runtime_error("Cannot open example3.engine"); 87 | streamsize ss = in.tellg(); 88 | in.seekg(0, ios::beg); 89 | cout << "Input file size = " << ss << endl; 90 | buffer.resize(ss); 91 | if (0 == ss || !in.read(buffer.data(), ss)) 92 | throw runtime_error("Cannot read example3.engine"); 93 | } 94 | 95 | unique_ptr> runtime(createInferRuntime(logger)); 96 | unique_ptr> engine(runtime->deserializeCudaEngine(buffer.data(), buffer.size())); 97 | if (!engine) 98 | throw runtime_error("Deserialize error !"); 99 | 100 | // Optional : Print all bindings : name + dims + dtype 101 | cout << "=============\nBindings :\n"; 102 | int n = engine->getNbBindings(); 103 | for (int i = 0; i < n; ++i) { 104 | Dims d = engine->getBindingDimensions(i); 105 | cout << i << " : " << engine->getBindingName(i) << " : dims="; 106 | for (int j = 0; j < d.nbDims; ++j) { 107 | cout << d.d[j]; 108 | if (j < d.nbDims - 1) 109 | cout << "x"; 110 | } 111 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 112 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 113 | } 114 | cout << "=============\n\n"; 115 | 116 | // Create context 117 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 118 | unique_ptr> context(engine->createExecutionContext()); 119 | // Very important, you must set batch size here, otherwise you get zero output ! 120 | context->setBindingDimensions(0, Dims2(batchSize, 3)); 121 | 122 | // Create data structures for the inference 123 | cudaStream_t stream; 124 | cudaStreamCreate(&stream); 125 | vector inputTensor{0.5, -0.5, 1.0, 0.0, 0.0, 0.0}; 126 | vector outputTensor(2 * batchSize, -4.9); 127 | void *bindings[2]{0}; 128 | // Alloc cuda memory for IO tensors 129 | size_t sizes[] = {inputTensor.size(), outputTensor.size()}; 130 | for (int i = 0; i < engine->getNbBindings(); ++i) { 131 | // Create CUDA buffer for Tensor. 132 | cudaMalloc(&bindings[i], sizes[i] * sizeof(float)); 133 | } 134 | 135 | // Run the inference ! 136 | cout << "Running the inference !" << endl; 137 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 138 | cudaStreamSynchronize(stream); 139 | // Must be [ [1.5, 3.5], [-1,-2] ] 140 | cout << "y = ["; 141 | for (int i = 0; i < batchSize; ++i) { 142 | cout << " [" << outputTensor.at(2 * i) << ", " << outputTensor.at(2 * i + 1) << "]"; 143 | if (i < batchSize - 1) 144 | cout << ", "; 145 | } 146 | cout << " ]" << endl; 147 | 148 | cudaStreamDestroy(stream); 149 | cudaFree(bindings[0]); 150 | cudaFree(bindings[1]); 151 | return 0; 152 | } 153 | -------------------------------------------------------------------------------- /example3-save.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 3-save : Like example 2, divided into 'save' and load 'parts' 3 | // I use here batch of 2 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | //====================================================================================================================== 17 | 18 | class Logger : public nvinfer1::ILogger { 19 | public: 20 | void log(Severity severity, const char *msg) override { 21 | using namespace std; 22 | string s; 23 | switch (severity) { 24 | case Severity::kINTERNAL_ERROR: 25 | s = "INTERNAL_ERROR"; 26 | break; 27 | case Severity::kERROR: 28 | s = "ERROR"; 29 | break; 30 | case Severity::kWARNING: 31 | s = "WARNING"; 32 | break; 33 | case Severity::kINFO: 34 | s = "INFO"; 35 | break; 36 | case Severity::kVERBOSE: 37 | s = "VERBOSE"; 38 | break; 39 | } 40 | cerr << s << ": " << msg << endl; 41 | } 42 | }; 43 | //====================================================================================================================== 44 | 45 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 46 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 47 | template 48 | struct Destroy { 49 | void operator()(T *t) const { 50 | t->destroy(); 51 | } 52 | }; 53 | 54 | //====================================================================================================================== 55 | 56 | /// Parse onnx file and create a TRT engine 57 | nvinfer1::ICudaEngine *createCudaEngine(const std::string &onnxFileName, nvinfer1::ILogger &logger, int batchSize) { 58 | using namespace std; 59 | using namespace nvinfer1; 60 | 61 | unique_ptr> builder{createInferBuilder(logger)}; 62 | unique_ptr> network{ 63 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH) 64 | }; 65 | unique_ptr> parser{ 66 | nvonnxparser::createParser(*network, logger)}; 67 | 68 | if (!parser->parseFromFile(onnxFileName.c_str(), static_cast(ILogger::Severity::kINFO))) 69 | throw runtime_error("ERROR: could not parse ONNX model " + onnxFileName + " !"); 70 | 71 | // Create Optimization profile and set the batch size 72 | IOptimizationProfile *profile = builder->createOptimizationProfile(); 73 | profile->setDimensions("input", OptProfileSelector::kMIN, Dims2{batchSize, 3}); 74 | profile->setDimensions("input", OptProfileSelector::kMAX, Dims2{batchSize, 3}); 75 | profile->setDimensions("input", OptProfileSelector::kOPT, Dims2{batchSize, 3}); 76 | 77 | // Build engine 78 | unique_ptr> config(builder->createBuilderConfig()); 79 | // This is needed for TensorRT 6, not needed by 7 ! 80 | config->setMaxWorkspaceSize(64*1024*1024); 81 | config->addOptimizationProfile(profile); 82 | return builder->buildEngineWithConfig(*network, *config); 83 | } 84 | 85 | //====================================================================================================================== 86 | int main() { 87 | using namespace std; 88 | using namespace nvinfer1; 89 | 90 | // Parse model, create engine 91 | Logger logger; 92 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example3-save !!! "); 93 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 94 | int batchSize = 2; 95 | unique_ptr> engine(createCudaEngine("model2.onnx", logger, batchSize)); 96 | 97 | if (!engine) 98 | throw runtime_error("Engine creation failed !"); 99 | 100 | // Optional : Print all bindings : name + dims + dtype 101 | cout << "=============\nBindings :\n"; 102 | int n = engine->getNbBindings(); 103 | for (int i = 0; i < n; ++i) { 104 | Dims d = engine->getBindingDimensions(i); 105 | cout << i << " : " << engine->getBindingName(i) << " : dims="; 106 | for (int j = 0; j < d.nbDims; ++j) { 107 | cout << d.d[j]; 108 | if (j < d.nbDims - 1) 109 | cout << "x"; 110 | } 111 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 112 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 113 | } 114 | cout << "=============\n\n"; 115 | 116 | // Write engine to disk 117 | unique_ptr> serializedEngine(engine->serialize()); 118 | cout << "\nSerialized engine : size = " << serializedEngine->size() << ", dtype = " << (int) serializedEngine->type() 119 | << endl; 120 | 121 | ofstream out("example3.engine", ios::binary); 122 | out.write((char *)serializedEngine->data(), serializedEngine->size()); 123 | 124 | return 0; 125 | } 126 | -------------------------------------------------------------------------------- /example4.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 4 : Here I construct network in-place, no ONNX parsing ! 3 | // I use here dynamic batches, a batch of 2 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #include 14 | 15 | //====================================================================================================================== 16 | 17 | class Logger : public nvinfer1::ILogger { 18 | public: 19 | void log(Severity severity, const char *msg) override { 20 | using namespace std; 21 | string s; 22 | switch (severity) { 23 | case Severity::kINTERNAL_ERROR: 24 | s = "INTERNAL_ERROR"; 25 | break; 26 | case Severity::kERROR: 27 | s = "ERROR"; 28 | break; 29 | case Severity::kWARNING: 30 | s = "WARNING"; 31 | break; 32 | case Severity::kINFO: 33 | s = "INFO"; 34 | break; 35 | case Severity::kVERBOSE: 36 | s = "VERBOSE"; 37 | break; 38 | } 39 | cerr << s << ": " << msg << endl; 40 | } 41 | }; 42 | //====================================================================================================================== 43 | 44 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 45 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 46 | template 47 | struct Destroy { 48 | void operator()(T *t) const { 49 | t->destroy(); 50 | } 51 | }; 52 | 53 | //====================================================================================================================== 54 | /// Optional : Print dimensions as string 55 | std::string printDim(const nvinfer1::Dims & d) { 56 | using namespace std; 57 | ostringstream oss; 58 | for (int j = 0; j < d.nbDims; ++j) { 59 | oss << d.d[j]; 60 | if (j < d.nbDims - 1) 61 | oss << "x"; 62 | } 63 | return oss.str(); 64 | } 65 | //====================================================================================================================== 66 | /// Optional : Print layers of the network 67 | void printNetwork(const nvinfer1::INetworkDefinition &net) { 68 | using namespace std; 69 | using namespace nvinfer1; 70 | cout << "\n=============\nNetwork info :" << endl; 71 | 72 | cout << "\nInputs : " << endl; 73 | for (int i = 0; i < net.getNbInputs(); ++i) { 74 | ITensor * inp = net.getInput(i); 75 | cout << "input" << i << " , dtype=" << (int)inp->getType() << " , dims=" << printDim(inp->getDimensions()) << endl; 76 | } 77 | 78 | cout << "\nLayers : " << endl; 79 | cout << "Number of layers : " << net.getNbLayers() << endl; 80 | for (int i = 0; i < net.getNbLayers(); ++i) { 81 | ILayer *l = net.getLayer(i); 82 | cout << "layer" << i << " , name=" << l->getName() << " , type=" << (int)l->getType() << " , IN "; 83 | for (int j = 0; j < l->getNbInputs(); ++j) 84 | cout << printDim(l->getInput(j)->getDimensions()) << " "; 85 | cout << ", OUT "; 86 | for (int j = 0; j < l->getNbOutputs(); ++j) 87 | cout << printDim(l->getOutput(j)->getDimensions()) << " "; 88 | cout << endl; 89 | } 90 | 91 | cout << "\nOutputs : " << endl; 92 | for (int i = 0; i < net.getNbOutputs(); ++i) { 93 | ITensor * outp = net.getOutput(i); 94 | cout << "input" << i << " , dtype=" << (int)outp->getType() << " , dims=" << printDim(outp->getDimensions()) << endl; 95 | } 96 | 97 | cout << "=============\n" << endl; 98 | } 99 | //====================================================================================================================== 100 | 101 | /// Create model and create a TRT engine 102 | nvinfer1::ICudaEngine *createCudaEngine(nvinfer1::ILogger &logger, int batchSize) { 103 | using namespace std; 104 | using namespace nvinfer1; 105 | 106 | unique_ptr> builder{createInferBuilder(logger)}; 107 | unique_ptr> network{ 108 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH) 109 | }; 110 | 111 | // Weights + bias 112 | vector w0 = {1., 2., 3., 4., 5., 6.}; 113 | vector b0 = {-1., -2.}; 114 | Weights w{DataType::kFLOAT, w0.data(), (int64_t) w0.size()}; 115 | Weights b{DataType::kFLOAT, b0.data(), (int64_t) b0.size()}; 116 | 117 | // Note, so hard to make a simple linear layer 118 | // But they have convolutions, FC etc for images only !!! WTF ??? 119 | ITensor *input = network->addInput("goblin_input", DataType::kFLOAT, Dims2(-1, 3)); 120 | IConstantLayer * const1 = network->addConstant(Dims2(2, 3), w); 121 | IMatrixMultiplyLayer* mm = network->addMatrixMultiply(*input, MatrixOperation::kNONE, *const1->getOutput(0), MatrixOperation::kTRANSPOSE); 122 | IConstantLayer * const2 = network->addConstant(Dims2(1, 2), b); 123 | IElementWiseLayer* ew = network->addElementWise(*mm->getOutput(0), *const2->getOutput(0), ElementWiseOperation::kSUM); 124 | ITensor *output = ew->getOutput(0); 125 | output->setName("goblin_output"); 126 | network->markOutput(*output); 127 | 128 | printNetwork(*network); 129 | 130 | // Create Optimization profile and set the batch size 131 | IOptimizationProfile *profile = builder->createOptimizationProfile(); 132 | profile->setDimensions("goblin_input", OptProfileSelector::kMIN, Dims2{batchSize, 3}); 133 | profile->setDimensions("goblin_input", OptProfileSelector::kMAX, Dims2{batchSize, 3}); 134 | profile->setDimensions("goblin_input", OptProfileSelector::kOPT, Dims2{batchSize, 3}); 135 | 136 | // Build engine 137 | unique_ptr> config(builder->createBuilderConfig()); 138 | // This is needed for TensorRT 6, not needed by 7 ! 139 | config->setMaxWorkspaceSize(64*1024*1024); 140 | config->addOptimizationProfile(profile); 141 | return builder->buildEngineWithConfig(*network, *config); 142 | } 143 | 144 | //====================================================================================================================== 145 | /// Run a single inference 146 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 147 | std::vector &outputTensor, void **bindings, int batchSize) { 148 | 149 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 150 | 151 | // Infer asynchronously, in a proper cuda way ! 152 | using namespace std; 153 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 154 | stream); 155 | context->enqueueV2(bindings, stream, nullptr); 156 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 157 | cudaMemcpyDeviceToHost, stream); 158 | } 159 | 160 | //====================================================================================================================== 161 | int main() { 162 | using namespace std; 163 | using namespace nvinfer1; 164 | 165 | // Parse model, create engine 166 | Logger logger; 167 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example4 !!! "); 168 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 169 | int batchSize = 2; 170 | unique_ptr> engine(createCudaEngine(logger, batchSize)); 171 | 172 | if (!engine) 173 | throw runtime_error("Engine creation failed !"); 174 | 175 | // Optional : Print all bindings : name + dims + dtype 176 | cout << "=============\nBindings :\n"; 177 | int n = engine->getNbBindings(); 178 | for (int i = 0; i < n; ++i) { 179 | Dims d = engine->getBindingDimensions(i); 180 | cout << i << " : " << engine->getBindingName(i) << " : dims=" << printDim(d); 181 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 182 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 183 | } 184 | cout << "=============\n\n"; 185 | 186 | // Create context 187 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 188 | unique_ptr> context(engine->createExecutionContext()); 189 | // Very important, you must set batch size here, otherwise you get zero output ! 190 | context->setBindingDimensions(0, Dims2(batchSize, 3)); 191 | 192 | // Create data structures for the inference 193 | cudaStream_t stream; 194 | cudaStreamCreate(&stream); 195 | vector inputTensor{0.5, -0.5, 1.0, 0.0, 0.0, 0.0}; 196 | vector outputTensor(2 * batchSize, -4.9); 197 | void *bindings[2]{0}; 198 | // Alloc cuda memory for IO tensors 199 | size_t sizes[] = {inputTensor.size(), outputTensor.size()}; 200 | for (int i = 0; i < engine->getNbBindings(); ++i) { 201 | // Create CUDA buffer for Tensor. 202 | cudaMalloc(&bindings[i], sizes[i] * sizeof(float)); 203 | } 204 | 205 | // Run the inference ! 206 | cout << "Running the inference !" << endl; 207 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 208 | cudaStreamSynchronize(stream); 209 | // Must be [ [1.5, 3.5], [-1,-2] ] 210 | cout << "y = ["; 211 | for (int i = 0; i < batchSize; ++i) { 212 | cout << " [" << outputTensor.at(2 * i) << ", " << outputTensor.at(2 * i + 1) << "]"; 213 | if (i < batchSize - 1) 214 | cout << ", "; 215 | } 216 | cout << " ]" << endl; 217 | 218 | cudaStreamDestroy(stream); 219 | cudaFree(bindings[0]); 220 | cudaFree(bindings[1]); 221 | return 0; 222 | } 223 | -------------------------------------------------------------------------------- /example5.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 5 : Here I construct network in-place, no ONNX parsing ! 3 | // Here I use a single Fully Connected layer 4 | // I use here dynamic batches, a batch of 2 5 | // I also tried int8, but it was not selected 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include 17 | 18 | //====================================================================================================================== 19 | 20 | class Logger : public nvinfer1::ILogger { 21 | public: 22 | void log(Severity severity, const char *msg) override { 23 | using namespace std; 24 | string s; 25 | switch (severity) { 26 | case Severity::kINTERNAL_ERROR: 27 | s = "INTERNAL_ERROR"; 28 | break; 29 | case Severity::kERROR: 30 | s = "ERROR"; 31 | break; 32 | case Severity::kWARNING: 33 | s = "WARNING"; 34 | break; 35 | case Severity::kINFO: 36 | s = "INFO"; 37 | break; 38 | case Severity::kVERBOSE: 39 | s = "VERBOSE"; 40 | break; 41 | } 42 | cerr << s << ": " << msg << endl; 43 | } 44 | }; 45 | //====================================================================================================================== 46 | 47 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 48 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 49 | template 50 | struct Destroy { 51 | void operator()(T *t) const { 52 | t->destroy(); 53 | } 54 | }; 55 | 56 | //====================================================================================================================== 57 | /// Optional : Print dimensions as string 58 | std::string printDim(const nvinfer1::Dims & d) { 59 | using namespace std; 60 | ostringstream oss; 61 | for (int j = 0; j < d.nbDims; ++j) { 62 | oss << d.d[j]; 63 | if (j < d.nbDims - 1) 64 | oss << "x"; 65 | } 66 | return oss.str(); 67 | } 68 | //====================================================================================================================== 69 | /// Optional : Print layers of the network 70 | void printNetwork(const nvinfer1::INetworkDefinition &net) { 71 | using namespace std; 72 | using namespace nvinfer1; 73 | cout << "\n=============\nNetwork info :" << endl; 74 | 75 | cout << "\nInputs : " << endl; 76 | for (int i = 0; i < net.getNbInputs(); ++i) { 77 | ITensor * inp = net.getInput(i); 78 | cout << "input" << i << " , dtype=" << (int)inp->getType() << " , dims=" << printDim(inp->getDimensions()) << endl; 79 | } 80 | 81 | cout << "\nLayers : " << endl; 82 | cout << "Number of layers : " << net.getNbLayers() << endl; 83 | for (int i = 0; i < net.getNbLayers(); ++i) { 84 | ILayer *l = net.getLayer(i); 85 | cout << "layer" << i << " , name=" << l->getName() << " , type=" << (int)l->getType() << " , IN "; 86 | for (int j = 0; j < l->getNbInputs(); ++j) 87 | cout << printDim(l->getInput(j)->getDimensions()) << " "; 88 | cout << ", OUT "; 89 | for (int j = 0; j < l->getNbOutputs(); ++j) 90 | cout << printDim(l->getOutput(j)->getDimensions()) << " "; 91 | cout << endl; 92 | } 93 | 94 | cout << "\nOutputs : " << endl; 95 | for (int i = 0; i < net.getNbOutputs(); ++i) { 96 | ITensor * outp = net.getOutput(i); 97 | cout << "input" << i << " , dtype=" << (int)outp->getType() << " , dims=" << printDim(outp->getDimensions()) << endl; 98 | } 99 | 100 | cout << "=============\n" << endl; 101 | } 102 | //====================================================================================================================== 103 | 104 | /// Create model and create a TRT engine 105 | nvinfer1::ICudaEngine *createCudaEngine(nvinfer1::ILogger &logger, int batchSize) { 106 | using namespace std; 107 | using namespace nvinfer1; 108 | 109 | unique_ptr> builder{createInferBuilder(logger)}; 110 | unique_ptr> network{ 111 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH) 112 | }; 113 | 114 | // Weights + bias 115 | vector w0 = {1., 2., 3., 4., 5., 6.}; 116 | // I used this version to test FP16, which gives lower precision 117 | // vector w0 = {1., 2., 3., 4., 5., sqrt(37.123456789f)}; 118 | vector b0 = {-1., -2.}; 119 | Weights w{DataType::kFLOAT, w0.data(), (int64_t) w0.size()}; 120 | Weights b{DataType::kFLOAT, b0.data(), (int64_t) b0.size()}; 121 | 122 | // A single 4D FC layer 123 | ITensor *input = network->addInput("goblin_input", DataType::kFLOAT, Dims4(-1, 1, 1, 3)); 124 | IFullyConnectedLayer *fc = network->addFullyConnected(*input, 2, w, b); 125 | ITensor *output = fc->getOutput(0); 126 | output->setName("goblin_output"); 127 | network->markOutput(*output); 128 | 129 | printNetwork(*network); 130 | 131 | 132 | // Are fancy types available ? 133 | cout << "platformHasFastFp16 = " << builder->platformHasFastFp16() << endl; 134 | cout << "platformHasFastInt8 = " << builder->platformHasFastInt8() << endl; 135 | 136 | // Create Optimization profile and set the batch size 137 | IOptimizationProfile *profile = builder->createOptimizationProfile(); 138 | profile->setDimensions("goblin_input", OptProfileSelector::kMIN, Dims4{batchSize, 1, 1, 3}); 139 | profile->setDimensions("goblin_input", OptProfileSelector::kMAX, Dims4{batchSize, 1, 1, 3}); 140 | profile->setDimensions("goblin_input", OptProfileSelector::kOPT, Dims4{batchSize, 1, 1, 3}); 141 | 142 | // Set up the config 143 | unique_ptr> config(builder->createBuilderConfig()); 144 | // This is needed for TensorRT 6, not needed by 7 ! 145 | config->setMaxWorkspaceSize(64*1024*1024); 146 | config->addOptimizationProfile(profile); 147 | 148 | // Int8 quantization with the explicit range 149 | config->setFlag(BuilderFlag::kINT8); 150 | // config->setFlag(BuilderFlag::kFP16); 151 | config->setFlag(BuilderFlag::kSTRICT_TYPES); 152 | 153 | // Set the dynamic range for all layers and input 154 | float minRange = -17., maxRange = 17.; 155 | cout << "layers = " << network->getNbLayers() << endl; 156 | for (int i = 0; i < network->getNbLayers(); ++i) { 157 | ILayer *layer = network->getLayer(i); 158 | ITensor *tensor = layer->getOutput(0); 159 | tensor->setDynamicRange(minRange, maxRange); 160 | // layer->setPrecision(DataType::kINT8); 161 | // layer->setOutputType(0, DataType::kINT8); 162 | } 163 | network->getInput(0)->setDynamicRange(minRange, maxRange); 164 | 165 | // Build engine 166 | return builder->buildEngineWithConfig(*network, *config); 167 | } 168 | 169 | //====================================================================================================================== 170 | /// Run a single inference 171 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 172 | std::vector &outputTensor, void **bindings, int batchSize) { 173 | 174 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 175 | 176 | // Infer asynchronously, in a proper cuda way ! 177 | using namespace std; 178 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 179 | stream); 180 | context->enqueueV2(bindings, stream, nullptr); 181 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 182 | cudaMemcpyDeviceToHost, stream); 183 | } 184 | 185 | //====================================================================================================================== 186 | int main() { 187 | using namespace std; 188 | using namespace nvinfer1; 189 | 190 | // Parse model, create engine 191 | Logger logger; 192 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example5 !!! "); 193 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 194 | int batchSize = 2; 195 | unique_ptr> engine(createCudaEngine(logger, batchSize)); 196 | 197 | if (!engine) 198 | throw runtime_error("Engine creation failed !"); 199 | 200 | // Optional : Print all bindings : name + dims + dtype 201 | cout << "=============\nBindings :\n"; 202 | int n = engine->getNbBindings(); 203 | for (int i = 0; i < n; ++i) { 204 | Dims d = engine->getBindingDimensions(i); 205 | cout << i << " : " << engine->getBindingName(i) << " : dims=" << printDim(d); 206 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 207 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 208 | } 209 | cout << "=============\n\n"; 210 | 211 | // Create context 212 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 213 | unique_ptr> context(engine->createExecutionContext()); 214 | // Very important, you must set batch size here, otherwise you get zero output ! 215 | context->setBindingDimensions(0, Dims4(batchSize, 1, 1, 3)); 216 | 217 | // Create data structures for the inference 218 | cudaStream_t stream; 219 | cudaStreamCreate(&stream); 220 | vector inputTensor{0.5, -0.5, 1.0, 0.0, 0.0, 0.0}; 221 | vector outputTensor(2 * batchSize, -4.9); 222 | void *bindings[2]{0}; 223 | // Alloc cuda memory for IO tensors 224 | size_t sizes[] = {inputTensor.size(), outputTensor.size()}; 225 | for (int i = 0; i < engine->getNbBindings(); ++i) { 226 | // Create CUDA buffer for Tensor. 227 | cudaMalloc(&bindings[i], sizes[i] * sizeof(float)); 228 | } 229 | 230 | // Run the inference ! 231 | cout << "Running the inference !" << endl; 232 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 233 | cudaStreamSynchronize(stream); 234 | // Must be [ [1.5, 3.5], [-1,-2] ] 235 | cout << "y = ["; 236 | cout.precision(20); 237 | for (int i = 0; i < batchSize; ++i) { 238 | cout << " [" << outputTensor.at(2 * i) << ", " << outputTensor.at(2 * i + 1) << "]"; 239 | if (i < batchSize - 1) 240 | cout << ", "; 241 | } 242 | cout << " ]" << endl; 243 | 244 | cudaStreamDestroy(stream); 245 | cudaFree(bindings[0]); 246 | cudaFree(bindings[1]); 247 | return 0; 248 | } 249 | -------------------------------------------------------------------------------- /example6.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 6 : Here I construct network in-place, no ONNX parsing ! 3 | // Here I use a single Convolution layer 4 | // I use here dynamic batches, a batch of 1 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | #include 16 | 17 | //====================================================================================================================== 18 | 19 | class Logger : public nvinfer1::ILogger { 20 | public: 21 | void log(Severity severity, const char *msg) override { 22 | using namespace std; 23 | string s; 24 | switch (severity) { 25 | case Severity::kINTERNAL_ERROR: 26 | s = "INTERNAL_ERROR"; 27 | break; 28 | case Severity::kERROR: 29 | s = "ERROR"; 30 | break; 31 | case Severity::kWARNING: 32 | s = "WARNING"; 33 | break; 34 | case Severity::kINFO: 35 | s = "INFO"; 36 | break; 37 | case Severity::kVERBOSE: 38 | s = "VERBOSE"; 39 | break; 40 | } 41 | cerr << s << ": " << msg << endl; 42 | } 43 | }; 44 | //====================================================================================================================== 45 | 46 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 47 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 48 | template 49 | struct Destroy { 50 | void operator()(T *t) const { 51 | t->destroy(); 52 | } 53 | }; 54 | 55 | //====================================================================================================================== 56 | /// Optional : Print dimensions as string 57 | std::string printDim(const nvinfer1::Dims &d) { 58 | using namespace std; 59 | ostringstream oss; 60 | for (int j = 0; j < d.nbDims; ++j) { 61 | oss << d.d[j]; 62 | if (j < d.nbDims - 1) 63 | oss << "x"; 64 | } 65 | return oss.str(); 66 | } 67 | 68 | //====================================================================================================================== 69 | /// Optional : Print layers of the network 70 | void printNetwork(const nvinfer1::INetworkDefinition &net) { 71 | using namespace std; 72 | using namespace nvinfer1; 73 | cout << "\n=============\nNetwork info :" << endl; 74 | 75 | cout << "\nInputs : " << endl; 76 | for (int i = 0; i < net.getNbInputs(); ++i) { 77 | ITensor *inp = net.getInput(i); 78 | cout << "input" << i << " , dtype=" << (int) inp->getType() << " , dims=" << printDim(inp->getDimensions()) 79 | << endl; 80 | } 81 | 82 | cout << "\nLayers : " << endl; 83 | cout << "Number of layers : " << net.getNbLayers() << endl; 84 | for (int i = 0; i < net.getNbLayers(); ++i) { 85 | ILayer *l = net.getLayer(i); 86 | cout << "layer" << i << " , name=" << l->getName() << " , type=" << (int) l->getType() << " , IN "; 87 | for (int j = 0; j < l->getNbInputs(); ++j) 88 | cout << printDim(l->getInput(j)->getDimensions()) << " "; 89 | cout << ", OUT "; 90 | for (int j = 0; j < l->getNbOutputs(); ++j) 91 | cout << printDim(l->getOutput(j)->getDimensions()) << " "; 92 | cout << endl; 93 | } 94 | 95 | cout << "\nOutputs : " << endl; 96 | for (int i = 0; i < net.getNbOutputs(); ++i) { 97 | ITensor *outp = net.getOutput(i); 98 | cout << "input" << i << " , dtype=" << (int) outp->getType() << " , dims=" << printDim(outp->getDimensions()) 99 | << endl; 100 | } 101 | 102 | cout << "=============\n" << endl; 103 | } 104 | //====================================================================================================================== 105 | 106 | /// Create model and create a TRT engine 107 | nvinfer1::ICudaEngine *createCudaEngine(nvinfer1::ILogger &logger, int batchSize) { 108 | using namespace std; 109 | using namespace nvinfer1; 110 | 111 | unique_ptr> builder{createInferBuilder(logger)}; 112 | unique_ptr> network{ 113 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH) 114 | }; 115 | 116 | // Weights + bias 117 | float q = 1.0f / 9; 118 | vector w0 = {q, q, q, q, q, q, q, q, q}; // 3x3 box filter 119 | vector b0 = {0}; 120 | Weights w{DataType::kFLOAT, w0.data(), (int64_t) w0.size()}; 121 | Weights b{DataType::kFLOAT, b0.data(), (int64_t) b0.size()}; 122 | 123 | // A single 4D FC layer 124 | ITensor *input = network->addInput("goblin_input", DataType::kFLOAT, Dims4(-1, 1, 10, 10)); 125 | IConvolutionLayer *conv = network->addConvolutionNd(*input, 1, Dims2(3, 3), w, b); 126 | ITensor *output = conv->getOutput(0); 127 | output->setName("goblin_output"); 128 | network->markOutput(*output); 129 | 130 | printNetwork(*network); 131 | 132 | 133 | // Are fancy types available ? 134 | cout << "platformHasFastFp16 = " << builder->platformHasFastFp16() << endl; 135 | cout << "platformHasFastInt8 = " << builder->platformHasFastInt8() << endl; 136 | 137 | // Create Optimization profile and set the batch size 138 | IOptimizationProfile *profile = builder->createOptimizationProfile(); 139 | profile->setDimensions("goblin_input", OptProfileSelector::kMIN, Dims4{batchSize, 1, 10, 10}); 140 | profile->setDimensions("goblin_input", OptProfileSelector::kMAX, Dims4{batchSize, 1, 10, 10}); 141 | profile->setDimensions("goblin_input", OptProfileSelector::kOPT, Dims4{batchSize, 1, 10, 10}); 142 | 143 | // Set up the config 144 | unique_ptr> config(builder->createBuilderConfig()); 145 | // This is needed for TensorRT 6, not needed by 7 ! 146 | config->setMaxWorkspaceSize(64 * 1024 * 1024); 147 | config->addOptimizationProfile(profile); 148 | 149 | // Int8 quantization with the explicit range 150 | config->setFlag(BuilderFlag::kINT8); 151 | // config->setFlag(BuilderFlag::kFP16); 152 | config->setFlag(BuilderFlag::kSTRICT_TYPES); 153 | 154 | // Set the dynamic range for all layers and input 155 | float minRange = -17., maxRange = 17.; 156 | cout << "layers = " << network->getNbLayers() << endl; 157 | for (int i = 0; i < network->getNbLayers(); ++i) { 158 | ILayer *layer = network->getLayer(i); 159 | ITensor *tensor = layer->getOutput(0); 160 | tensor->setDynamicRange(minRange, maxRange); 161 | layer->setPrecision(DataType::kINT8); 162 | layer->setOutputType(0, DataType::kINT8); 163 | } 164 | network->getInput(0)->setDynamicRange(minRange, maxRange); 165 | 166 | // Build engine 167 | return builder->buildEngineWithConfig(*network, *config); 168 | } 169 | 170 | //====================================================================================================================== 171 | /// Run a single inference 172 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 173 | std::vector &outputTensor, void **bindings, int batchSize) { 174 | 175 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 176 | 177 | // Infer asynchronously, in a proper cuda way ! 178 | using namespace std; 179 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 180 | stream); 181 | context->enqueueV2(bindings, stream, nullptr); 182 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 183 | cudaMemcpyDeviceToHost, stream); 184 | } 185 | 186 | //====================================================================================================================== 187 | int main() { 188 | using namespace std; 189 | using namespace nvinfer1; 190 | 191 | // Parse model, create engine 192 | Logger logger; 193 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example6 !!! "); 194 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 195 | int batchSize = 1; 196 | unique_ptr> engine(createCudaEngine(logger, batchSize)); 197 | 198 | if (!engine) 199 | throw runtime_error("Engine creation failed !"); 200 | 201 | // Optional : Print all bindings : name + dims + dtype 202 | cout << "=============\nBindings :\n"; 203 | int n = engine->getNbBindings(); 204 | for (int i = 0; i < n; ++i) { 205 | Dims d = engine->getBindingDimensions(i); 206 | cout << i << " : " << engine->getBindingName(i) << " : dims=" << printDim(d); 207 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 208 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 209 | } 210 | cout << "=============\n\n"; 211 | 212 | 213 | // Create context 214 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 215 | unique_ptr> context(engine->createExecutionContext()); 216 | // Very important, you must set batch size here, otherwise you get zero output ! 217 | context->setBindingDimensions(0, Dims4(batchSize, 1, 10, 10)); 218 | 219 | // Create data structures for the inference 220 | cudaStream_t stream; 221 | cudaStreamCreate(&stream); 222 | vector inputTensor(10*10*batchSize, 3.1); 223 | vector outputTensor(8*8 * batchSize, -4.9); 224 | for (int iy = 0; iy < 10; ++iy) { 225 | for (int ix = 0; ix < 10; ++ix) { 226 | inputTensor[iy*10 + ix] = (ix + iy) % 2; 227 | } 228 | } 229 | cout << "input = " << endl; 230 | for (int iy = 0; iy < 10; ++iy) { 231 | for (int ix = 0; ix < 10; ++ix) { 232 | cout << inputTensor[iy*10 + ix] << " "; 233 | } 234 | cout << endl; 235 | } 236 | 237 | void *bindings[2]{0}; 238 | // Alloc cuda memory for IO tensors 239 | size_t sizes[] = {inputTensor.size(), outputTensor.size()}; 240 | for (int i = 0; i < engine->getNbBindings(); ++i) { 241 | // Create CUDA buffer for Tensor. 242 | cudaMalloc(&bindings[i], sizes[i] * sizeof(float)); 243 | } 244 | 245 | // Run the inference ! 246 | cout << "Running the inference !" << endl; 247 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 248 | cudaStreamSynchronize(stream); 249 | 250 | cout << "output = " << endl; 251 | for (int iy = 0; iy < 8; ++iy) { 252 | for (int ix = 0; ix < 8; ++ix) { 253 | cout << outputTensor[iy*8 + ix] << " "; 254 | } 255 | cout << endl; 256 | } 257 | 258 | cudaStreamDestroy(stream); 259 | cudaFree(bindings[0]); 260 | cudaFree(bindings[1]); 261 | return 0; 262 | } 263 | -------------------------------------------------------------------------------- /example7.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 7 : Finally I succeeded with int8 (if your GPU supports it) 3 | // It seems only combinations conv+relu are available 4 | // Also, for some reason, one layer is not enough, so I used conv->relu->conv->relu network 5 | // Update : added inference 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | 16 | #include 17 | 18 | constexpr bool USE_INT8 = true; 19 | 20 | //====================================================================================================================== 21 | 22 | class Logger : public nvinfer1::ILogger { 23 | public: 24 | void log(Severity severity, const char *msg) override { 25 | using namespace std; 26 | string s; 27 | switch (severity) { 28 | case Severity::kINTERNAL_ERROR: 29 | s = "INTERNAL_ERROR"; 30 | break; 31 | case Severity::kERROR: 32 | s = "ERROR"; 33 | break; 34 | case Severity::kWARNING: 35 | s = "WARNING"; 36 | break; 37 | case Severity::kINFO: 38 | s = "INFO"; 39 | break; 40 | case Severity::kVERBOSE: 41 | s = "VERBOSE"; 42 | break; 43 | } 44 | cerr << s << ": " << msg << endl; 45 | } 46 | }; 47 | //====================================================================================================================== 48 | 49 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 50 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 51 | template 52 | struct Destroy { 53 | void operator()(T *t) const { 54 | t->destroy(); 55 | } 56 | }; 57 | 58 | //====================================================================================================================== 59 | /// Optional : Print dimensions as string 60 | std::string printDim(const nvinfer1::Dims &d) { 61 | using namespace std; 62 | ostringstream oss; 63 | for (int j = 0; j < d.nbDims; ++j) { 64 | oss << d.d[j]; 65 | if (j < d.nbDims - 1) 66 | oss << "x"; 67 | } 68 | return oss.str(); 69 | } 70 | 71 | //====================================================================================================================== 72 | /// Optional : Print layers of the network 73 | void printNetwork(const nvinfer1::INetworkDefinition &net) { 74 | using namespace std; 75 | using namespace nvinfer1; 76 | cout << "\n=============\nNetwork info :" << endl; 77 | 78 | cout << "\nInputs : " << endl; 79 | for (int i = 0; i < net.getNbInputs(); ++i) { 80 | ITensor *inp = net.getInput(i); 81 | cout << "input" << i << " , dtype=" << (int) inp->getType() << " , dims=" << printDim(inp->getDimensions()) 82 | << endl; 83 | } 84 | 85 | cout << "\nLayers : " << endl; 86 | cout << "Number of layers : " << net.getNbLayers() << endl; 87 | for (int i = 0; i < net.getNbLayers(); ++i) { 88 | ILayer *l = net.getLayer(i); 89 | cout << "layer" << i << " , name=" << l->getName() << " , type=" << (int) l->getType() << " , IN "; 90 | for (int j = 0; j < l->getNbInputs(); ++j) 91 | cout << "(" << int(l->getInput(j)->getType()) << ") " << printDim(l->getInput(j)->getDimensions()) << " "; 92 | cout << ", OUT "; 93 | for (int j = 0; j < l->getNbOutputs(); ++j) 94 | cout << "(" << int(l->getOutput(j)->getType()) << ") " << printDim(l->getOutput(j)->getDimensions()) << " "; 95 | cout << endl; 96 | switch (l->getType()) { 97 | case (LayerType::kCONVOLUTION): { 98 | IConvolutionLayer *lc = static_cast(l); 99 | cout << "CONVOLUTION : "; 100 | cout << "ker=" << printDim(lc->getKernelSizeNd()); 101 | cout << ", stride=" << printDim(lc->getStrideNd()); 102 | cout << ", padding=" << printDim(lc->getPaddingNd()); 103 | cout << ", groups=" << lc->getNbGroups(); 104 | Weights w = lc->getKernelWeights(); 105 | cout << ", weights=" << w.count << ":" << int(w.type); 106 | cout << endl; 107 | } 108 | break; 109 | case (LayerType::kSCALE): { 110 | IScaleLayer *ls = static_cast(l); 111 | cout << "SCALE: "; 112 | cout << "mode=" << int(ls->getMode()); 113 | Weights ws = ls->getScale(); 114 | cout << ", scale=" << ws.count << ":" << int(ws.type); 115 | Weights wp = ls->getPower(); 116 | cout << ", power=" << wp.count << ":" << int(wp.type); 117 | Weights wf = ls->getShift(); 118 | cout << ", shift=" << wf.count << ":" << int(wf.type); 119 | cout << endl; 120 | } 121 | break; 122 | } 123 | } 124 | 125 | cout << "\nOutputs : " << endl; 126 | for (int i = 0; i < net.getNbOutputs(); ++i) { 127 | ITensor *outp = net.getOutput(i); 128 | cout << "input" << i << " , dtype=" << (int) outp->getType() << " , dims=" << printDim(outp->getDimensions()) 129 | << endl; 130 | } 131 | 132 | cout << "=============\n" << endl; 133 | } 134 | //====================================================================================================================== 135 | 136 | /// Create model and create a TRT engine 137 | nvinfer1::ICudaEngine *createCudaEngine(nvinfer1::ILogger &logger, int batchSize) { 138 | using namespace std; 139 | using namespace nvinfer1; 140 | 141 | unique_ptr> builder{createInferBuilder(logger)}; 142 | unique_ptr> network{ 143 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH) 144 | }; 145 | 146 | // Network 147 | ITensor *input = network->addInput("goblin_input", DataType::kFLOAT, Dims4(1, 3, 224, 224)); 148 | 149 | // conv1 150 | vector wwC1(7 * 7 * 3 * 64, 0.0123); // 3x3 box filter 151 | vector bbC1(64, 0.5); 152 | Weights wC1{DataType::kFLOAT, wwC1.data(), (int64_t) wwC1.size()}; 153 | Weights bC1{DataType::kFLOAT, bbC1.data(), (int64_t) bbC1.size()}; 154 | IConvolutionLayer *conv1 = network->addConvolutionNd(*input, 64, Dims2(7, 7), wC1, bC1); 155 | conv1->setStrideNd(Dims2(2, 2)); 156 | conv1->setPaddingNd(Dims2(3, 3)); 157 | // relu1 158 | IActivationLayer *relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU); 159 | 160 | // conv2 161 | vector wwC2(3 * 3 * 64 * 128, 0.01231); // 3x3 box filter 162 | vector bbC2(128, 0.4); 163 | Weights wC2{DataType::kFLOAT, wwC2.data(), (int64_t) wwC2.size()}; 164 | Weights bC2{DataType::kFLOAT, bbC2.data(), (int64_t) bbC2.size()}; 165 | IConvolutionLayer *conv2 = network->addConvolutionNd(*relu1->getOutput(0), 128, Dims2(3, 3), wC2, bC2); 166 | conv1->setStrideNd(Dims2(2, 2)); 167 | conv1->setPaddingNd(Dims2(1, 1)); 168 | // relu2 169 | IActivationLayer *relu2 = network->addActivation(*conv2->getOutput(0), ActivationType::kRELU); 170 | 171 | ITensor *output = relu2->getOutput(0); 172 | output->setName("goblin_output"); 173 | network->markOutput(*output); 174 | 175 | printNetwork(*network); 176 | 177 | // Are fancy types available ? 178 | cout << "platformHasFastFp16 = " << builder->platformHasFastFp16() << endl; 179 | cout << "platformHasFastInt8 = " << builder->platformHasFastInt8() << endl; 180 | 181 | // Set up the config 182 | unique_ptr> config(builder->createBuilderConfig()); 183 | // This is needed for TensorRT 6, not needed by 7 ! 184 | config->setMaxWorkspaceSize(1024 * 1024 * 1024); 185 | 186 | if (USE_INT8) { 187 | // Int8 quantization with the explicit range 188 | config->setFlag(BuilderFlag::kINT8); 189 | config->setFlag(BuilderFlag::kSTRICT_TYPES); 190 | 191 | // Set the dynamic range for all layers and input 192 | float minRange = -17., maxRange = 17.; 193 | cout << "layers = " << network->getNbLayers() << endl; 194 | for (int i = 0; i < network->getNbLayers(); ++i) { 195 | ILayer *layer = network->getLayer(i); 196 | ITensor *tensor = layer->getOutput(0); 197 | tensor->setDynamicRange(minRange, maxRange); 198 | layer->setPrecision(DataType::kINT8); 199 | layer->setOutputType(0, DataType::kINT8); 200 | } 201 | network->getInput(0)->setDynamicRange(minRange, maxRange); 202 | } 203 | 204 | // Build engine 205 | return builder->buildEngineWithConfig(*network, *config); 206 | } 207 | //====================================================================================================================== 208 | /// Run a single inference 209 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 210 | std::vector &outputTensor, void **bindings, int batchSize) { 211 | 212 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 213 | 214 | // Infer asynchronously, in a proper cuda way ! 215 | using namespace std; 216 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 217 | stream); 218 | context->enqueueV2(bindings, stream, nullptr); 219 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 220 | cudaMemcpyDeviceToHost, stream); 221 | } 222 | //====================================================================================================================== 223 | int main() { 224 | using namespace std; 225 | using namespace nvinfer1; 226 | 227 | // Parse model, create engine 228 | Logger logger; 229 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example7 !!! "); 230 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 231 | int batchSize = 1; 232 | unique_ptr> engine(createCudaEngine(logger, batchSize)); 233 | 234 | if (!engine) 235 | throw runtime_error("Engine creation failed !"); 236 | 237 | // Optional : Print all bindings : name + dims + dtype 238 | cout << "=============\nBindings :\n"; 239 | int n = engine->getNbBindings(); 240 | for (int i = 0; i < n; ++i) { 241 | Dims d = engine->getBindingDimensions(i); 242 | cout << i << " : " << engine->getBindingName(i) << " : dims=" << printDim(d); 243 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 244 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 245 | } 246 | cout << "=============\n\n"; 247 | 248 | // Create context 249 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 250 | unique_ptr> context(engine->createExecutionContext()); 251 | 252 | // Create data structures for the inference 253 | cudaStream_t stream; 254 | cudaStreamCreate(&stream); 255 | vector inputTensor(3*224*224*batchSize, 3.1); 256 | vector outputTensor(128*108*108 * batchSize, -4.9); 257 | for (int iy = 0; iy < 224; ++iy) { 258 | for (int ix = 0; ix < 224; ++ix) { 259 | for (int j = 0; j < 3; ++j) { 260 | inputTensor[iy*224*3 + ix*3 + j] = (ix + iy) % 2; 261 | } 262 | } 263 | } 264 | cout << "input = " << endl; 265 | for (int iy = 0; iy < 10; ++iy) { 266 | for (int ix = 0; ix < 10; ++ix) { 267 | cout << inputTensor[iy*224*3 + ix*3] << " "; 268 | } 269 | cout << endl; 270 | } 271 | 272 | void *bindings[2]{0}; 273 | // Alloc cuda memory for IO tensors 274 | size_t sizes[] = {inputTensor.size(), outputTensor.size()}; 275 | for (int i = 0; i < engine->getNbBindings(); ++i) { 276 | // Create CUDA buffer for Tensor. 277 | cudaMalloc(&bindings[i], sizes[i] * sizeof(float)); 278 | } 279 | 280 | // Run the inference ! 281 | cout << "Running the inference !" << endl; 282 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 283 | cudaStreamSynchronize(stream); 284 | 285 | cout << "output = " << endl; 286 | for (int iy = 0; iy < 8; ++iy) { 287 | for (int ix = 0; ix < 8; ++ix) { 288 | cout << outputTensor[iy*108*128 + ix*128] << " "; 289 | } 290 | cout << endl; 291 | } 292 | 293 | cudaStreamDestroy(stream); 294 | cudaFree(bindings[0]); 295 | cudaFree(bindings[1]); 296 | 297 | return 0; 298 | } 299 | -------------------------------------------------------------------------------- /example7prof.cpp: -------------------------------------------------------------------------------- 1 | // By Oleksiy Grechnyev, IT-JIM 2 | // Example 7prof : Like example 7, but with profiling 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | 13 | #include 14 | 15 | constexpr bool USE_INT8 = false; 16 | 17 | //====================================================================================================================== 18 | class GoblinProfiler: public nvinfer1::IProfiler { 19 | public: 20 | virtual void reportLayerTime(const char *layerName, float ms) { 21 | // Note: this is a minimal output for demo purposes only! 22 | // In real life, it's betetr to remember all data, then print after profiling 23 | // Like in the nvidia example sampleNMT 24 | // cout can be a semi-heavy operation which disturbs the timing 25 | using namespace std; 26 | cout << "PROF : " << layerName << ":" << ms << endl; 27 | } 28 | }; 29 | //====================================================================================================================== 30 | 31 | class Logger : public nvinfer1::ILogger { 32 | public: 33 | void log(Severity severity, const char *msg) override { 34 | using namespace std; 35 | string s; 36 | switch (severity) { 37 | case Severity::kINTERNAL_ERROR: 38 | s = "INTERNAL_ERROR"; 39 | break; 40 | case Severity::kERROR: 41 | s = "ERROR"; 42 | break; 43 | case Severity::kWARNING: 44 | s = "WARNING"; 45 | break; 46 | case Severity::kINFO: 47 | s = "INFO"; 48 | break; 49 | case Severity::kVERBOSE: 50 | s = "VERBOSE"; 51 | break; 52 | } 53 | cerr << s << ": " << msg << endl; 54 | } 55 | }; 56 | //====================================================================================================================== 57 | 58 | /// Using unique_ptr with Destroy is optional, but beats calling destroy() for everything 59 | /// Borrowed from the NVidia tutorial, nice C++ skills ! 60 | template 61 | struct Destroy { 62 | void operator()(T *t) const { 63 | t->destroy(); 64 | } 65 | }; 66 | 67 | //====================================================================================================================== 68 | /// Optional : Print dimensions as string 69 | std::string printDim(const nvinfer1::Dims &d) { 70 | using namespace std; 71 | ostringstream oss; 72 | for (int j = 0; j < d.nbDims; ++j) { 73 | oss << d.d[j]; 74 | if (j < d.nbDims - 1) 75 | oss << "x"; 76 | } 77 | return oss.str(); 78 | } 79 | 80 | //====================================================================================================================== 81 | /// Optional : Print layers of the network 82 | void printNetwork(const nvinfer1::INetworkDefinition &net) { 83 | using namespace std; 84 | using namespace nvinfer1; 85 | cout << "\n=============\nNetwork info :" << endl; 86 | 87 | cout << "\nInputs : " << endl; 88 | for (int i = 0; i < net.getNbInputs(); ++i) { 89 | ITensor *inp = net.getInput(i); 90 | cout << "input" << i << " , dtype=" << (int) inp->getType() << " , dims=" << printDim(inp->getDimensions()) 91 | << endl; 92 | } 93 | 94 | cout << "\nLayers : " << endl; 95 | cout << "Number of layers : " << net.getNbLayers() << endl; 96 | for (int i = 0; i < net.getNbLayers(); ++i) { 97 | ILayer *l = net.getLayer(i); 98 | cout << "layer" << i << " , name=" << l->getName() << " , type=" << (int) l->getType() << " , IN "; 99 | for (int j = 0; j < l->getNbInputs(); ++j) 100 | cout << "(" << int(l->getInput(j)->getType()) << ") " << printDim(l->getInput(j)->getDimensions()) << " "; 101 | cout << ", OUT "; 102 | for (int j = 0; j < l->getNbOutputs(); ++j) 103 | cout << "(" << int(l->getOutput(j)->getType()) << ") " << printDim(l->getOutput(j)->getDimensions()) << " "; 104 | cout << endl; 105 | switch (l->getType()) { 106 | case (LayerType::kCONVOLUTION): { 107 | IConvolutionLayer *lc = static_cast(l); 108 | cout << "CONVOLUTION : "; 109 | cout << "ker=" << printDim(lc->getKernelSizeNd()); 110 | cout << ", stride=" << printDim(lc->getStrideNd()); 111 | cout << ", padding=" << printDim(lc->getPaddingNd()); 112 | cout << ", groups=" << lc->getNbGroups(); 113 | Weights w = lc->getKernelWeights(); 114 | cout << ", weights=" << w.count << ":" << int(w.type); 115 | cout << endl; 116 | } 117 | break; 118 | case (LayerType::kSCALE): { 119 | IScaleLayer *ls = static_cast(l); 120 | cout << "SCALE: "; 121 | cout << "mode=" << int(ls->getMode()); 122 | Weights ws = ls->getScale(); 123 | cout << ", scale=" << ws.count << ":" << int(ws.type); 124 | Weights wp = ls->getPower(); 125 | cout << ", power=" << wp.count << ":" << int(wp.type); 126 | Weights wf = ls->getShift(); 127 | cout << ", shift=" << wf.count << ":" << int(wf.type); 128 | cout << endl; 129 | } 130 | break; 131 | } 132 | } 133 | 134 | cout << "\nOutputs : " << endl; 135 | for (int i = 0; i < net.getNbOutputs(); ++i) { 136 | ITensor *outp = net.getOutput(i); 137 | cout << "input" << i << " , dtype=" << (int) outp->getType() << " , dims=" << printDim(outp->getDimensions()) 138 | << endl; 139 | } 140 | 141 | cout << "=============\n" << endl; 142 | } 143 | //====================================================================================================================== 144 | 145 | /// Create model and create a TRT engine 146 | nvinfer1::ICudaEngine *createCudaEngine(nvinfer1::ILogger &logger, int batchSize) { 147 | using namespace std; 148 | using namespace nvinfer1; 149 | 150 | unique_ptr> builder{createInferBuilder(logger)}; 151 | unique_ptr> network{ 152 | builder->createNetworkV2(1U << (unsigned) NetworkDefinitionCreationFlag::kEXPLICIT_BATCH) 153 | }; 154 | 155 | // Network 156 | ITensor *input = network->addInput("goblin_input", DataType::kFLOAT, Dims4(1, 3, 224, 224)); 157 | 158 | // conv1 159 | vector wwC1(7 * 7 * 3 * 64, 0.0123); // 3x3 box filter 160 | vector bbC1(64, 0.5); 161 | Weights wC1{DataType::kFLOAT, wwC1.data(), (int64_t) wwC1.size()}; 162 | Weights bC1{DataType::kFLOAT, bbC1.data(), (int64_t) bbC1.size()}; 163 | IConvolutionLayer *conv1 = network->addConvolutionNd(*input, 64, Dims2(7, 7), wC1, bC1); 164 | conv1->setStrideNd(Dims2(2, 2)); 165 | conv1->setPaddingNd(Dims2(3, 3)); 166 | // relu1 167 | IActivationLayer *relu1 = network->addActivation(*conv1->getOutput(0), ActivationType::kRELU); 168 | 169 | // conv2 170 | vector wwC2(3 * 3 * 64 * 128, 0.01231); // 3x3 box filter 171 | vector bbC2(128, 0.4); 172 | Weights wC2{DataType::kFLOAT, wwC2.data(), (int64_t) wwC2.size()}; 173 | Weights bC2{DataType::kFLOAT, bbC2.data(), (int64_t) bbC2.size()}; 174 | IConvolutionLayer *conv2 = network->addConvolutionNd(*relu1->getOutput(0), 128, Dims2(3, 3), wC2, bC2); 175 | conv1->setStrideNd(Dims2(2, 2)); 176 | conv1->setPaddingNd(Dims2(1, 1)); 177 | // relu2 178 | IActivationLayer *relu2 = network->addActivation(*conv2->getOutput(0), ActivationType::kRELU); 179 | 180 | ITensor *output = relu2->getOutput(0); 181 | output->setName("goblin_output"); 182 | network->markOutput(*output); 183 | 184 | printNetwork(*network); 185 | 186 | // Are fancy types available ? 187 | cout << "platformHasFastFp16 = " << builder->platformHasFastFp16() << endl; 188 | cout << "platformHasFastInt8 = " << builder->platformHasFastInt8() << endl; 189 | 190 | // Set up the config 191 | unique_ptr> config(builder->createBuilderConfig()); 192 | // This is needed for TensorRT 6, not needed by 7 ! 193 | config->setMaxWorkspaceSize(1024 * 1024 * 1024); 194 | 195 | if (USE_INT8) { 196 | // Int8 quantization with the explicit range 197 | config->setFlag(BuilderFlag::kINT8); 198 | config->setFlag(BuilderFlag::kSTRICT_TYPES); 199 | 200 | // Set the dynamic range for all layers and input 201 | float minRange = -17., maxRange = 17.; 202 | cout << "layers = " << network->getNbLayers() << endl; 203 | for (int i = 0; i < network->getNbLayers(); ++i) { 204 | ILayer *layer = network->getLayer(i); 205 | ITensor *tensor = layer->getOutput(0); 206 | tensor->setDynamicRange(minRange, maxRange); 207 | layer->setPrecision(DataType::kINT8); 208 | layer->setOutputType(0, DataType::kINT8); 209 | } 210 | network->getInput(0)->setDynamicRange(minRange, maxRange); 211 | } 212 | 213 | // Build engine 214 | return builder->buildEngineWithConfig(*network, *config); 215 | } 216 | //====================================================================================================================== 217 | /// Run a single inference 218 | void launchInference(nvinfer1::IExecutionContext *context, cudaStream_t stream, std::vector const &inputTensor, 219 | std::vector &outputTensor, void **bindings, int batchSize) { 220 | 221 | int inputId = 0, outputId = 1; // Here I assume input=0, output=1 for the current network 222 | 223 | // Infer asynchronously, in a proper cuda way ! 224 | using namespace std; 225 | cudaMemcpyAsync(bindings[inputId], inputTensor.data(), inputTensor.size() * sizeof(float), cudaMemcpyHostToDevice, 226 | stream); 227 | context->enqueueV2(bindings, stream, nullptr); 228 | cudaMemcpyAsync(outputTensor.data(), bindings[outputId], outputTensor.size() * sizeof(float), 229 | cudaMemcpyDeviceToHost, stream); 230 | } 231 | //====================================================================================================================== 232 | int main() { 233 | using namespace std; 234 | using namespace nvinfer1; 235 | 236 | // Parse model, create engine 237 | Logger logger; 238 | logger.log(ILogger::Severity::kINFO, "C++ TensorRT example7 !!! "); 239 | logger.log(ILogger::Severity::kINFO, "Creating engine ..."); 240 | int batchSize = 1; 241 | unique_ptr> engine(createCudaEngine(logger, batchSize)); 242 | 243 | if (!engine) 244 | throw runtime_error("Engine creation failed !"); 245 | 246 | // Optional : Print all bindings : name + dims + dtype 247 | cout << "=============\nBindings :\n"; 248 | int n = engine->getNbBindings(); 249 | for (int i = 0; i < n; ++i) { 250 | Dims d = engine->getBindingDimensions(i); 251 | cout << i << " : " << engine->getBindingName(i) << " : dims=" << printDim(d); 252 | cout << " , dtype=" << (int) engine->getBindingDataType(i) << " "; 253 | cout << (engine->bindingIsInput(i) ? "IN" : "OUT") << endl; 254 | } 255 | cout << "=============\n\n"; 256 | 257 | // Create context 258 | logger.log(ILogger::Severity::kINFO, "Creating context ..."); 259 | unique_ptr> context(engine->createExecutionContext()); 260 | 261 | // Add profiler to context 262 | GoblinProfiler profiler; 263 | context->setProfiler(&profiler); 264 | 265 | // Create data structures for the inference 266 | cudaStream_t stream; 267 | cudaStreamCreate(&stream); 268 | vector inputTensor(3*224*224*batchSize, 3.1); 269 | vector outputTensor(128*108*108 * batchSize, -4.9); 270 | for (int iy = 0; iy < 224; ++iy) { 271 | for (int ix = 0; ix < 224; ++ix) { 272 | for (int j = 0; j < 3; ++j) { 273 | inputTensor[iy*224*3 + ix*3 + j] = (ix + iy) % 2; 274 | } 275 | } 276 | } 277 | cout << "input = " << endl; 278 | for (int iy = 0; iy < 10; ++iy) { 279 | for (int ix = 0; ix < 10; ++ix) { 280 | cout << inputTensor[iy*224*3 + ix*3] << " "; 281 | } 282 | cout << endl; 283 | } 284 | 285 | void *bindings[2]{0}; 286 | // Alloc cuda memory for IO tensors 287 | size_t sizes[] = {inputTensor.size(), outputTensor.size()}; 288 | for (int i = 0; i < engine->getNbBindings(); ++i) { 289 | // Create CUDA buffer for Tensor. 290 | cudaMalloc(&bindings[i], sizes[i] * sizeof(float)); 291 | } 292 | 293 | // Run the inference ! 294 | cout << "Running the inference !" << endl; 295 | launchInference(context.get(), stream, inputTensor, outputTensor, bindings, batchSize); 296 | cudaStreamSynchronize(stream); 297 | 298 | cout << "output = " << endl; 299 | for (int iy = 0; iy < 8; ++iy) { 300 | for (int ix = 0; ix < 8; ++ix) { 301 | cout << outputTensor[iy*108*128 + ix*128] << " "; 302 | } 303 | cout << endl; 304 | } 305 | 306 | cudaStreamDestroy(stream); 307 | cudaFree(bindings[0]); 308 | cudaFree(bindings[1]); 309 | 310 | return 0; 311 | } 312 | -------------------------------------------------------------------------------- /gen_models.py: -------------------------------------------------------------------------------- 1 | # By Olekisy Grechnyev, IT-JIM on 3/30/20. 2 | # Generate a trivial model in pytorch 3 | # Convert to ONNX 4 | 5 | import torch 6 | import numpy as np 7 | 8 | def main(): 9 | """Create a very simple ONNX model""" 10 | model = torch.nn.Linear(3, 2) 11 | 12 | w, b = model.state_dict()['weight'], model.state_dict()['bias'] 13 | # w, b = model.weight, model.bias 14 | with torch.no_grad(): 15 | w.copy_(torch.tensor([[1., 2., 3.], [4., 5., 6.]])) 16 | b.copy_(torch.tensor([-1., -2.])) 17 | model.cuda() 18 | print('w = ', w, w.dtype) 19 | print('b = ', b, b.dtype) 20 | 21 | # Check the standard result, should be [1.5, 3.5] 22 | x = torch.tensor([0.5, -0.5, 1.0], device='cuda') 23 | y = model(x) 24 | print('x =', x) 25 | print('y =', y) 26 | 27 | # Export to ONNX 28 | model.eval() 29 | x = torch.randn(1, 3, requires_grad=True, device='cuda') 30 | 31 | #Export model1.onnx with batch_size=1 32 | print('\nExporting model1.onnx ...') 33 | torch.onnx.export(model, 34 | x, 35 | 'model1.onnx', 36 | opset_version=9, 37 | verbose=True, 38 | export_params=True, 39 | input_names=['input'], 40 | output_names=['output'], 41 | #dynamic_axes={'input': {0: 'batch_size'}, # variable lenght axes 42 | #'output': {0: 'batch_size'}} 43 | ) 44 | 45 | #Export model2.onnx with dynamic batch_size 46 | print('\nExporting model2.onnx ...') 47 | torch.onnx.export(model, 48 | x, 49 | 'model2.onnx', 50 | #opset_version=11, 51 | verbose=True, 52 | export_params=True, 53 | input_names=['input'], 54 | output_names=['output'], 55 | dynamic_axes={'input': {0: 'batch_size'}, # variable lenght axes 56 | 'output': {0: 'batch_size'}} 57 | ) 58 | 59 | 60 | if __name__ == '__main__': 61 | main() 62 | -------------------------------------------------------------------------------- /model1.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agrechnev/trt-cpp-min/d52d90cfe101f4e426c212bfc813aa77c30d74ec/model1.onnx -------------------------------------------------------------------------------- /model2.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/agrechnev/trt-cpp-min/d52d90cfe101f4e426c212bfc813aa77c30d74ec/model2.onnx --------------------------------------------------------------------------------