├── CMakeLists.txt ├── README.md ├── builder.py ├── compile.sh ├── flattenConcatCustom.cpp ├── flattenConcatCustom.h └── load_trt_engine.cpp /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14) 2 | include_directories(./) 3 | 4 | file(GLOB SRC_FILES "./flattenConcatCustom.cpp") 5 | set(TRT_HOME "path-to-tensorrt/TensorRT-6.0.1.5") 6 | set(CUDA_LIB "path-to-cuda-9.0-lib") 7 | 8 | add_compile_options(-std=c++11) 9 | add_library(flatten_concat SHARED ${SRC_FILES}) 10 | 11 | target_include_directories(flatten_concat PUBLIC ${TRT_HOME}/include) 12 | target_include_directories(flatten_concat PUBLIC ${TRT_HOME}) 13 | target_include_directories(flatten_concat PUBLIC ${TRT_HOME}/cuda/include) 14 | 15 | target_link_libraries(flatten_concat -ldl -lpthread -lrt) 16 | target_link_libraries(flatten_concat ${TRT_HOME}/lib/libnvparsers_static.a) 17 | target_link_libraries(flatten_concat ${TRT_HOME}/lib/libnvinfer_static.a) 18 | target_link_libraries(flatten_concat ${TRT_HOME}/lib/libnvinfer_plugin_static.a) 19 | target_link_libraries(flatten_concat path-to-conda/anaconda3/lib/libstdc++.so) 20 | target_link_libraries(flatten_concat ${CUDA_LIB}/libcudnn_v9.a ${CUDA_LIB}/libcublas_v9.a) 21 | target_link_libraries(flatten_concat ${CUDA_LIB}/libculibos_v9.a) 22 | target_link_libraries(flatten_concat ${CUDA_LIB}/libcudart_v9.a) 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TensorRT-Custom-Plugin 2 | This repository describes: 3 | 4 | (1) how to add a custom TensorRT plugin in c++, 5 | 6 | (2) how to build and serialize network with the custom plugin in python 7 | 8 | (3) how to load and forward the network in c++. 9 | 10 | ## Add custom TensorRT plugin in c++ 11 | We follow [flattenconcat plugin](https://github.com/NVIDIA/TensorRT/tree/release/6.0/plugin/flattenConcat) to create flattenConcat plugin. 12 | 13 | Since the flattenConcat plugin is already in TensorRT, we renamed the class name. 14 | The corresponding source codes are in flattenConcatCustom.cpp flattenConcatCustom.h 15 | We use file [CMakeLists.txt](https://github.com/YirongMao/TensorRT-Custom-Plugin/blob/master/CMakeLists.txt) to build shared lib: libflatten_concat.so 16 | 17 | 18 | ## Build network and serialize engine in python 19 | Please follow [builder.py](https://github.com/YirongMao/TensorRT-Custom-Plugin/blob/master/builder.py). 20 | 21 | 22 | ```python 23 | # You should configure the path to libnvinfer_plugin.so 24 | nvinfer = ctypes.CDLL("/path-to-tensorrt/TensorRT-6.0.1.5/lib/libnvinfer_plugin.so", mode=ctypes.RTLD_GLOBAL) 25 | print('load nvinfer') 26 | pg = ctypes.CDLL("./libflatten_concat.so", mode=ctypes.RTLD_GLOBAL) 27 | print('load customed plugin') 28 | 29 | #TensorRT Initialization 30 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) 31 | trt.init_libnvinfer_plugins(TRT_LOGGER, "") 32 | plg_registry = trt.get_plugin_registry() 33 | # to call the constructor@https://github.com/YirongMao/TensorRT-Custom-Plugin/blob/master/flattenConcatCustom.cpp#L36 34 | plg_creator = plg_registry.get_plugin_creator("FlattenConcatCustom", "1", "") 35 | print(plg_creator) 36 | 37 | axis_pf = trt.PluginField("axis", np.array([1], np.int32), trt.PluginFieldType.INT32) 38 | batch_pf = trt.PluginField("ignoreBatch", np.array([0], np.int32), trt.PluginFieldType.INT32) 39 | 40 | pfc = trt.PluginFieldCollection([axis_pf, batch_pf]) 41 | fn = plg_creator.create_plugin("FlattenConcatCustom1", pfc) 42 | print(fn) 43 | 44 | network = builder.create_network() 45 | input_1 = network.add_input(name="input_1", dtype=trt.float32, shape=(4, 2, 2)) 46 | input_2 = network.add_input(name="input_2", dtype=trt.float32, shape=(2, 2, 2)) 47 | inputs = [input_1, input_2] 48 | # to call configurePlugin@https://github.com/YirongMao/TensorRT-Custom-Plugin/blob/master/flattenConcatCustom.cpp#L258 49 | emb_layer = network.add_plugin_v2(inputs, fn) 50 | ``` 51 | 52 | ## Load network in c++ 53 | Please follow load_trt_engine.cpp. To load the engine with custom plugin, its header *.h file should be included. 54 | 55 | ## env. requirements 56 | TensorRT-6.0.1.5, cuda-9.0 57 | 58 | ## Contacts 59 | If you encounter any problem, be free to create an issue. 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /builder.py: -------------------------------------------------------------------------------- 1 | import tensorrt as trt 2 | import ctypes 3 | import numpy as np 4 | import torch 5 | 6 | nvinfer = ctypes.CDLL("/path-to-tensorrt/TensorRT-6.0.1.5/lib/libnvinfer_plugin.so", mode=ctypes.RTLD_GLOBAL) 7 | print('load nvinfer') 8 | pg = ctypes.CDLL("./libflatten_concat.so", mode=ctypes.RTLD_GLOBAL) 9 | print('load customed plugin') 10 | 11 | #TensorRT Initialization 12 | TRT_LOGGER = trt.Logger(trt.Logger.INFO) 13 | trt.init_libnvinfer_plugins(TRT_LOGGER, "") 14 | plg_registry = trt.get_plugin_registry() 15 | plg_creator = plg_registry.get_plugin_creator("FlattenConcatCustom", "1", "") 16 | print(plg_creator) 17 | 18 | axis_pf = trt.PluginField("axis", np.array([1], np.int32), trt.PluginFieldType.INT32) 19 | batch_pf = trt.PluginField("ignoreBatch", np.array([0], np.int32), trt.PluginFieldType.INT32) 20 | 21 | pfc = trt.PluginFieldCollection([axis_pf, batch_pf]) 22 | fn = plg_creator.create_plugin("FlattenConcatCustom1", pfc) 23 | print(fn) 24 | 25 | builder = trt.Builder(TRT_LOGGER) 26 | builder.max_batch_size = 10 27 | builder.max_workspace_size = 5000 * (1024 * 1024) 28 | builder.strict_type_constraints = False 29 | network = builder.create_network() 30 | 31 | input_1 = network.add_input(name="input_1", dtype=trt.float32, shape=(4, 2, 2)) 32 | input_2 = network.add_input(name="input_2", dtype=trt.float32, shape=(2, 2, 2)) 33 | 34 | inputs = [input_1, input_2] 35 | emb_layer = network.add_plugin_v2(inputs, fn) 36 | print(emb_layer) 37 | embeddings = emb_layer.get_output(0) 38 | network.mark_output(embeddings) 39 | embeddings_shape = embeddings.shape 40 | engine = builder.build_cuda_engine(network) 41 | serialized_engine = engine.serialize() 42 | # TRT_LOGGER.log(TRT_LOGGER.INFO, "Saving the engine....") 43 | with open('flattenconcat.engine', 'wb') as fout: 44 | fout.write(serialized_engine) 45 | print('engine serialized') 46 | -------------------------------------------------------------------------------- /compile.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | rm -rf load_trt_engine 4 | g++ load_trt_engine.cpp flattenConcatCustom.cpp -o load_trt_engine\ 5 | -I path-to-tensorrt/TensorRT-6.0.1.5/include \ 6 | -I path-to-tensorrt/TensorRT-6.0.1.5 \ 7 | -I path-to-cuda-9.0/cuda/include \ 8 | -std=c++11 -ldl -lpthread -lrt\ 9 | -Lpath-to-tensorrt/TensorRT-6.0.1.5/lib \ 10 | -lnvparsers_static -lnvinfer_static -lnvinfer_plugin_static \ 11 | -Lpath-to-conda/anaconda3/lib \ 12 | -lstdc++ \ 13 | -Lpath-to-cuda-0.9/cuda/lib -lcudnn_v9 -lcublas_v9 -lculibos_v9 -lcudart_v9 14 | -------------------------------------------------------------------------------- /flattenConcatCustom.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #include "flattenConcatCustom.h" 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include 24 | #include 25 | 26 | using namespace nvinfer1; 27 | using nvinfer1::plugin::FlattenConcatCustom; 28 | using nvinfer1::plugin::FlattenConcatCustomPluginCreator; 29 | 30 | static const char* FlattenConcatCustom_PLUGIN_VERSION{"1"}; 31 | static const char* FlattenConcatCustom_PLUGIN_NAME{"FlattenConcatCustom"}; 32 | 33 | PluginFieldCollection FlattenConcatCustomPluginCreator::mFC{}; 34 | std::vector FlattenConcatCustomPluginCreator::mPluginAttributes; 35 | 36 | FlattenConcatCustom::FlattenConcatCustom(int concatAxis, bool ignoreBatch) 37 | : mIgnoreBatch(ignoreBatch) 38 | , mConcatAxisID(concatAxis) 39 | { 40 | ASSERT(mConcatAxisID == 1 || mConcatAxisID == 2 || mConcatAxisID == 3); 41 | } 42 | 43 | FlattenConcatCustom::FlattenConcatCustom( 44 | int concatAxis, bool ignoreBatch, int numInputs, int outputConcatAxis, const int* inputConcatAxis, size_t* copySize) 45 | : mIgnoreBatch(ignoreBatch) 46 | , mConcatAxisID(concatAxis) 47 | , mOutputConcatAxis(outputConcatAxis) 48 | , mNumInputs(numInputs) 49 | { 50 | ASSERT(mConcatAxisID == 1 || mConcatAxisID == 2 || mConcatAxisID == 3); 51 | 52 | // Allocate memory for mInputConcatAxis, mCopySize members 53 | LOG_ERROR(cudaMallocHost((void**) &mInputConcatAxis, mNumInputs * sizeof(int))); 54 | LOG_ERROR(cudaMallocHost((void**) &mCopySize, mNumInputs * sizeof(size_t))); 55 | 56 | // Perform deep copy 57 | if (copySize != nullptr) 58 | { 59 | for (int i = 0; i < mNumInputs; i++) 60 | { 61 | mCopySize[i] = static_cast(copySize[i]); 62 | } 63 | } 64 | 65 | for (int i = 0; i < mNumInputs; ++i) 66 | { 67 | mInputConcatAxis[i] = inputConcatAxis[i]; 68 | } 69 | 70 | // Create cublas context 71 | LOG_ERROR(cublasCreate(&mCublas)); 72 | } 73 | 74 | FlattenConcatCustom::FlattenConcatCustom(const void* data, size_t length) 75 | { 76 | const char *d = reinterpret_cast(data), *a = d; 77 | mIgnoreBatch = read(d); 78 | mConcatAxisID = read(d); 79 | ASSERT(mConcatAxisID == 1 || mConcatAxisID == 2 || mConcatAxisID == 3); 80 | mOutputConcatAxis = read(d); 81 | mNumInputs = read(d); 82 | LOG_ERROR(cudaMallocHost((void**) &mInputConcatAxis, mNumInputs * sizeof(int))); 83 | LOG_ERROR(cudaMallocHost((void**) &mCopySize, mNumInputs * sizeof(int))); 84 | 85 | std::for_each(mInputConcatAxis, mInputConcatAxis + mNumInputs, [&](int& inp) { inp = read(d); }); 86 | 87 | mCHW = read(d); 88 | 89 | std::for_each(mCopySize, mCopySize + mNumInputs, [&](size_t& inp) { inp = read(d); }); 90 | 91 | ASSERT(d == a + length); 92 | } 93 | 94 | FlattenConcatCustom::~FlattenConcatCustom() 95 | { 96 | if (mInputConcatAxis) 97 | { 98 | LOG_ERROR(cudaFreeHost(mInputConcatAxis)); 99 | } 100 | if (mCopySize) 101 | { 102 | LOG_ERROR(cudaFreeHost(mCopySize)); 103 | } 104 | } 105 | 106 | int FlattenConcatCustom::getNbOutputs() const 107 | { 108 | return 1; 109 | } 110 | 111 | Dims FlattenConcatCustom::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) 112 | { 113 | ASSERT(nbInputDims >= 1); 114 | ASSERT(index == 0); 115 | 116 | mNumInputs = nbInputDims; 117 | LOG_ERROR(cudaMallocHost((void**) &mInputConcatAxis, nbInputDims * sizeof(int))); 118 | int outputConcatAxis = 0; 119 | 120 | for (int i = 0; i < nbInputDims; ++i) 121 | { 122 | int flattenInput = 0; 123 | ASSERT(inputs[i].nbDims == 3); 124 | if (mConcatAxisID != 1) 125 | { 126 | ASSERT(inputs[i].d[0] == inputs[0].d[0]); 127 | } 128 | if (mConcatAxisID != 2) 129 | { 130 | ASSERT(inputs[i].d[1] == inputs[0].d[1]); 131 | } 132 | if (mConcatAxisID != 3) 133 | { 134 | ASSERT(inputs[i].d[2] == inputs[0].d[2]); 135 | } 136 | flattenInput = inputs[i].d[0] * inputs[i].d[1] * inputs[i].d[2]; 137 | outputConcatAxis += flattenInput; 138 | } 139 | 140 | return DimsCHW(mConcatAxisID == 1 ? outputConcatAxis : 1, mConcatAxisID == 2 ? outputConcatAxis : 1, 141 | mConcatAxisID == 3 ? outputConcatAxis : 1); 142 | } 143 | 144 | int FlattenConcatCustom::initialize() 145 | { 146 | return STATUS_SUCCESS; 147 | } 148 | 149 | void FlattenConcatCustom::terminate() 150 | { 151 | LOG_ERROR(cublasDestroy(mCublas)); 152 | } 153 | 154 | size_t FlattenConcatCustom::getWorkspaceSize(int) const 155 | { 156 | return 0; 157 | } 158 | 159 | int FlattenConcatCustom::enqueue(int batchSize, const void* const* inputs, void** outputs, void*, cudaStream_t stream) 160 | { 161 | int numConcats = 1; 162 | ASSERT(mConcatAxisID != 0); 163 | // mCHW is the first input tensor 164 | numConcats = std::accumulate(mCHW.d, mCHW.d + mConcatAxisID - 1, 1, std::multiplies()); 165 | 166 | LOG_ERROR(cublasSetStream(mCublas, stream)); 167 | 168 | // Num concats will be proportional to number of samples in a batch 169 | if (!mIgnoreBatch) 170 | { 171 | numConcats *= batchSize; 172 | } 173 | 174 | auto* output = reinterpret_cast(outputs[0]); 175 | int offset = 0; 176 | for (int i = 0; i < mNumInputs; ++i) 177 | { 178 | const auto* input = reinterpret_cast(inputs[i]); 179 | float* inputTemp; 180 | LOG_ERROR(cudaMalloc(&inputTemp, mCopySize[i] * batchSize)); 181 | LOG_ERROR(cudaMemcpyAsync(inputTemp, input, mCopySize[i] * batchSize, cudaMemcpyDeviceToDevice, stream)); 182 | 183 | for (int n = 0; n < numConcats; ++n) 184 | { 185 | LOG_ERROR(cublasScopy(mCublas, mInputConcatAxis[i], inputTemp + n * mInputConcatAxis[i], 1, 186 | output + (n * mOutputConcatAxis + offset), 1)); 187 | } 188 | LOG_ERROR(cudaFree(inputTemp)); 189 | offset += mInputConcatAxis[i]; 190 | } 191 | 192 | return 0; 193 | } 194 | 195 | size_t FlattenConcatCustom::getSerializationSize() const 196 | { 197 | return sizeof(bool) + sizeof(int) * (3 + mNumInputs) + sizeof(nvinfer1::Dims) + (sizeof(mCopySize) * mNumInputs); 198 | } 199 | 200 | void FlattenConcatCustom::serialize(void* buffer) const 201 | { 202 | char *d = reinterpret_cast(buffer), *a = d; 203 | write(d, mIgnoreBatch); 204 | write(d, mConcatAxisID); 205 | write(d, mOutputConcatAxis); 206 | write(d, mNumInputs); 207 | for (int i = 0; i < mNumInputs; ++i) 208 | { 209 | write(d, mInputConcatAxis[i]); 210 | } 211 | write(d, mCHW); 212 | for (int i = 0; i < mNumInputs; ++i) 213 | { 214 | write(d, mCopySize[i]); 215 | } 216 | ASSERT(d == a + getSerializationSize()); 217 | } 218 | 219 | // Attach the plugin object to an execution context and grant the plugin the access to some context resource. 220 | void FlattenConcatCustom::attachToContext( 221 | cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) 222 | { 223 | } 224 | 225 | // Detach the plugin object from its execution context. 226 | void FlattenConcatCustom::detachFromContext() {} 227 | 228 | // Return true if output tensor is broadcast across a batch. 229 | bool FlattenConcatCustom::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const 230 | { 231 | return false; 232 | } 233 | 234 | // Return true if plugin can use input that is broadcast across batch without replication. 235 | bool FlattenConcatCustom::canBroadcastInputAcrossBatch(int inputIndex) const 236 | { 237 | return false; 238 | } 239 | 240 | // Set plugin namespace 241 | void FlattenConcatCustom::setPluginNamespace(const char* pluginNamespace) 242 | { 243 | mPluginNamespace = pluginNamespace; 244 | } 245 | 246 | const char* FlattenConcatCustom::getPluginNamespace() const 247 | { 248 | return mPluginNamespace; 249 | } 250 | 251 | // Return the DataType of the plugin output at the requested index 252 | DataType FlattenConcatCustom::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const 253 | { 254 | ASSERT(index < 3); 255 | return DataType::kFLOAT; 256 | } 257 | 258 | void FlattenConcatCustom::configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, 259 | const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, 260 | const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) 261 | { 262 | ASSERT(nbOutputs == 1); 263 | mCHW = inputDims[0]; 264 | mNumInputs = nbInputs; 265 | ASSERT(inputDims[0].nbDims == 3); 266 | 267 | if (mInputConcatAxis == nullptr) 268 | { 269 | LOG_ERROR(cudaMallocHost((void**) &mInputConcatAxis, mNumInputs * sizeof(int))); 270 | } 271 | 272 | for (int i = 0; i < nbInputs; ++i) 273 | { 274 | int flattenInput = 0; 275 | ASSERT(inputDims[i].nbDims == 3); 276 | if (mConcatAxisID != 1) 277 | { 278 | ASSERT(inputDims[i].d[0] == inputDims[0].d[0]); 279 | } 280 | if (mConcatAxisID != 2) 281 | { 282 | ASSERT(inputDims[i].d[1] == inputDims[0].d[1]); 283 | } 284 | if (mConcatAxisID != 3) 285 | { 286 | ASSERT(inputDims[i].d[2] == inputDims[0].d[2]); 287 | } 288 | flattenInput = inputDims[i].d[0] * inputDims[i].d[1] * inputDims[i].d[2]; 289 | mInputConcatAxis[i] = flattenInput; 290 | mOutputConcatAxis += mInputConcatAxis[i]; 291 | } 292 | 293 | for (int i = 0; i < nbInputs; ++i) 294 | { 295 | mCopySize[i] = inputDims[i].d[0] * inputDims[i].d[1] * inputDims[i].d[2] * sizeof(float); 296 | } 297 | } 298 | 299 | bool FlattenConcatCustom::supportsFormat(DataType type, PluginFormat format) const 300 | { 301 | return (type == DataType::kFLOAT && format == PluginFormat::kNCHW); 302 | } 303 | const char* FlattenConcatCustom::getPluginType() const 304 | { 305 | return "FlattenConcatCustom"; 306 | } 307 | 308 | const char* FlattenConcatCustom::getPluginVersion() const 309 | { 310 | return "1"; 311 | } 312 | 313 | void FlattenConcatCustom::destroy() 314 | { 315 | delete this; 316 | } 317 | 318 | IPluginV2Ext* FlattenConcatCustom::clone() const 319 | { 320 | auto* plugin 321 | = new FlattenConcatCustom(mConcatAxisID, mIgnoreBatch, mNumInputs, mOutputConcatAxis, mInputConcatAxis, mCopySize); 322 | plugin->setPluginNamespace(mPluginNamespace); 323 | return plugin; 324 | } 325 | 326 | FlattenConcatCustomPluginCreator::FlattenConcatCustomPluginCreator() 327 | { 328 | mPluginAttributes.emplace_back(PluginField("axis", nullptr, PluginFieldType::kINT32, 1)); 329 | mPluginAttributes.emplace_back(PluginField("ignoreBatch", nullptr, PluginFieldType::kINT32, 1)); 330 | 331 | mFC.nbFields = mPluginAttributes.size(); 332 | mFC.fields = mPluginAttributes.data(); 333 | } 334 | 335 | const char* FlattenConcatCustomPluginCreator::getPluginName() const 336 | { 337 | return FlattenConcatCustom_PLUGIN_NAME; 338 | } 339 | 340 | const char* FlattenConcatCustomPluginCreator::getPluginVersion() const 341 | { 342 | return FlattenConcatCustom_PLUGIN_VERSION; 343 | } 344 | 345 | const PluginFieldCollection* FlattenConcatCustomPluginCreator::getFieldNames() 346 | { 347 | return &mFC; 348 | } 349 | 350 | IPluginV2Ext* FlattenConcatCustomPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) 351 | { 352 | const PluginField* fields = fc->fields; 353 | for (int i = 0; i < fc->nbFields; ++i) 354 | { 355 | const char* attrName = fields[i].name; 356 | if (!strcmp(attrName, "axis")) 357 | { 358 | ASSERT(fields[i].type == PluginFieldType::kINT32); 359 | mConcatAxisID = *(static_cast(fields[i].data)); 360 | } 361 | if (!strcmp(attrName, "ignoreBatch")) 362 | { 363 | ASSERT(fields[i].type == PluginFieldType::kINT32); 364 | mIgnoreBatch = *(static_cast(fields[i].data)); 365 | } 366 | } 367 | 368 | auto* plugin = new FlattenConcatCustom(mConcatAxisID, mIgnoreBatch); 369 | plugin->setPluginNamespace(mNamespace.c_str()); 370 | return plugin; 371 | } 372 | 373 | IPluginV2Ext* FlattenConcatCustomPluginCreator::deserializePlugin( 374 | const char* name, const void* serialData, size_t serialLength) 375 | { 376 | // This object will be deleted when the network is destroyed, which will 377 | // call Concat::destroy() 378 | IPluginV2Ext* plugin = new FlattenConcatCustom(serialData, serialLength); 379 | plugin->setPluginNamespace(mNamespace.c_str()); 380 | return plugin; 381 | } 382 | REGISTER_TENSORRT_PLUGIN(FlattenConcatCustomPluginCreator); 383 | 384 | 385 | -------------------------------------------------------------------------------- /flattenConcatCustom.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | #ifndef TRT_FlattenConcatCustom_PLUGIN_H 17 | #define TRT_FlattenConcatCustom_PLUGIN_H 18 | 19 | #include "NvInferPlugin.h" 20 | #include "plugin.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | 28 | #define LOG_ERROR(status) \ 29 | do \ 30 | { \ 31 | auto ret = (status); \ 32 | if (ret != 0) \ 33 | { \ 34 | std::cout << "Cuda failure: " << ret << std::endl; \ 35 | abort(); \ 36 | } \ 37 | } while (0) 38 | 39 | namespace nvinfer1 40 | { 41 | namespace plugin 42 | { 43 | class FlattenConcatCustom : public IPluginV2Ext 44 | { 45 | public: 46 | FlattenConcatCustom(int concatAxis, bool ignoreBatch); 47 | 48 | FlattenConcatCustom(int concatAxis, bool ignoreBatch, int numInputs, int outputConcatAxis, const int* inputConcatAxis, 49 | size_t* copySize); 50 | 51 | FlattenConcatCustom(const void* data, size_t length); 52 | 53 | ~FlattenConcatCustom() override; 54 | 55 | FlattenConcatCustom() = delete; 56 | 57 | int getNbOutputs() const override; 58 | 59 | Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override; 60 | 61 | int initialize() override; 62 | 63 | void terminate() override; 64 | 65 | size_t getWorkspaceSize(int) const override; 66 | 67 | int enqueue( 68 | int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override; 69 | 70 | DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const override; 71 | 72 | size_t getSerializationSize() const override; 73 | 74 | void serialize(void* buffer) const override; 75 | 76 | bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const override; 77 | 78 | bool canBroadcastInputAcrossBatch(int inputIndex) const override; 79 | 80 | void configurePlugin(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, 81 | const DataType* inputTypes, const DataType* outputTypes, const bool* inputIsBroadcast, 82 | const bool* outputIsBroadcast, PluginFormat floatFormat, int maxBatchSize) override; 83 | 84 | bool supportsFormat(DataType type, PluginFormat format) const override; 85 | 86 | void detachFromContext() override; 87 | 88 | const char* getPluginType() const override; 89 | 90 | const char* getPluginVersion() const override; 91 | 92 | void destroy() override; 93 | 94 | void attachToContext( 95 | cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) override; 96 | IPluginV2Ext* clone() const override; 97 | 98 | void setPluginNamespace(const char* pluginNamespace) override; 99 | 100 | const char* getPluginNamespace() const override; 101 | 102 | private: 103 | Weights copyToDevice(const void* hostData, size_t count); 104 | 105 | void serializeFromDevice(char*& hostBuffer, Weights deviceWeights) const; 106 | 107 | Weights deserializeToDevice(const char*& hostBuffer, size_t count); 108 | 109 | size_t* mCopySize = nullptr; 110 | bool mIgnoreBatch{false}; 111 | int mConcatAxisID{0}, mOutputConcatAxis{0}, mNumInputs{0}; 112 | int* mInputConcatAxis = nullptr; 113 | nvinfer1::Dims mCHW; 114 | const char* mPluginNamespace; 115 | cublasHandle_t mCublas; 116 | }; 117 | 118 | class FlattenConcatCustomPluginCreator : public BaseCreator 119 | { 120 | public: 121 | FlattenConcatCustomPluginCreator(); 122 | 123 | ~FlattenConcatCustomPluginCreator() override = default; 124 | 125 | const char* getPluginName() const override; 126 | 127 | const char* getPluginVersion() const override; 128 | 129 | const PluginFieldCollection* getFieldNames() override; 130 | 131 | IPluginV2Ext* createPlugin(const char* name, const PluginFieldCollection* fc) override; 132 | 133 | IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; 134 | 135 | private: 136 | static PluginFieldCollection mFC; 137 | bool mIgnoreBatch{false}; 138 | int mConcatAxisID; 139 | static std::vector mPluginAttributes; 140 | }; 141 | 142 | } // namespace plugin 143 | } // namespace nvinfer1 144 | 145 | #endif // TRT_FlattenConcatCustom_PLUGIN_H 146 | -------------------------------------------------------------------------------- /load_trt_engine.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "NvInfer.h" 6 | #include "NvUtils.h" 7 | #include "flattenConcatCustom.h" 8 | 9 | using namespace std; 10 | using namespace nvinfer1; 11 | 12 | class Logger : public nvinfer1::ILogger 13 | { 14 | public: 15 | Logger(Severity severity = Severity::kINFO) 16 | : reportableSeverity(severity) 17 | { 18 | } 19 | 20 | void log(Severity severity, const char* msg) override 21 | { 22 | // suppress messages with severity enum value greater than the reportable 23 | if (severity > reportableSeverity) 24 | return; 25 | 26 | switch (severity) 27 | { 28 | case Severity::kINTERNAL_ERROR: break; 29 | case Severity::kERROR: break; 30 | case Severity::kWARNING: break; 31 | case Severity::kINFO: break; 32 | default:break; 33 | } 34 | std::cerr << msg << std::endl; 35 | } 36 | 37 | Severity reportableSeverity; 38 | }; 39 | 40 | static Logger gLogger; 41 | 42 | int main(int argc, char ** argv){ 43 | std::cout<<"To read engine" << endl; 44 | std::ifstream ins; 45 | ins.open(argv[1], std::ofstream::binary); 46 | ins.seekg(0, std::ios::end); 47 | int slength = ins.tellg(); 48 | ins.seekg(0, std::ios::beg); 49 | char* sbuffer = new char[slength]; 50 | ins.read(sbuffer, sizeof(char)*slength); 51 | ins.close(); 52 | std::cout << "Read engine file " << argv[1] << ", size " << slength << endl; 53 | IRuntime* runtime = createInferRuntime(gLogger); 54 | ICudaEngine* engine = runtime->deserializeCudaEngine(sbuffer, slength, nullptr); 55 | std::cout<< "engine pointer " << engine << "batch size " << engine->getMaxBatchSize()<> inputs_shape; 58 | vector> outputs_shape; 59 | int nb_bindings = engine->getNbBindings(); 60 | std::cout<<"nb bindings" << nb_bindings << endl; 61 | for (int i = 0; i < nb_bindings; i++) { 62 | std::cout<<"index binding " << i << endl; 63 | auto dims = engine->getBindingDimensions(i); 64 | if (engine->bindingIsInput(i)) { 65 | vector shapes; 66 | for (int j = 0; j < dims.nbDims; ++j) { 67 | shapes.push_back(dims.d[j]); 68 | //std::cout< shapes; 73 | for (int j = 0; j < dims.nbDims; ++j) { 74 | shapes.push_back(dims.d[j]); 75 | } 76 | outputs_shape.push_back(shapes); 77 | } 78 | } 79 | std::cout << "input shapes" << endl; 80 | for(auto vec: inputs_shape){ 81 | for(auto v: vec){ 82 | std::cout << v << " "; 83 | } 84 | std::cout << endl; 85 | } 86 | std::cout << "output shapes" << endl; 87 | for(auto vec: outputs_shape){ 88 | for(auto v: vec){ 89 | std::cout << v << " "; 90 | } 91 | std::cout << endl; 92 | } 93 | return 0; 94 | 95 | } 96 | --------------------------------------------------------------------------------