├── .gitignore ├── CMakeLists.txt ├── FunctionLoadLog.md ├── README.md ├── main.cpp ├── p_test.py └── plugin ├── UpsampleKernel.cu ├── UpsamplePlugin.cpp ├── UpsamplePlugin.h └── UpsmapleKernel.h /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | .vscode/* -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # We need cmake >= 3.8, since 3.8 introduced CUDA as a first class language 2 | cmake_minimum_required(VERSION 3.8 FATAL_ERROR) 3 | project(UpsamplePlugin LANGUAGES CXX CUDA) 4 | 5 | # Enable all compile warnings 6 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-long-long -pedantic -Werror") 7 | 8 | # Sets variable to a value if variable is unset. 9 | macro(set_ifndef var val) 10 | if (NOT ${var}) 11 | set(${var} ${val}) 12 | endif() 13 | message(STATUS "Configurable variable ${var} set to ${${var}}") 14 | endmacro() 15 | 16 | # -------- CONFIGURATION -------- 17 | set_ifndef(TRT_LIB /usr/lib/x86_64-linux-gnu) 18 | set_ifndef(TRT_INCLUDE /usr/include/x86_64-linux-gnu) 19 | 20 | # Find dependencies: 21 | message("\nThe following variables are derived from the values of the previous variables unless provided explicitly:\n") 22 | 23 | # TensorRT's nvinfer lib 24 | find_library(_NVINFER_LIB nvinfer HINTS ${TRT_LIB} PATH_SUFFIXES lib lib64) 25 | set_ifndef(NVINFER_LIB ${_NVINFER_LIB}) 26 | 27 | # -------- BUILDING -------- 28 | 29 | # Add include directories 30 | include_directories(${CUDA_INC_DIR} ${TRT_INCLUDE} ${CMAKE_SOURCE_DIR}/plugin/) 31 | 32 | # Define clip plugin library target 33 | add_library(upsampleplugin MODULE 34 | ${CMAKE_SOURCE_DIR}/plugin/UpsampleKernel.cu 35 | ${CMAKE_SOURCE_DIR}/plugin/UpsamplePlugin.cpp 36 | ${CMAKE_SOURCE_DIR}/plugin/UpsmapleKernel.h 37 | ${CMAKE_SOURCE_DIR}/plugin/UpsamplePlugin.h 38 | ) 39 | 40 | # Use C++11 41 | target_compile_features(upsampleplugin PUBLIC cxx_std_11) 42 | 43 | # Link TensorRT's nvinfer lib 44 | target_link_libraries(upsampleplugin PRIVATE ${NVINFER_LIB}) 45 | 46 | # We need to explicitly state that we need all CUDA files 47 | # to be built with -dc as the member functions will be called by 48 | # other libraries and executables (in our case, Python inference scripts) 49 | set_target_properties(upsampleplugin PROPERTIES 50 | CUDA_SEPARABLE_COMPILATION ON 51 | ) 52 | -------------------------------------------------------------------------------- /FunctionLoadLog.md: -------------------------------------------------------------------------------- 1 | Load .so need to run some Creator's functions, as follow: 2 | ``` 3 | UpsamplePluginCreator::UpsamplePluginCreator 4 | UpsamplePluginCreator::getPluginName 5 | UpsamplePluginCreator::getPluginName 6 | UpsamplePluginCreator::getPluginNamespace 7 | UpsamplePluginCreator::getPluginVersion 8 | ``` 9 | 10 | when you build a engine, tensorrt will call those function: 11 | 12 | ``` 13 | UpsamplePluginCreator::getPluginName 14 | UpsamplePluginCreator::createPlugin 15 | UpsamplePlugin::UpsamplePlugin1 16 | UpsamplePlugin::getNbOutputs 17 | UpsamplePlugin::UpsamplePlugin1 18 | UpsamplePlugin::getNbOutputs 19 | UpsamplePlugin::getOutputDimensions 20 | UpsamplePlugin::UpsamplePlugin1 21 | UpsamplePlugin::supportsFormat 22 | UpsamplePlugin::supportsFormat 23 | UpsamplePlugin::supportsFormat 24 | UpsamplePlugin::supportsFormat 25 | UpsamplePlugin::supportsFormat 26 | UpsamplePlugin::UpsamplePlugin1 27 | UpsamplePlugin::supportsFormat 28 | UpsamplePlugin::configureWithFormat 29 | UpsamplePlugin::initialize 30 | UpsamplePlugin::enqueue 31 | ``` 32 | 33 | when you save engine to a file, tensorrt will call those functions: 34 | 35 | ``` 36 | UpsamplePluginCreator::getPluginName 37 | UpsamplePluginCreator::createPlugin 38 | UpsamplePlugin::UpsamplePlugin1 39 | UpsamplePlugin::getNbOutputs 40 | UpsamplePlugin::UpsamplePlugin1 41 | UpsamplePlugin::getNbOutputs 42 | UpsamplePlugin::getOutputDimensions 43 | UpsamplePlugin::UpsamplePlugin1 44 | UpsamplePlugin::supportsFormat 45 | UpsamplePlugin::supportsFormat 46 | UpsamplePlugin::supportsFormat 47 | UpsamplePlugin::supportsFormat 48 | UpsamplePlugin::supportsFormat 49 | UpsamplePlugin::UpsamplePlugin1 50 | UpsamplePlugin::supportsFormat 51 | UpsamplePlugin::configureWithFormat 52 | UpsamplePlugin::initialize 53 | UpsamplePlugin::getPluginType 54 | UpsamplePlugin::getPluginVersion 55 | UpsamplePlugin::getSerializationSize 56 | UpsamplePlugin::serialize 57 | UpsamplePlugin::getSerializationSize 58 | UpsamplePlugin::getSerializationSize 59 | ``` 60 | 61 | when you load a engine, tensorrt will load those functions: 62 | ``` 63 | UpsamplePluginCreator::getPluginVersion 64 | UpsamplePluginCreator::getPluginNamespace 65 | UpsamplePluginCreator::deserializePlugin 66 | UpsamplePlugin::UpsamplePlugin2 67 | UpsamplePlugin::initialize 68 | UpsamplePlugin::enqueue 69 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | UpsamplingPlugin 2 | 3 | This Plugin only support bilinear upsample for now. 4 | 5 | # How? 6 | 7 | ## 如何创建一个Layer Plugin? 8 | 9 | 一个Layer Plugin的实现必须包含两个class。一个是Plugin的实体,一个是Plugin的Creator。 10 | 11 | 具体要怎么实现,下面具体叙述。 12 | 13 | + Creator类,继承自IPluginCreator。 14 | + 构建函数:一般是初始化两个参数,mFC和mPluginAttributes。 15 | + getPluginName和getPluginVersion函数: 注册Creator所对应的Plugin的名称和版本。 16 | + getFieldNames函数:放回mFC变量。 17 | + creatorPlugin函数: 解析PluginFieldCollection,获取Plugin需要的参数。构建Plugin对象。 18 | + deserializePlugin函数: 从模型文件中载入engine的时候,构建Plugin所需要调用的函数。这个函数实现直接新建Plugin对象就行。 19 | + setPluginNamespace和getPluginNamespace,没有发现在什么阶段有调用这两个函数。 20 | + Plugin类,继承自IPluginV2。 21 | + 两个构建函数,一个用于在构建网络的时候,构建Plugin;另外一个是用于在反序列化的时候构建Plugin。(注意的是:如果这个层不支持没有参数输入的话,那么可以执行`UpsamplePlugin() = delete`删除默认构建函数) 22 | + getNbOutputs函数: 通常一个层是单输入单输出,所以这个函数直接`return 1`就行。 23 | + getOutputDimensions函数: 放回一个nvinfer1::Dims对象(或者它的子类),具体的值根据实际设置。 24 | + initialize,暂时没有用到。 25 | + teminate,暂时没有用到。 26 | + getWorkspaceSize,暂时没有用到,不清楚。 27 | + enqueue函数:用于正向传播,通常在这个函数调用cuda kernel。正常执行函数return 0。 28 | + getSerializationSize函数:当模型serialize的时候,需要保存到文件参数所占的具体空间。 29 | + serialize函数: 将Plugin参数保存至文件。 30 | + destroy函数: 删除当前类的this指针 31 | 32 | 33 | 34 | ## 如何让Tensorrt感知新的Layer Plugin? 35 | 36 | 在Plugin实现文件中调用`REGISTER_TENSORRT_PLUGIN`这个宏,用于注册一个Plugin Creator。例如: 37 | ``` 38 | REGISTER_TENSORRT_PLUGIN(UpsamplePluginCreator); 39 | ``` 40 | 41 | 有了这个宏,当在.py文件中调用xxplugin.so的时候就会自动执行这个语句,然后就会在tensorrt中注册UpsamplePluginCreator的信息,可以用于创建新的Plugin,实际的效果就是在`trt.get_plugin_registry().plugin_creator_list`添加了一个`UpsamplePluginCreator`。 42 | 43 | ## Tensorrt如何调用一个Plugin? 44 | 45 | ### python的调用方式: 46 | 47 | 1. 获取tensorrt中的creator列表。代码如下: 48 | ``` 49 | trt.init_libnvinfer_plugins(TRT_LOGGER, '') 50 | PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list 51 | ``` 52 | 2. 有了上面的列表,就可以根据名字匹配相应的 Plugin Creator,并且传入相应的参数,构建对应的plugin。代码如下: 53 | ``` 54 | def get_upsample_plugin(plugin_name, sacle_factor=2, align_corners=False): 55 | plugin = None 56 | for plugin_creator in PLUGIN_CREATORS: 57 | if plugin_creator.name == plugin_name: 58 | scale_factor_field = trt.PluginField("scaleFactor", np.array([sacle_factor], dtype=np.int8), trt.PluginFieldType.INT8) 59 | align_corners_field = trt.PluginField("alignCorners", np.array([int(align_corners)], dtype=np.int8), trt.PluginFieldType.INT8) 60 | field_collection = trt.PluginFieldCollection([align_corners_field, scale_factor_field]) 61 | plugin = plugin_creator.create_plugin(name=plugin_name, field_collection=field_collection) 62 | return plugin 63 | ``` 64 | + Note: 参数的载入tensorrt使用的是trt.PluginField,第一个参数是名字,第二个是参数的内存地址(buffer类型, 一般用numpy来实现),第三个是类型。名字和类型必须跟你在Creator中使用的一样,不然报错。 65 | 3. 创建好了Plugin,就可以用`network.add_plugin_v2`调用了。代码如下: 66 | ``` 67 | upsample_layer = network.add_plugin_v2(inputs=[inputs], plugin=get_upsample_plugin("UpsamplePlugin", sacle_factor, align_corners)) 68 | ``` 69 | 70 | 71 | ### C++的调用方式: 72 | 73 | TODO 74 | 75 | 76 | 77 | # 附录 78 | 79 | + tensorrt在构建初次build engine以及engine serialize和deserialize的时候,调用的那些plugin的参数,可以参考`FunctionLoadLog.md` 80 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "UpsamplePlugin.h" 3 | 4 | int main() 5 | { 6 | 7 | plugin_test(); 8 | 9 | return 0; 10 | } -------------------------------------------------------------------------------- /p_test.py: -------------------------------------------------------------------------------- 1 | import ctypes 2 | import os 3 | 4 | import pycuda.driver as cuda 5 | import pycuda.autoinit 6 | import tensorrt as trt 7 | import numpy as np 8 | 9 | TRT_LOGGER = trt.Logger() 10 | 11 | 12 | CLIP_PLUGIN_LIBRARY = os.path.join( 13 | os.path.dirname(os.path.realpath(__file__)), 14 | 'build/libupsampleplugin.so' 15 | ) 16 | # ctypes.CDLL(CLIP_PLUGIN_LIBRARY) 17 | 18 | # CDLL("/usr/lib/x86_64-linux-gnu/libgomp.so.1", mode=RTLD_GLOBAL) 19 | lib = ctypes.cdll.LoadLibrary(CLIP_PLUGIN_LIBRARY) 20 | 21 | trt.init_libnvinfer_plugins(TRT_LOGGER, '') 22 | PLUGIN_CREATORS = trt.get_plugin_registry().plugin_creator_list 23 | # for s_creator in PLUGIN_CREATORS: 24 | # print(s_creator.name) 25 | 26 | def allocate_buffers(engine): 27 | # Determine dimensions and create page-locked memory buffers (i.e. won't be swapped to disk) to hold host inputs/outputs. 28 | print("binding shape: {}, {}",engine.get_binding_shape(0),engine.get_binding_shape(1)) 29 | h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)), dtype=trt.nptype(trt.float32)) 30 | h_output = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)), dtype=trt.nptype(trt.float32)) 31 | # Allocate device memory for inputs and outputs. 32 | d_input = cuda.mem_alloc(h_input.nbytes) 33 | d_output = cuda.mem_alloc(h_output.nbytes) 34 | # Create a stream in which to copy inputs/outputs and run inference. 35 | stream = cuda.Stream() 36 | return h_input, d_input, h_output, d_output, stream 37 | 38 | 39 | def get_trt_plugin(plugin_name, sacle_factor=2, align_corners=False): 40 | plugin = None 41 | for plugin_creator in PLUGIN_CREATORS: 42 | if plugin_creator.name == plugin_name: 43 | scale_factor_field = trt.PluginField("scaleFactor", np.array([sacle_factor], dtype=np.int8), trt.PluginFieldType.INT8) 44 | align_corners_field = trt.PluginField("alignCorners", np.array([int(align_corners)], dtype=np.int8), trt.PluginFieldType.INT8) 45 | field_collection = trt.PluginFieldCollection([align_corners_field, scale_factor_field]) 46 | plugin = plugin_creator.create_plugin(name=plugin_name, field_collection=field_collection) 47 | return plugin 48 | 49 | 50 | 51 | 52 | def build_engine(): 53 | with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.CaffeParser() as parser: 54 | builder.max_batch_size = 1 55 | builder.max_workspace_size = 2**20 56 | input_layer = network.add_input(name="input_layer", dtype=trt.float32, shape=(1, 13, 3, 3)) 57 | # bn_w = [] 58 | # bn = network.add_scale(input=[input_layer], mode=trt.ScaleMode.CHANNEL, ) 59 | upsample = network.add_plugin_v2(inputs=[input_layer], plugin=get_trt_plugin("UpsamplePlugin")) 60 | upsample.get_output(0).name = "outputs" 61 | network.mark_output(upsample.get_output(0)) 62 | 63 | return builder.build_cuda_engine(network) 64 | 65 | 66 | def do_inference(context, h_input, d_input, h_output, d_output, stream): 67 | # Transfer input data to the GPU. 68 | cuda.memcpy_htod_async(d_input, h_input, stream) 69 | # Run inference. 70 | context.execute_async(bindings=[int(d_input), int(d_output)], stream_handle=stream.handle) 71 | # Transfer predictions back from the GPU. 72 | cuda.memcpy_dtoh_async(h_output, d_output, stream) 73 | # Synchronize the stream 74 | stream.synchronize() 75 | 76 | 77 | def main(): 78 | 79 | arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 80 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 81 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 82 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 83 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 84 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 85 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 86 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 87 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 88 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 89 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 90 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, \ 91 | 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0 92 | ]) 93 | print(arr) 94 | 95 | 96 | # with build_engine() as engine: 97 | # # Build an engine, allocate buffers and create a stream. 98 | # # For more information on buffer allocation, refer to the introductory samples. 99 | # h_input, d_input, h_output, d_output, stream = allocate_buffers(engine) 100 | # np.copyto(h_input, arr) 101 | # # print("debug") 102 | # with engine.create_execution_context() as context: 103 | # do_inference(context, h_input, d_input, h_output, d_output, stream) 104 | # print(h_output) 105 | 106 | # save_engine = os.path.join(os.path.dirname(__file__), "sample.engine") 107 | # with build_engine() as engine: 108 | # with open(save_engine, "wb") as f: 109 | # f.write(engine.serialize()) 110 | 111 | save_engine = os.path.join(os.path.dirname(__file__), "sample.engine") 112 | with open(save_engine, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime: 113 | engine = runtime.deserialize_cuda_engine(f.read()) 114 | h_input, d_input, h_output, d_output, stream = allocate_buffers(engine) 115 | np.copyto(h_input, arr) 116 | # print("debug") 117 | with engine.create_execution_context() as context: 118 | do_inference(context, h_input, d_input, h_output, d_output, stream) 119 | print(h_output) 120 | 121 | if __name__ == "__main__": 122 | main() -------------------------------------------------------------------------------- /plugin/UpsampleKernel.cu: -------------------------------------------------------------------------------- 1 | #include "UpsmapleKernel.h" 2 | 3 | 4 | /** 5 | * @brief caculate the number of cuda kernel for upsample. (Cite from: 《GPU高性能编程CUDA实战》P46,P47) 6 | * 7 | * @param total_thread_num: the number of cuda thread of you want to used for upsample 8 | * @param max_thread_num: the gpu device property 9 | * @return int the number of cuda kernel for upsample 10 | */ 11 | int get_kernel_num(int total_thread_num, int max_thread_num) 12 | { 13 | return (total_thread_num + max_thread_num - 1)/max_thread_num; 14 | } 15 | 16 | int get_max_thread_num() 17 | { 18 | cudaDeviceProp prop; 19 | cudaGetDeviceProperties(&prop, 0); 20 | return prop.maxThreadsPerBlock; 21 | } 22 | 23 | __host__ __forceinline__ float linear_upsampling_compute_scale(int input_size, int output_size, bool align_corners) 24 | { 25 | return align_corners ? float(input_size - 1)/float(output_size - 1) : (float(input_size)/float(output_size)); 26 | } 27 | 28 | __device__ __forceinline__ float linear_upsampling_compute_source_index(float scale, int dst_index, int intput_size, bool align_corners) 29 | { 30 | if(align_corners) 31 | { 32 | return scale * dst_index; 33 | } 34 | else 35 | { 36 | float src_idx = scale * (dst_index + 0.5)-0.5; 37 | return (src_idx>=0) ? src_idx : 0; 38 | } 39 | } 40 | 41 | 42 | __device__ __forceinline__ int get_index(const int batch_idx, const int channel_idx, const int height_idx, const int width_idx, 43 | const int batch_total, const int channel_total, const int width) 44 | { 45 | int ret_idx = batch_idx * batch_total 46 | + channel_idx * channel_total 47 | + height_idx * width 48 | + width_idx; 49 | return ret_idx; 50 | } 51 | 52 | /** 53 | * @brief 54 | * 55 | * @tparam T 56 | * @param n 57 | * @param input_shape: input data shape. such as [batch, channel, height, width] 58 | * @param rate_h 59 | * @param rate_w 60 | * @param inputs 61 | * @param outputs 62 | * @return __global__ BilinearKernel 63 | * @TODO: 64 | * 65 | */ 66 | 67 | 68 | template 69 | __global__ void BilinearKernel( 70 | const int n, 71 | int input_b, 72 | int input_c, 73 | int input_h, 74 | int input_w, 75 | int output_h, 76 | int output_w, 77 | const float rate_h, 78 | const float rate_w, 79 | bool align_corners, 80 | const T* inputs, 81 | T* outputs) 82 | { 83 | 84 | int index = threadIdx.x + blockIdx.x * blockDim.x; 85 | if(index < n) 86 | { 87 | const int w2 = index % output_w; 88 | const int h2 = index / output_w; 89 | 90 | 91 | const float h1r = linear_upsampling_compute_source_index(rate_h, h2, input_h, align_corners); 92 | const int h1 = int(h1r); 93 | const int h1p = (h1 < input_h - 1) ? 1 : 0; 94 | const float h1lambda = h1r - h1; 95 | const float h0lambda = 1 - h1lambda; 96 | 97 | const float w1r = linear_upsampling_compute_source_index(rate_w, w2, input_w, align_corners); 98 | const int w1 = int(w1r); 99 | const int w1p = (w1 < input_w - 1) ? 1 : 0; 100 | const float w1lambda = w1r - w1; 101 | const float w0lambda = 1 - w1lambda; 102 | 103 | int s_batch_total_1 = input_c * input_h * input_w; 104 | int s_channel_total_1 = input_h * input_w; 105 | 106 | int s_batch_total_2 = input_c * output_h * output_w; 107 | int s_channel_total_2 = output_h * output_w; 108 | 109 | 110 | const int batch_size = input_b; 111 | const int channel_size = input_c; 112 | 113 | for(int b_idx=0; b_idx<<< kernel_num, max_threads, 0, stream>>>(n,input_b,input_c,input_h,input_w, 150 | output_h, output_w, 151 | rate_h, rate_w, align_corners, 152 | static_cast(inputs), 153 | static_cast(outputs)); 154 | return 0; 155 | } 156 | -------------------------------------------------------------------------------- /plugin/UpsamplePlugin.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "UpsmapleKernel.h" 3 | #include "UpsamplePlugin.h" 4 | 5 | #include 6 | #include 7 | 8 | using namespace nvinfer1; 9 | 10 | // Upsample plugin specific constants 11 | namespace { 12 | static const char* UPSAMPLE_PLUGIN_VERSION{"1"}; 13 | static const char* UPSAMPLE_PLUGIN_NAME{"UpsamplePlugin"}; 14 | } 15 | 16 | // Static class fields initialization 17 | PluginFieldCollection UpsamplePluginCreator::mFC{}; 18 | std::vector UpsamplePluginCreator::mPluginAttributes; 19 | 20 | REGISTER_TENSORRT_PLUGIN(UpsamplePluginCreator); 21 | 22 | // Helper function for serializing plugin 23 | template 24 | void writeToBuffer(char*& buffer, const T& val) 25 | { 26 | *reinterpret_cast(buffer) = val; 27 | buffer += sizeof(T); 28 | } 29 | 30 | // Helper function for deserializing plugin 31 | template 32 | T readFromBuffer(const char*& buffer) 33 | { 34 | T val = *reinterpret_cast(buffer); 35 | buffer += sizeof(T); 36 | return val; 37 | } 38 | 39 | UpsamplePlugin::UpsamplePlugin(const std::string name, int scale_factor, bool align_corners) 40 | : mLayerName(name) 41 | , mAlignCorners(align_corners) 42 | , mScaleFactor(scale_factor) 43 | { 44 | // printf("UpsamplePlugin::UpsamplePlugin1\n"); 45 | mInputShape.c() = -1; 46 | mInputShape.h() = -1; 47 | mInputShape.w() = -1; 48 | mInputVolume = 0; 49 | } 50 | 51 | UpsamplePlugin::UpsamplePlugin(const std::string name, const void* data, size_t length) 52 | : mLayerName(name) 53 | { 54 | //printf("UpsamplePlugin::UpsamplePlugin2\n"); 55 | // Deserialize in the same order as serialization 56 | const char *d = static_cast(data); 57 | const char *a = d; 58 | 59 | mScaleFactor = readFromBuffer(d); 60 | mAlignCorners = readFromBuffer(d); 61 | 62 | mInputVolume = readFromBuffer(d); 63 | mInputShape.c() = readFromBuffer(d); 64 | mInputShape.h() = readFromBuffer(d); 65 | mInputShape.w() = readFromBuffer(d); 66 | 67 | // writeToBuffer(d, mInputVolume); 68 | // writeToBuffer(d, mInputShape.c()); 69 | // writeToBuffer(d, mInputShape.h()); 70 | // writeToBuffer(d, mInputShape.w()); 71 | 72 | 73 | // mInputShape.c() = -1; 74 | // mInputShape.h() = -1; 75 | // mInputShape.w() = -1; 76 | // mInputVolume = 0; 77 | //printf("length: %d\n", int(length)); 78 | assert(d == (a + length)); 79 | 80 | } 81 | 82 | const char* UpsamplePlugin::getPluginType() const 83 | { 84 | //printf("UpsamplePlugin::getPluginType\n"); 85 | return UPSAMPLE_PLUGIN_NAME; 86 | } 87 | 88 | const char* UpsamplePlugin::getPluginVersion() const 89 | { 90 | //printf("UpsamplePlugin::getPluginVersion\n"); 91 | return UPSAMPLE_PLUGIN_VERSION; 92 | } 93 | 94 | int UpsamplePlugin::getNbOutputs() const 95 | { 96 | //printf("UpsamplePlugin::getNbOutputs\n"); 97 | return 1; 98 | } 99 | 100 | Dims UpsamplePlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) 101 | { 102 | //printf("UpsamplePlugin::getOutputDimensions\n"); 103 | assert(index == 0); 104 | assert(nbInputDims == 1); 105 | assert(inputs[0].nbDims == 4); 106 | 107 | return nvinfer1::DimsNCHW{inputs[0].d[0], inputs[0].d[1], inputs[0].d[2]*mScaleFactor, inputs[0].d[3]*mScaleFactor}; 108 | } 109 | 110 | int UpsamplePlugin::initialize() 111 | { 112 | //printf("UpsamplePlugin::initialize\n"); 113 | return 0; 114 | } 115 | 116 | 117 | int UpsamplePlugin::enqueue(int batchSize, const void* const* inputs, void** outputs, void*, cudaStream_t stream) 118 | { 119 | //printf("UpsamplePlugin::enqueue\n"); 120 | int status = -1; 121 | 122 | // Our plugin outputs only one tensor 123 | void* output = outputs[0]; 124 | 125 | // Launch CUDA kernel wrapper and save its return value 126 | status = UpsampleInference(stream, mInputVolume, 127 | batchSize, mInputShape.c(), mInputShape.h(), mInputShape.w(), 128 | mScaleFactor, mAlignCorners, 129 | inputs[0], output); 130 | 131 | return status; 132 | } 133 | 134 | size_t UpsamplePlugin::getSerializationSize() const 135 | { 136 | //printf("UpsamplePlugin::getSerializationSize\n"); 137 | return sizeof(mScaleFactor) + sizeof(mAlignCorners) + 138 | sizeof(mInputVolume) + sizeof(mInputShape.c()) + 139 | sizeof(mInputShape.h()) + sizeof(mInputShape.w()); 140 | } 141 | 142 | 143 | void UpsamplePlugin::serialize(void* buffer) const 144 | { 145 | //printf("UpsamplePlugin::serialize\n"); 146 | char *d = static_cast(buffer); 147 | const char *a = d; 148 | 149 | writeToBuffer(d, mScaleFactor); 150 | writeToBuffer(d, mAlignCorners); 151 | writeToBuffer(d, mInputVolume); 152 | writeToBuffer(d, mInputShape.c()); 153 | writeToBuffer(d, mInputShape.h()); 154 | writeToBuffer(d, mInputShape.w()); 155 | 156 | //printf("------getSerializationSize: %d\n",int(getSerializationSize())); 157 | 158 | assert(d == a + getSerializationSize()); 159 | } 160 | 161 | void UpsamplePlugin::configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputs, int nbOutputs, DataType type, PluginFormat format, int) 162 | { 163 | //printf("-----UpsamplePlugin::configureWithFormat\n"); 164 | // Validate input arguments 165 | assert(nbOutputs == 1); 166 | assert(type == DataType::kFLOAT); 167 | assert(format == PluginFormat::kNCHW); 168 | assert(inputs[0].nbDims == 4); 169 | // Fetch volume for future enqueue() operations 170 | size_t volume = inputs[0].d[2]*mScaleFactor * inputs[0].d[3]*mScaleFactor; 171 | mInputVolume = volume; 172 | mInputShape.c() = inputs[0].d[1]; 173 | mInputShape.h() = inputs[0].d[2]; 174 | mInputShape.w() = inputs[0].d[3]; 175 | } 176 | 177 | bool UpsamplePlugin::supportsFormat(DataType type, PluginFormat format) const 178 | { 179 | //printf("UpsamplePlugin::supportsFormat\n"); 180 | // This plugin only supports ordinary floats, and NCHW input format 181 | if (type == DataType::kFLOAT && format == PluginFormat::kNCHW) 182 | return true; 183 | else 184 | return false; 185 | } 186 | 187 | void UpsamplePlugin::terminate() {} 188 | 189 | void UpsamplePlugin::destroy() { 190 | // This gets called when the network containing plugin is destroyed 191 | delete this; 192 | } 193 | 194 | IPluginV2* UpsamplePlugin::clone() const 195 | { 196 | return new UpsamplePlugin(mLayerName, mScaleFactor, mAlignCorners); 197 | } 198 | 199 | void UpsamplePlugin::setPluginNamespace(const char* libNamespace) 200 | { 201 | mNamespace = libNamespace; 202 | } 203 | 204 | const char* UpsamplePlugin::getPluginNamespace() const 205 | { 206 | return mNamespace.c_str(); 207 | } 208 | 209 | UpsamplePluginCreator::UpsamplePluginCreator() 210 | { 211 | //printf("UpsamplePluginCreator::UpsamplePluginCreator\n"); 212 | // Describe UpsamplePlugin's required PluginField arguments 213 | mPluginAttributes.emplace_back(PluginField("scaleFactor", nullptr, PluginFieldType::kINT8, 1)); 214 | mPluginAttributes.emplace_back(PluginField("alignCorners", nullptr, PluginFieldType::kINT8, 1)); 215 | 216 | // Fill PluginFieldCollection with PluginField arguments metadata 217 | mFC.nbFields = mPluginAttributes.size(); 218 | mFC.fields = mPluginAttributes.data(); 219 | } 220 | const char* UpsamplePluginCreator::getPluginName() const 221 | { 222 | //printf("UpsamplePluginCreator::getPluginName\n"); 223 | return UPSAMPLE_PLUGIN_NAME; 224 | } 225 | 226 | const char* UpsamplePluginCreator::getPluginVersion() const 227 | { 228 | //printf("UpsamplePluginCreator::getPluginVersion\n"); 229 | return UPSAMPLE_PLUGIN_VERSION; 230 | } 231 | 232 | const PluginFieldCollection* UpsamplePluginCreator::getFieldNames() 233 | { 234 | //printf("UpsamplePluginCreator::getFieldNames\n"); 235 | return &mFC; 236 | } 237 | 238 | IPluginV2* UpsamplePluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) 239 | { 240 | //printf("UpsamplePluginCreator::createPlugin\n"); 241 | int scaleFactor = 0; 242 | bool alignCorners = false; 243 | const PluginField* fields = fc->fields; 244 | 245 | // Parse fields from PluginFieldCollection 246 | assert(fc->nbFields == 2); 247 | for (int i = 0; i < fc->nbFields; i++){ 248 | 249 | if (strcmp(fields[i].name, "scaleFactor") == 0) { 250 | assert(fields[i].type == PluginFieldType::kINT8); 251 | scaleFactor = *(static_cast(fields[i].data)); 252 | } 253 | else if (strcmp(fields[i].name, "alignCorners") == 0) { 254 | assert(fields[i].type == PluginFieldType::kINT8); 255 | alignCorners = bool(*(static_cast(fields[i].data))); 256 | 257 | } 258 | } 259 | return new UpsamplePlugin(name, scaleFactor, alignCorners); 260 | } 261 | 262 | IPluginV2* UpsamplePluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) 263 | { 264 | //printf("UpsamplePluginCreator::deserializePlugin\n"); 265 | // This object will be deleted when the network is destroyed, which will 266 | // call UpsamplePlugin::destroy() 267 | return new UpsamplePlugin(name, serialData, serialLength); 268 | } 269 | 270 | void UpsamplePluginCreator::setPluginNamespace(const char* libNamespace) 271 | { 272 | //printf("UpsamplePluginCreator::setPluginNamespace\n"); 273 | mNamespace = libNamespace; 274 | } 275 | 276 | const char* UpsamplePluginCreator::getPluginNamespace() const 277 | { 278 | //printf("UpsamplePluginCreator::getPluginNamespace\n"); 279 | return mNamespace.c_str(); 280 | } 281 | -------------------------------------------------------------------------------- /plugin/UpsamplePlugin.h: -------------------------------------------------------------------------------- 1 | #ifndef UPSAMPLE_PLUGIN_H 2 | #define UPSAMPLE_PLUGIN_H 3 | 4 | #include "NvInferPlugin.h" 5 | #include 6 | #include 7 | 8 | 9 | using namespace nvinfer1; 10 | 11 | class UpsamplePlugin : public IPluginV2 12 | { 13 | public: 14 | UpsamplePlugin(const std::string name, int scale_factor=2, bool align_corners=0); 15 | 16 | UpsamplePlugin(const std::string name, const void* data, size_t length); 17 | 18 | // It doesn't make sense to make UpsamplePlugin without arguments, so we delete default constructor. 19 | UpsamplePlugin() = delete; 20 | 21 | int getNbOutputs() const override; 22 | 23 | Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override; 24 | 25 | int initialize() override; 26 | 27 | void terminate() override; 28 | 29 | size_t getWorkspaceSize(int) const override { return 0; }; 30 | 31 | int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream) override; 32 | 33 | size_t getSerializationSize() const override; 34 | 35 | void serialize(void* buffer) const override; 36 | 37 | void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override; 38 | 39 | bool supportsFormat(DataType type, PluginFormat format) const override; 40 | 41 | const char* getPluginType() const override; 42 | 43 | const char* getPluginVersion() const override; 44 | 45 | void destroy() override; 46 | 47 | nvinfer1::IPluginV2* clone() const override; 48 | 49 | void setPluginNamespace(const char* pluginNamespace) override; 50 | 51 | const char* getPluginNamespace() const override; 52 | 53 | private: 54 | const std::string mLayerName; 55 | bool mAlignCorners; 56 | int mScaleFactor; 57 | size_t mInputVolume; 58 | DimsCHW mInputShape; 59 | std::string mNamespace; 60 | }; 61 | 62 | class UpsamplePluginCreator : public IPluginCreator 63 | { 64 | public: 65 | UpsamplePluginCreator(); 66 | 67 | const char* getPluginName() const override; 68 | 69 | const char* getPluginVersion() const override; 70 | 71 | const PluginFieldCollection* getFieldNames() override; 72 | 73 | IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) override; 74 | 75 | IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) override; 76 | 77 | void setPluginNamespace(const char* pluginNamespace) override; 78 | 79 | const char* getPluginNamespace() const override; 80 | 81 | private: 82 | static PluginFieldCollection mFC; 83 | static std::vector mPluginAttributes; 84 | std::string mNamespace; 85 | }; 86 | 87 | #endif -------------------------------------------------------------------------------- /plugin/UpsmapleKernel.h: -------------------------------------------------------------------------------- 1 | #ifndef UPSAMPLE_KERNEL_H 2 | #define UPSAMPLE_KERNEL_H 3 | 4 | #include 5 | #include "NvInfer.h" 6 | 7 | int UpsampleInference( 8 | cudaStream_t stream, 9 | int n, 10 | int input_b, 11 | int input_c, 12 | int input_h, 13 | int input_w, 14 | int scale_factor, 15 | bool align_corners, 16 | const void* inputs, 17 | void* outputs); 18 | 19 | 20 | #endif 21 | --------------------------------------------------------------------------------