├── README.md ├── co_detr_add_nms.py ├── co_detr_trt.cpp ├── codetr.cpp ├── codetr.h ├── linux_cc ├── co_detr │ ├── CMakeLists.txt │ ├── src │ │ ├── co_detr_add_nms.py │ │ ├── co_detr_trt.cpp │ │ ├── codetr.cpp │ │ ├── codetr.h │ │ └── logging.h │ └── test_img │ │ ├── bus.jpg │ │ └── zidane.jpg └── plugin │ ├── Makefile │ ├── Makefile.inc │ ├── common │ ├── common_cuda_helper.hpp │ ├── trt_plugin_base.hpp │ ├── trt_plugin_helper.hpp │ └── trt_serialize.hpp │ ├── trt_grid_sampler_kernel.cu │ └── trt_grid_sampler_kernel.hpp ├── plugin ├── CMakeLists.txt ├── common │ ├── common_cuda_helper.hpp │ ├── trt_plugin_base.hpp │ ├── trt_plugin_helper.hpp │ └── trt_serialize.hpp └── grid_sampler │ ├── trt_grid_sampler.cpp │ ├── trt_grid_sampler.hpp │ ├── trt_grid_sampler_kernel.cu │ └── trt_grid_sampler_kernel.hpp └── test_res ├── final.jpg ├── final_0.jpg └── static ├── cmd.png ├── v1.PNG └── v2.PNG /README.md: -------------------------------------------------------------------------------- 1 | ## Co-DETR TensorRT 模型端到端加速推理的C++实现 2 | 3 | 徐静 4 | 5 | ### 0. 环境配置说明 6 | 7 | + Ubuntu16.04下安装mmdetection, mmdeploy, 其依赖mmcv和mmengine 8 | 9 | ```shell 10 | # mmdetection==3.3.0 11 | git clone -b 3.3.0 https://github.com/open-mmlab/mmdetection 12 | pip install -v -e . 13 | 14 | # mmcv 15 | pip install mmcv==2.0.0 16 | 17 | # mmdeploy 18 | # https://github.com/TommyZihao/MMDeploy_Tutorials 19 | git clone -b 1.3.1 https://github.com/open-mmlab/mmdeploy --recursive 20 | # 编译并安装 MMDeploy(耗时大约十分钟) 21 | python tools/scripts/build_ubuntu_x64_ort.py 22 | ``` 23 | 24 | + windows TensorRT的环境 25 | + TensorRT 8.5 26 | + cuda 11.0, cudnn 27 | + vs2017 28 | + cmake version 3.22.1 29 | + opencv 30 | 31 | 32 | 33 | ### 1.Co-DETR 转ONNX 34 | 35 | 1.修改模型配置文件,关闭测试过程中的soft-nms(后面用EfficientNMS Plugin代替) 36 | 37 | ```python 38 | # mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py 39 | #mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py 40 | 41 | test_cfg=[ 42 | # # Deferent from the DINO, we use the NMS. 43 | dict( 44 | max_per_img=300, 45 | # NMS can improve the mAP by 0.2. 46 | # nms=dict(type='soft_nms', iou_threshold=0.8)), # 关掉test过程中的soft nms 47 | ), 48 | ``` 49 | 50 | 2.修改mmdeploy中关于onnx的导出配置 51 | 52 | ```python 53 | # mmdeploy/configs/_base_/onnx_config.py 54 | onnx_config = dict( 55 | type='onnx', 56 | export_params=True, 57 | keep_initializers_as_inputs=False, 58 | opset_version=11, # opset 版本 59 | save_file='end2end.onnx', #转出onnx的保存名字 60 | input_names=['input'], # input的名字 61 | output_names=['output'], # output的名字 62 | input_shape=None, 63 | optimize=True) 64 | # mmdeploy/configs/mmdet/_base_/base_static.py 65 | 66 | _base_ = ['../../_base_/onnx_config.py'] 67 | 68 | onnx_config = dict(output_names=['dets', 'labels'], input_shape=[640,640]) # static input的大小设置为640x640 69 | codebase_config = dict( 70 | type='mmdet', 71 | task='ObjectDetection', 72 | model_type='end2end', 73 | post_processing=dict( 74 | score_threshold=0.05, 75 | confidence_threshold=0.005, # for YOLOv3 76 | iou_threshold=0.5, 77 | max_output_boxes_per_class=200, 78 | pre_top_k=5000, 79 | keep_top_k=100, 80 | background_label_id=-1, 81 | )) 82 | 83 | # co-dino使用了多尺度训练,这里我们将test input的尺度设为640x640,减少计算量 84 | ``` 85 | 86 | 3.mmdeploy转onnx 87 | 88 | ```shell 89 | python mmdeploy/tools/deploy.py \ 90 | mmdeploy/configs/mmdet/detection/detection_onnxruntime_static.py \ 91 | mmdetection/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py \ 92 | mmdetection/checkpoints/co_dino_5scale_swin_large_16e_o365tococo-614254c9.pth \ 93 | mmdetection/demo/demo.jpg \ 94 | --work-dir mmdetection/checkpoints \ 95 | --device cpu 96 | # 这个过程生成了end2end.onnx的,但是onnxruntime的时候或报错,报错的原因是grid_sampler算子onnxruntime和tensorrt均不支持,稍后会编译tensorrt plugin解决该伪问题 97 | ``` 98 | 99 | 4.对onnx进行onnxsim和 fold constants 100 | 101 | ```shell 102 | polygraphy surgeon sanitize end2end.onnx --fold-constants -o end2end_folded.onnx 103 | python -m onnxsim end2end_folded.onnx end2end_folded_sim.onnx 104 | ``` 105 | 注意: 106 | 107 | ``` 108 | # 常量折叠和simplifier涉及到的库的版本 109 | polygraphy==0.49.0 110 | onnxruntime-gpu==1.19.2 111 | onnx-simplifier=0.4.36 112 | ``` 113 | 114 | 115 | 116 | ### 2. Windows 下单独编译mmdeploy中仅涉及Co-DETR的TensorRT Plugin 117 | 118 | ```CMakeLists 119 | cmake_minimum_required(VERSION 2.6) 120 | 121 | project(mmdeploy_plugins) 122 | 123 | add_definitions(-std=c++11) 124 | add_definitions(-DAPI_EXPORTS) 125 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 126 | set(CMAKE_CXX_STANDARD 11) 127 | set(CMAKE_BUILD_TYPE Release) 128 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /O2") 129 | add_compile_definitions(WIN32_LEAN_AND_MEAN NOMINMAX) 130 | 131 | find_package(CUDA REQUIRED) 132 | 133 | #if(WIN32) 134 | #enable_language(CUDA) 135 | #endif(WIN32) 136 | 137 | # cuda 138 | set(cuda_inc "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0/include") 139 | set(cuda_lib "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0/lib/x64") 140 | include_directories(${cuda_inc}) 141 | link_directories(${cuda_lib}) 142 | #cub 143 | set(CUB_ROOT_DIR "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0/include/cub") 144 | include_directories(${CUB_ROOT_DIR}) 145 | # tensorrt 146 | set(tensorrt_inc "D:/trt_install/TensorRT-8.5.1.7/include") 147 | set(tensorrt_lib "D:/trt_install/TensorRT-8.5.1.7/lib") 148 | include_directories(${tensorrt_inc}) 149 | link_directories(${tensorrt_lib}) 150 | # opencv 151 | #include_directories("${PROJECT_SOURCE_DIR}/third_party/CV460_64/include") 152 | #set(opencv_lib "${PROJECT_SOURCE_DIR}/third_party/CV460_64/lib/opencv_world460.lib") 153 | 154 | # common files,来源于mmdeploy 155 | include_directories(common) 156 | 157 | file(GLOB grid_sampler_src ${PROJECT_SOURCE_DIR}/grid_sampler/*.cpp ${PROJECT_SOURCE_DIR}/grid_sampler/*.cu) 158 | cuda_add_library(trtgrid_sampler SHARED ${grid_sampler_src}) 159 | #cuda_add_library(trtgrid_sampler STATIC ${grid_sampler_src}) 160 | target_link_libraries(trtgrid_sampler nvinfer cudart) 161 | 162 | 163 | file(GLOB topk_src ${PROJECT_SOURCE_DIR}/gather_topk/*.cpp ${PROJECT_SOURCE_DIR}/gather_topk/*.cu) 164 | cuda_add_library(trtgather_topk SHARED ${topk_src}) 165 | #cuda_add_library(trtgather_topk STATIC ${topk_src}) 166 | target_link_libraries(trtgather_topk nvinfer cudart) 167 | 168 | 169 | if(UNIX) 170 | add_definitions(-O2 -pthread) 171 | endif(UNIX) 172 | ``` 173 | 174 | 175 | 176 | 1. 打开vs studio 2017的终端`x64 Native Tools Command ...`,cd到项目的目录进行编译 177 | 178 | ![](test_res/static/cmd.png) 179 | 180 | 2.windows下编译TensorRT Plugin 181 | 182 | ```bash 183 | mkdir build && cd build 184 | cmake -G ”NMake Makefiles“ .. 185 | nmake 186 | ``` 187 | 188 | 在build文件夹下生成了`trtgrid_sampler.dll`和`trtgather_topk.dll`,下面我们会使用`trtgrid_sampler.dll`的plugin 。 189 | 190 | ### 3.Co-DETR ONNX Graph修改和编辑 191 | 192 | 原始导出的不包含nms的graph 193 | 194 | ![](test_res/static/v1.PNG) 195 | 196 | 执行编辑onnx graph的脚本: 197 | 198 | ```shell 199 | python co_detr_add_nms.py 200 | ``` 201 | 202 | 模型结构变为: 203 | 204 | ![](test_res/static/v2.PNG) 205 | 206 | ### 4.Windows下序列化Co-DETR TensorRT engine 207 | 208 | ```shell 209 | trtexec --onnx=end2end_foled_sim_nms.onnx --saveEngine=test_1.plan --workspace=60000 --verbose --plugins=./trtgrid_sampler.dll 210 | ``` 211 | 212 | 213 | 214 | ### 5.mmdetection中Co-DETR模型前处理实现的分析和C++重写 215 | 216 | mmdetection 3.3.0 co-dino的前处理: 217 | 218 | + opencv读入BGR图像 219 | 220 | + 等比例缩放,长边缩放到640,缩放方法bilinear 221 | 222 | + normalize: 223 | 224 | mean=[123.675, 116.28, 103.53], # RGB 225 | std=[58.395, 57.12, 57.375], #RGB 226 | 227 | + BGR2RGB 228 | 229 | + 短边右下角填充为0 230 | 231 | C++实现如下: 232 | 233 | ```c++ 234 | //mmdetection3.3.0 co-detr前处理 235 | void codetr::preprocess(cv::Mat &img, float data[]) { 236 | int w, h, x, y; 237 | float r_w = INPUT_W / (img.cols*1.0); 238 | float r_h = INPUT_H / (img.rows*1.0); 239 | if (r_h > r_w) { 240 | w = INPUT_W; 241 | h = r_w * img.rows; 242 | } 243 | else { 244 | w = r_h * img.cols; 245 | h = INPUT_H; 246 | } 247 | cv::Mat re(h, w, CV_8UC3); 248 | cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR); 249 | cv::Mat out(INPUT_H, INPUT_W, CV_8UC3, cv::Scalar(103, 116, 123)); //(0,0,0)像素填充 250 | re.copyTo(out(cv::Rect(0, 0, re.cols, re.rows))); //右下角 251 | 252 | int i = 0; 253 | for (int row = 0; row < INPUT_H; ++row) { 254 | uchar* uc_pixel = out.data + row * out.step; 255 | for (int col = 0; col < INPUT_W; ++col) { 256 | data[i] = ((float)uc_pixel[2] - 123.675)/58.395; //R 257 | data[i + INPUT_H * INPUT_W] = ((float)uc_pixel[1] - 116.28) / 57.12; //G 258 | data[i + 2 * INPUT_H * INPUT_W] = ((float)uc_pixel[0] - 103.53)/ 57.375; //B 259 | 260 | uc_pixel += 3; 261 | ++i; 262 | } 263 | } 264 | } 265 | ``` 266 | 267 | 268 | 269 | ### 6.Co-DETR TensorRT C++实现和测试 270 | 271 | 注意C++加载自己定义的Plugin 272 | 273 | ```c++ 274 | bool didInitPlugins = initLibNvInferPlugins(nullptr, ""); 275 | void* handle_grid_sampler = LoadLibrary(L"trtgrid_sampler.dll"); 276 | ``` 277 | 278 | TensorRT C++的推理Demo: 279 | 280 | | bus.jpg | zidane.jpg | 281 | | ------------------------- | ----------------------- | 282 | | ![](test_res/final_0.jpg) | ![](test_res/final.jpg) | 283 | 284 | ### 7. Linux如何编译该程序 285 | 286 | + 我提供了在Linux下编译Co-DETR进行端到端推理的代码,请参考`linux_cc/`,其中`plugin`为grid_sampler plugin的编译, `co_detr`为Co-DETR的TensorRT调用。 287 | 288 | > [!NOTE]\ 289 | > 290 | > + Co-DETR TensorRT的实现,坑确实比较多,参考的网络资源基本没有 291 | > + 我们将soft-nms算子删除,替换为TensorRT EfficientNMS Plugin 292 | > + 我们在windows下编译了TensorRT Plugin grid_sampler 293 | 294 | 最终成功实现了Co-DETR的端到端的TensorRT 模型推理异构计算加速推理! 295 | -------------------------------------------------------------------------------- /co_detr_add_nms.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | 4 | xujing 5 | 6 | 2024-10-2 7 | 8 | mmdetection=3.3.0 9 | 10 | co-detr onnx添加nms plugin 11 | 12 | ''' 13 | 14 | import onnx_graphsurgeon as gs 15 | import numpy as np 16 | import onnx 17 | 18 | 19 | score_node_name = "Sigmoid_8437" 20 | index_node_name = "Div_8449" 21 | box_node_name = "ScatterND_8577" 22 | 23 | num_class = 80 24 | 25 | 26 | def gather_concat_score(graph,num_class=num_class): 27 | score_output_node = [node for node in graph.nodes if node.name == score_node_name][0] 28 | score_output = score_output_node.outputs[0] 29 | 30 | index_output_node = [node for node in graph.nodes if node.name == index_node_name][0] 31 | index_output = index_output_node.outputs[0] 32 | 33 | # gather 34 | gather_output = gs.Variable(name="score_input_0",shape=(300,num_class),dtype=np.float32) 35 | gather_node = gs.Node(op="Gather",inputs=[score_output,index_output],outputs=[gather_output]) 36 | 37 | # # Unsqueeze 38 | # unsqueeze_output = gs.Variable(name="score_input",shape=(1,300,num_class),dtype=np.float32) 39 | # unsqueeze_node = gs.Node(op="Unsqueeze",inputs=[gather_output],outputs=[unsqueeze_output],attrs={"axes":0}) 40 | # reshape 41 | shape_score = gs.Constant("shape_score",values=np.array([1,300,num_class],dtype=np.int64)) 42 | scores = gs.Variable(name="score_input",shape=(1,300,num_class),dtype=np.float32) 43 | scores_node = gs.Node(op="Reshape",inputs=[gather_output,shape_score],outputs=[scores]) 44 | 45 | 46 | # concat 47 | box_node = [node for node in graph.nodes if node.name == box_node_name][0] 48 | # box_output = box_node.outputs[0] 49 | 50 | # Unsqueeze 51 | # unsqueeze_output_1 = gs.Variable(name="box_input",shape=(1,300,4),dtype=np.float32) 52 | # unsqueeze_node_1 = gs.Node(op="Unsqueeze",inputs=[box_node.outputs[0]],outputs=[unsqueeze_output_1],attrs={"axes":0}) 53 | 54 | # 替换为reshape,不用unsqueeze 55 | shape_box = gs.Constant("shape_box",values=np.array([1,300,4],dtype=np.int64)) #batchnms_trt: [1,300,1,4] 56 | boxes = gs.Variable(name="box_input",shape=(1,300,4),dtype=np.float32) 57 | boxes_node = gs.Node(op="Reshape",inputs=[box_node.outputs[0],shape_box],outputs=[boxes]) 58 | 59 | # concat_output = gs.Variable(name="concat_box",shape=(300,num_class+4),dtype=np.float32) 60 | # concat_node = gs.Node(op="Concat",inputs=[box_output,gather_output],outputs=[concat_output],attrs={"axis":1}) 61 | 62 | # graph.nodes.extend([gather_node,concat_node,]) 63 | graph.nodes.extend([gather_node,scores_node,boxes_node]) 64 | 65 | graph.outputs = [ boxes, scores ] 66 | 67 | graph.cleanup().toposort() 68 | # onnx.save(gs.export_onnx(graph),"./last_1.onnx") 69 | 70 | return graph 71 | 72 | 73 | 74 | # graph中插入EfficientNMS plugin op 75 | def create_and_add_plugin_node(graph, max_output_boxes, nms_type="efficientnms"): 76 | 77 | batch_size = graph.inputs[0].shape[0] 78 | print("The batch size is: ", batch_size) 79 | # input_h = graph.inputs[0].shape[2] 80 | # input_w = graph.inputs[0].shape[3] 81 | 82 | tensors = graph.tensors() 83 | boxes_tensor = tensors["box_input"] 84 | confs_tensor = tensors["score_input"] 85 | 86 | print(boxes_tensor) 87 | print(confs_tensor) 88 | 89 | if nms_type == "batchnms": 90 | num_detections = gs.Variable(name="num_detections").to_variable(dtype=np.int32, shape=[batch_size]) 91 | nmsed_boxes = gs.Variable(name="nmsed_boxes").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes, 4]) 92 | nmsed_scores = gs.Variable(name="nmsed_scores").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes]) 93 | nmsed_classes = gs.Variable(name="nmsed_classes").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes]) 94 | 95 | 96 | elif nms_type == "efficientnms": 97 | num_detections = gs.Variable(name="num_detections").to_variable(dtype=np.int32, shape=[batch_size, 1]) 98 | nmsed_boxes = gs.Variable(name="detection_boxes").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes, 4]) 99 | nmsed_scores = gs.Variable(name="detection_scores").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes]) 100 | nmsed_classes = gs.Variable(name="detection_classes").to_variable(dtype=np.int32, shape=[batch_size, max_output_boxes]) 101 | 102 | new_outputs = [num_detections, nmsed_boxes, nmsed_scores, nmsed_classes] 103 | 104 | if nms_type == "batchnms": 105 | mns_node = gs.Node( 106 | op="BatchedNMS_TRT", 107 | attrs=create_attrs_batchnms(max_output_boxes), 108 | inputs=[boxes_tensor, confs_tensor], 109 | outputs=new_outputs) 110 | elif nms_type == "efficientnms": 111 | mns_node = gs.Node( 112 | op="EfficientNMS_TRT", 113 | attrs=create_attrs_efficientnms(max_output_boxes), 114 | inputs=[boxes_tensor, confs_tensor], 115 | outputs=new_outputs) 116 | 117 | graph.nodes.append(mns_node) 118 | graph.outputs = new_outputs 119 | 120 | return graph.cleanup().toposort() 121 | 122 | 123 | def create_attrs_efficientnms(max_output_boxes=100): 124 | 125 | attrs = {} 126 | 127 | attrs["score_threshold"] = 0.70 128 | attrs["iou_threshold"] = 0.45 129 | attrs["max_output_boxes"] = max_output_boxes 130 | attrs["background_class"] = -1 131 | attrs["score_activation"] = False 132 | attrs["class_agnostic"] = False 133 | attrs["box_coding"] = 0 134 | # 001 is the default plugin version the parser will search for, and therefore can be omitted, 135 | # but we include it here for illustrative purposes. 136 | attrs["plugin_version"] = "1" 137 | 138 | return attrs 139 | 140 | def create_attrs_batchnms(max_output_boxes=100): 141 | 142 | attrs = {} 143 | 144 | attrs["shareLocation"] = True 145 | attrs["backgroundLabelId"] = -1 146 | attrs["numClasses"] = 80 147 | attrs["topK"] = 1000 148 | attrs["keepTopK"] = max_output_boxes 149 | attrs["scoreThreshold"] = 0.25 150 | attrs["iouThreshold"] = 0.45 151 | 152 | attrs["isNormalized"] = False 153 | attrs["clipBoxes"] = False 154 | attrs["scoreBits"] = 16 #FP16才起作用 155 | attrs["caffeSemantics"] = False 156 | 157 | # 001 is the default plugin version the parser will search for, and therefore can be omitted, 158 | # but we include it here for illustrative purposes. 159 | attrs["plugin_version"] = "1" 160 | 161 | return attrs 162 | 163 | if __name__ == "__main__": 164 | onnx_path = "./end2end_folded_sim.onnx" 165 | graph = gs.import_onnx(onnx.load(onnx_path)) 166 | 167 | # 添加op得到Efficient NMS plugin的input 168 | graph = gather_concat_score(graph) 169 | 170 | # 添加Efficient NMS plugin 171 | graph = create_and_add_plugin_node(graph, 20) 172 | 173 | # 保存图结构 174 | onnx.save(gs.export_onnx(graph),"./end2end_folded_sim_nms.onnx") 175 | 176 | 177 | 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /co_detr_trt.cpp: -------------------------------------------------------------------------------- 1 | // co_detr_trt.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。 2 | // 3 | 4 | #include 5 | #include 6 | 7 | #include "cuda_runtime_api.h" 8 | #include "NvOnnxParser.h" 9 | #include "NvInfer.h" 10 | #include "NvInferPlugin.h" 11 | 12 | # include "codetr.h" 13 | 14 | 15 | float h_input[INPUT_SIZE * INPUT_SIZE * 3]; 16 | int h_output_0[1]; //1 17 | float h_output_1[1 * 20 * 4]; //1 18 | float h_output_2[1 * 20]; //1 19 | int h_output_3[1 * 20]; //1 20 | 21 | int main() 22 | { 23 | codetr *CoDetr = new codetr; 24 | 25 | IExecutionContext* engine_context = CoDetr->load_engine("./model/test_1.plan"); 26 | 27 | if (engine_context == nullptr) 28 | { 29 | std::cerr << "failed to create tensorrt execution context." << std::endl; 30 | } 31 | 32 | 33 | //cv2读图片 34 | cv::Mat image; 35 | image = cv::imread("./test_img/zidane.jpg", 1); 36 | 37 | CoDetr->preprocess(image, h_input); 38 | 39 | void* buffers[5]; 40 | cudaMalloc(&buffers[0], INPUT_SIZE * INPUT_SIZE * 3 * sizeof(float)); //<- input 41 | cudaMalloc(&buffers[1], 1 * sizeof(int)); //<- num_detections 42 | cudaMalloc(&buffers[2], 1 * 20 * 4 * sizeof(float)); //<- nmsed_boxes 43 | cudaMalloc(&buffers[3], 1 * 20 * sizeof(float)); //<- nmsed_scores 44 | cudaMalloc(&buffers[4], 1 * 20 * sizeof(int)); //<- nmsed_classes 45 | 46 | cudaMemcpy(buffers[0], h_input, INPUT_SIZE * INPUT_SIZE * 3 * sizeof(float), cudaMemcpyHostToDevice); 47 | 48 | // -- do execute --------// 49 | engine_context->executeV2(buffers); 50 | 51 | cudaMemcpy(h_output_0, buffers[1], 1 * sizeof(int), cudaMemcpyDeviceToHost); 52 | cudaMemcpy(h_output_1, buffers[2], 1 * 20 * 4 * sizeof(float), cudaMemcpyDeviceToHost); 53 | cudaMemcpy(h_output_2, buffers[3], 1 * 20 * sizeof(float), cudaMemcpyDeviceToHost); 54 | cudaMemcpy(h_output_3, buffers[4], 1 * 20 * sizeof(int), cudaMemcpyDeviceToHost); 55 | 56 | 57 | //std::vector pred_box; 58 | //for (int i = 0; i < 300; i++) { 59 | // std::cout << "box: " << h_output_0[i * 5] << ", " << h_output_0[i * 5 + 1] << ", " << h_output_0[i * 5 + 2] << ", " << h_output_0[i * 5 + 3] << ", " << h_output_0[i * 5 + 4] << std::endl; 60 | // 61 | 62 | // if (h_output_0[i * 5 + 4] >= 0.80) { 63 | // Bbox box; 64 | // box.x1 = h_output_0[i * 5]; 65 | // box.y1 = h_output_0[i * 5 + 1]; 66 | // box.x2 = h_output_0[i * 5 + 2]; 67 | // box.y2 = h_output_0[i * 5 + 3]; 68 | // box.score = h_output_0[i * 5 + 4]; 69 | // box.classes = h_output_1[i]; 70 | 71 | // std::cout << box.classes << "," << box.score << std::endl; 72 | // std::cout << box.x1 << "," << box.y1 << ", " << box.x2 << ", " << box.y2 << std::endl; 73 | 74 | 75 | // pred_box.push_back(box); 76 | // } 77 | // 78 | // //float max_score = 0.0; 79 | // //int max_id = 0; 80 | // //for (int j = 0; j < 80; j++) { 81 | // // if (max_score <= h_output_0[i * 80 + j]) { 82 | // // max_score = h_output_0[i * 80 + j]; 83 | // // max_id = j; 84 | // // } 85 | // // //std::cout << h_output_0[i * 80 + j] << ", "; 86 | // //} 87 | // std::cout << "max_score: " << h_output_1[i] << std::endl; 88 | //} 89 | 90 | std::cout << h_output_0 << std::endl; 91 | std::vector pred_box; 92 | for (int i = 0; i < h_output_0[0]; i++) { 93 | Bbox box; 94 | box.x1 = h_output_1[i * 4]; 95 | box.y1 = h_output_1[i * 4 + 1]; 96 | box.x2 = h_output_1[i * 4 + 2]; 97 | box.y2 = h_output_1[i * 4 + 3]; 98 | box.score = h_output_2[i]; 99 | box.classes = h_output_3[i]; 100 | 101 | std::cout << box.classes << "," << box.score << std::endl; 102 | std::cout << box.x1 << "," << box.y1 << ", " << box.x2 << ", " << box.y2 << std::endl; 103 | 104 | 105 | pred_box.push_back(box); 106 | } 107 | 108 | std::vector out = CoDetr->postprocess(pred_box, image.cols, image.rows); 109 | cv::Mat img = CoDetr->renderBoundingBox(image, out); 110 | 111 | cv::imwrite("final.jpg", img); 112 | 113 | cv::namedWindow("Image", 1);//创建窗口 114 | cv::imshow("Image", img);//显示图像 115 | 116 | cv::waitKey(0); //等待按键 117 | 118 | cudaFree(buffers[0]); 119 | cudaFree(buffers[1]); 120 | cudaFree(buffers[2]); 121 | cudaFree(buffers[3]); 122 | cudaFree(buffers[4]); 123 | 124 | delete engine_context; 125 | 126 | } 127 | 128 | 129 | -------------------------------------------------------------------------------- /codetr.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/codetr.cpp -------------------------------------------------------------------------------- /codetr.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include "cuda_runtime_api.h" 6 | #include "NvInfer.h" 7 | #include "NvInferPlugin.h" 8 | #include "logging.h" 9 | 10 | 11 | #define BATCH_SIZE 1 12 | #define INPUT_W 640 13 | #define INPUT_H 640 14 | #define INPUT_SIZE 640 15 | 16 | using namespace nvinfer1; 17 | using namespace sample; 18 | 19 | 20 | // box x1,y1,x2,y2 21 | struct Bbox { 22 | float x1; 23 | float y1; 24 | float x2; 25 | float y2; 26 | float score; 27 | int classes; 28 | }; 29 | class codetr 30 | { 31 | public: 32 | 33 | codetr(); 34 | 35 | IExecutionContext* load_engine(std::string enginePath); 36 | 37 | void preprocess(cv::Mat &img, float data[]); 38 | 39 | std::vector postprocess(std::vector &out, int width, int height); 40 | 41 | cv::Mat renderBoundingBox(cv::Mat image, const std::vector &bboxes); 42 | 43 | public: 44 | //ICudaEngine* engine; 45 | //IExecutionContext* engine_context; 46 | cv::Mat image; 47 | std::vector class_names = { "person","bicycle","car","motorcycle","airplane", 48 | "bus","train","truck","boat","traffic light","fire hydrant","stop sign","parking meter","bench", 49 | "bird","cat","dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack", 50 | "umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball","kite","baseball bat", 51 | "baseball glove","skateboard","surfboard","tennis racket","bottle","wine glass", 52 | "cup","fork","knife","spoon","bowl","banana","apple","sandwich","orange","broccoli","carrot","hot dog", 53 | "pizza","donut","cake","chair","couch","potted plant","bed","dining table","toilet","tv", 54 | "laptop","mouse","remote","keyboard","cell phone","microwave","oven","toaster","sink", 55 | "refrigerator","book","clock","vase","scissors","teddy bear","hair drier","toothbrush" }; 56 | 57 | 58 | }; 59 | 60 | -------------------------------------------------------------------------------- /linux_cc/co_detr/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(co_detr) 4 | 5 | # opencv 6 | set(OpenCV_DIR /workspace/opencv-4.5.0/build) 7 | find_package(OpenCV REQUIRED) 8 | 9 | # ${var}表示引用这个变量 10 | find_package(CUDA REQUIRED) 11 | include_directories(${CUDA_INCLUDE_DIRS}) 12 | include_directories(${TensorRT_INCLUDE_DIRS}) 13 | # find_library(CUDA) 14 | # find_library(NVINFER NAMES nvinfer) 15 | # find_library(NVPARSERS NAMES nvparsers) 16 | # find_library(NVONNXPARSERS NAMES nvonnxparser) 17 | 18 | find_library(NVINFER NAMES nvinfer) 19 | find_library(NVPARSERS NAMES nvparsers) 20 | find_library(NVONNXPARSERS NAMES nvonnxparser) 21 | find_library(NVINFERPLUFIN NAMES nvinfer_plugin) 22 | 23 | #target_link_libraries(yolox libnvinfer_plugin.so) 24 | 25 | 26 | find_library(CUDNN_LIBRARY 27 | NAMES libcudnn.so${__cudnn_ver_suffix} libcudnn${__cudnn_ver_suffix}.dylib ${__cudnn_lib_win_name} 28 | PATHS $ENV{LD_LIBRARY_PATH} ${__libpath_cudart} ${CUDNN_ROOT_DIR} ${PC_CUDNN_LIBRARY_DIRS} ${CMAKE_INSTALL_PREFIX} 29 | PATH_SUFFIXES lib lib64 bin 30 | DOC "CUDNN library." 31 | ) 32 | 33 | 34 | file(GLOB_RECURSE _HEAD ${CMAKE_CURRENT_LIST_DIR}/src/*.h 35 | ${CMAKE_CURRENT_LIST_DIR}/src/*.cuh 36 | ) 37 | 38 | file(GLOB _SRC ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp 39 | ${CMAKE_CURRENT_LIST_DIR}/src/*.cu 40 | ) 41 | 42 | 43 | add_executable (${PROJECT_NAME} ${_SRC} ${_HEAD}) 44 | target_link_libraries(${PROJECT_NAME} 45 | ${NVINFER} 46 | ${NVONNXPARSERS} 47 | ${NVINFERPLUFIN} 48 | ${CUDA_LIBRARIES} 49 | ${CUDA_CUBLAS_LIBRARIES} 50 | ${CUDNN_LIBRARY} 51 | ${OpenCV_LIBS}) 52 | 53 | message(STATUS "cmake success!!! Co-DETR by xj") -------------------------------------------------------------------------------- /linux_cc/co_detr/src/co_detr_add_nms.py: -------------------------------------------------------------------------------- 1 | 2 | ''' 3 | 4 | xujing 5 | 6 | 2024-10-2 7 | 8 | mmdetection=3.3.0 9 | 10 | co-detr onnx添加nms plugin 11 | 12 | ''' 13 | 14 | import onnx_graphsurgeon as gs 15 | import numpy as np 16 | import onnx 17 | 18 | 19 | score_node_name = "Sigmoid_8437" 20 | index_node_name = "Div_8449" 21 | box_node_name = "ScatterND_8577" 22 | 23 | num_class = 80 24 | 25 | 26 | def gather_concat_score(graph,num_class=num_class): 27 | score_output_node = [node for node in graph.nodes if node.name == score_node_name][0] 28 | score_output = score_output_node.outputs[0] 29 | 30 | index_output_node = [node for node in graph.nodes if node.name == index_node_name][0] 31 | index_output = index_output_node.outputs[0] 32 | 33 | # gather 34 | gather_output = gs.Variable(name="score_input_0",shape=(300,num_class),dtype=np.float32) 35 | gather_node = gs.Node(op="Gather",inputs=[score_output,index_output],outputs=[gather_output]) 36 | 37 | # # Unsqueeze 38 | # unsqueeze_output = gs.Variable(name="score_input",shape=(1,300,num_class),dtype=np.float32) 39 | # unsqueeze_node = gs.Node(op="Unsqueeze",inputs=[gather_output],outputs=[unsqueeze_output],attrs={"axes":0}) 40 | # reshape 41 | shape_score = gs.Constant("shape_score",values=np.array([1,300,num_class],dtype=np.int64)) 42 | scores = gs.Variable(name="score_input",shape=(1,300,num_class),dtype=np.float32) 43 | scores_node = gs.Node(op="Reshape",inputs=[gather_output,shape_score],outputs=[scores]) 44 | 45 | 46 | # concat 47 | box_node = [node for node in graph.nodes if node.name == box_node_name][0] 48 | # box_output = box_node.outputs[0] 49 | 50 | # Unsqueeze 51 | # unsqueeze_output_1 = gs.Variable(name="box_input",shape=(1,300,4),dtype=np.float32) 52 | # unsqueeze_node_1 = gs.Node(op="Unsqueeze",inputs=[box_node.outputs[0]],outputs=[unsqueeze_output_1],attrs={"axes":0}) 53 | 54 | # 替换为reshape,不用unsqueeze 55 | shape_box = gs.Constant("shape_box",values=np.array([1,300,4],dtype=np.int64)) #batchnms_trt: [1,300,1,4] 56 | boxes = gs.Variable(name="box_input",shape=(1,300,4),dtype=np.float32) 57 | boxes_node = gs.Node(op="Reshape",inputs=[box_node.outputs[0],shape_box],outputs=[boxes]) 58 | 59 | # concat_output = gs.Variable(name="concat_box",shape=(300,num_class+4),dtype=np.float32) 60 | # concat_node = gs.Node(op="Concat",inputs=[box_output,gather_output],outputs=[concat_output],attrs={"axis":1}) 61 | 62 | # graph.nodes.extend([gather_node,concat_node,]) 63 | graph.nodes.extend([gather_node,scores_node,boxes_node]) 64 | 65 | graph.outputs = [ boxes, scores ] 66 | 67 | graph.cleanup().toposort() 68 | # onnx.save(gs.export_onnx(graph),"./last_1.onnx") 69 | 70 | return graph 71 | 72 | 73 | 74 | # graph中插入EfficientNMS plugin op 75 | def create_and_add_plugin_node(graph, max_output_boxes, nms_type="efficientnms"): 76 | 77 | batch_size = graph.inputs[0].shape[0] 78 | print("The batch size is: ", batch_size) 79 | # input_h = graph.inputs[0].shape[2] 80 | # input_w = graph.inputs[0].shape[3] 81 | 82 | tensors = graph.tensors() 83 | boxes_tensor = tensors["box_input"] 84 | confs_tensor = tensors["score_input"] 85 | 86 | print(boxes_tensor) 87 | print(confs_tensor) 88 | 89 | if nms_type == "batchnms": 90 | num_detections = gs.Variable(name="num_detections").to_variable(dtype=np.int32, shape=[batch_size]) 91 | nmsed_boxes = gs.Variable(name="nmsed_boxes").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes, 4]) 92 | nmsed_scores = gs.Variable(name="nmsed_scores").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes]) 93 | nmsed_classes = gs.Variable(name="nmsed_classes").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes]) 94 | 95 | 96 | elif nms_type == "efficientnms": 97 | num_detections = gs.Variable(name="num_detections").to_variable(dtype=np.int32, shape=[batch_size, 1]) 98 | nmsed_boxes = gs.Variable(name="detection_boxes").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes, 4]) 99 | nmsed_scores = gs.Variable(name="detection_scores").to_variable(dtype=np.float32, shape=[batch_size, max_output_boxes]) 100 | nmsed_classes = gs.Variable(name="detection_classes").to_variable(dtype=np.int32, shape=[batch_size, max_output_boxes]) 101 | 102 | new_outputs = [num_detections, nmsed_boxes, nmsed_scores, nmsed_classes] 103 | 104 | if nms_type == "batchnms": 105 | mns_node = gs.Node( 106 | op="BatchedNMS_TRT", 107 | attrs=create_attrs_batchnms(max_output_boxes), 108 | inputs=[boxes_tensor, confs_tensor], 109 | outputs=new_outputs) 110 | elif nms_type == "efficientnms": 111 | mns_node = gs.Node( 112 | op="EfficientNMS_TRT", 113 | attrs=create_attrs_efficientnms(max_output_boxes), 114 | inputs=[boxes_tensor, confs_tensor], 115 | outputs=new_outputs) 116 | 117 | graph.nodes.append(mns_node) 118 | graph.outputs = new_outputs 119 | 120 | return graph.cleanup().toposort() 121 | 122 | 123 | def create_attrs_efficientnms(max_output_boxes=100): 124 | 125 | attrs = {} 126 | 127 | attrs["score_threshold"] = 0.70 128 | attrs["iou_threshold"] = 0.45 129 | attrs["max_output_boxes"] = max_output_boxes 130 | attrs["background_class"] = -1 131 | attrs["score_activation"] = False 132 | attrs["class_agnostic"] = False 133 | attrs["box_coding"] = 0 134 | # 001 is the default plugin version the parser will search for, and therefore can be omitted, 135 | # but we include it here for illustrative purposes. 136 | attrs["plugin_version"] = "1" 137 | 138 | return attrs 139 | 140 | def create_attrs_batchnms(max_output_boxes=100): 141 | 142 | attrs = {} 143 | 144 | attrs["shareLocation"] = True 145 | attrs["backgroundLabelId"] = -1 146 | attrs["numClasses"] = 80 147 | attrs["topK"] = 1000 148 | attrs["keepTopK"] = max_output_boxes 149 | attrs["scoreThreshold"] = 0.25 150 | attrs["iouThreshold"] = 0.45 151 | 152 | attrs["isNormalized"] = False 153 | attrs["clipBoxes"] = False 154 | attrs["scoreBits"] = 16 #FP16才起作用 155 | attrs["caffeSemantics"] = False 156 | 157 | # 001 is the default plugin version the parser will search for, and therefore can be omitted, 158 | # but we include it here for illustrative purposes. 159 | attrs["plugin_version"] = "1" 160 | 161 | return attrs 162 | 163 | if __name__ == "__main__": 164 | onnx_path = "./end2end_folded_sim.onnx" 165 | graph = gs.import_onnx(onnx.load(onnx_path)) 166 | 167 | # 添加op得到Efficient NMS plugin的input 168 | graph = gather_concat_score(graph) 169 | 170 | # 添加Efficient NMS plugin 171 | graph = create_and_add_plugin_node(graph, 20) 172 | 173 | # 保存图结构 174 | onnx.save(gs.export_onnx(graph),"./end2end_folded_sim_nms.onnx") 175 | 176 | 177 | 178 | 179 | 180 | 181 | -------------------------------------------------------------------------------- /linux_cc/co_detr/src/co_detr_trt.cpp: -------------------------------------------------------------------------------- 1 | // co_detr_trt.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。 2 | // 3 | 4 | #include 5 | #include 6 | 7 | #include "cuda_runtime_api.h" 8 | #include "NvOnnxParser.h" 9 | #include "NvInfer.h" 10 | #include "NvInferPlugin.h" 11 | 12 | # include "codetr.h" 13 | 14 | 15 | float h_input[INPUT_SIZE * INPUT_SIZE * 3]; 16 | int h_output_0[1]; //1 17 | float h_output_1[1 * 20 * 4]; //1 18 | float h_output_2[1 * 20]; //1 19 | int h_output_3[1 * 20]; //1 20 | 21 | int main() 22 | { 23 | codetr *CoDetr = new codetr; 24 | 25 | IExecutionContext* engine_context = CoDetr->load_engine("./model/test_1.plan"); 26 | 27 | if (engine_context == nullptr) 28 | { 29 | std::cerr << "failed to create tensorrt execution context." << std::endl; 30 | } 31 | 32 | 33 | //cv2读图片 34 | cv::Mat image; 35 | image = cv::imread("./test_img/zidane.jpg", 1); 36 | 37 | CoDetr->preprocess(image, h_input); 38 | 39 | void* buffers[5]; 40 | cudaMalloc(&buffers[0], INPUT_SIZE * INPUT_SIZE * 3 * sizeof(float)); //<- input 41 | cudaMalloc(&buffers[1], 1 * sizeof(int)); //<- num_detections 42 | cudaMalloc(&buffers[2], 1 * 20 * 4 * sizeof(float)); //<- nmsed_boxes 43 | cudaMalloc(&buffers[3], 1 * 20 * sizeof(float)); //<- nmsed_scores 44 | cudaMalloc(&buffers[4], 1 * 20 * sizeof(int)); //<- nmsed_classes 45 | 46 | cudaMemcpy(buffers[0], h_input, INPUT_SIZE * INPUT_SIZE * 3 * sizeof(float), cudaMemcpyHostToDevice); 47 | 48 | // -- do execute --------// 49 | engine_context->executeV2(buffers); 50 | 51 | cudaMemcpy(h_output_0, buffers[1], 1 * sizeof(int), cudaMemcpyDeviceToHost); 52 | cudaMemcpy(h_output_1, buffers[2], 1 * 20 * 4 * sizeof(float), cudaMemcpyDeviceToHost); 53 | cudaMemcpy(h_output_2, buffers[3], 1 * 20 * sizeof(float), cudaMemcpyDeviceToHost); 54 | cudaMemcpy(h_output_3, buffers[4], 1 * 20 * sizeof(int), cudaMemcpyDeviceToHost); 55 | 56 | 57 | //std::vector pred_box; 58 | //for (int i = 0; i < 300; i++) { 59 | // std::cout << "box: " << h_output_0[i * 5] << ", " << h_output_0[i * 5 + 1] << ", " << h_output_0[i * 5 + 2] << ", " << h_output_0[i * 5 + 3] << ", " << h_output_0[i * 5 + 4] << std::endl; 60 | // 61 | 62 | // if (h_output_0[i * 5 + 4] >= 0.80) { 63 | // Bbox box; 64 | // box.x1 = h_output_0[i * 5]; 65 | // box.y1 = h_output_0[i * 5 + 1]; 66 | // box.x2 = h_output_0[i * 5 + 2]; 67 | // box.y2 = h_output_0[i * 5 + 3]; 68 | // box.score = h_output_0[i * 5 + 4]; 69 | // box.classes = h_output_1[i]; 70 | 71 | // std::cout << box.classes << "," << box.score << std::endl; 72 | // std::cout << box.x1 << "," << box.y1 << ", " << box.x2 << ", " << box.y2 << std::endl; 73 | 74 | 75 | // pred_box.push_back(box); 76 | // } 77 | // 78 | // //float max_score = 0.0; 79 | // //int max_id = 0; 80 | // //for (int j = 0; j < 80; j++) { 81 | // // if (max_score <= h_output_0[i * 80 + j]) { 82 | // // max_score = h_output_0[i * 80 + j]; 83 | // // max_id = j; 84 | // // } 85 | // // //std::cout << h_output_0[i * 80 + j] << ", "; 86 | // //} 87 | // std::cout << "max_score: " << h_output_1[i] << std::endl; 88 | //} 89 | 90 | std::cout << h_output_0 << std::endl; 91 | std::vector pred_box; 92 | for (int i = 0; i < h_output_0[0]; i++) { 93 | Bbox box; 94 | box.x1 = h_output_1[i * 4]; 95 | box.y1 = h_output_1[i * 4 + 1]; 96 | box.x2 = h_output_1[i * 4 + 2]; 97 | box.y2 = h_output_1[i * 4 + 3]; 98 | box.score = h_output_2[i]; 99 | box.classes = h_output_3[i]; 100 | 101 | std::cout << box.classes << "," << box.score << std::endl; 102 | std::cout << box.x1 << "," << box.y1 << ", " << box.x2 << ", " << box.y2 << std::endl; 103 | 104 | 105 | pred_box.push_back(box); 106 | } 107 | 108 | std::vector out = CoDetr->postprocess(pred_box, image.cols, image.rows); 109 | cv::Mat img = CoDetr->renderBoundingBox(image, out); 110 | 111 | cv::imwrite("final.jpg", img); 112 | 113 | // cv::namedWindow("Image", 1);//创建窗口 114 | // cv::imshow("Image", img);//显示图像 115 | 116 | // cv::waitKey(0); //等待按键 117 | 118 | cudaFree(buffers[0]); 119 | cudaFree(buffers[1]); 120 | cudaFree(buffers[2]); 121 | cudaFree(buffers[3]); 122 | cudaFree(buffers[4]); 123 | 124 | delete engine_context; 125 | 126 | } 127 | 128 | 129 | -------------------------------------------------------------------------------- /linux_cc/co_detr/src/codetr.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/linux_cc/co_detr/src/codetr.cpp -------------------------------------------------------------------------------- /linux_cc/co_detr/src/codetr.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | #include "cuda_runtime_api.h" 6 | #include "NvInfer.h" 7 | #include "NvInferPlugin.h" 8 | #include "logging.h" 9 | 10 | 11 | #define BATCH_SIZE 1 12 | #define INPUT_W 640 13 | #define INPUT_H 640 14 | #define INPUT_SIZE 640 15 | 16 | using namespace nvinfer1; 17 | using namespace sample; 18 | 19 | 20 | // box x1,y1,x2,y2 21 | struct Bbox { 22 | float x1; 23 | float y1; 24 | float x2; 25 | float y2; 26 | float score; 27 | int classes; 28 | }; 29 | class codetr 30 | { 31 | public: 32 | 33 | codetr(); 34 | 35 | IExecutionContext* load_engine(std::string enginePath); 36 | 37 | void preprocess(cv::Mat &img, float data[]); 38 | 39 | std::vector postprocess(std::vector &out, int width, int height); 40 | 41 | cv::Mat renderBoundingBox(cv::Mat image, const std::vector &bboxes); 42 | 43 | public: 44 | //ICudaEngine* engine; 45 | //IExecutionContext* engine_context; 46 | cv::Mat image; 47 | std::vector class_names = { "person","bicycle","car","motorcycle","airplane", 48 | "bus","train","truck","boat","traffic light","fire hydrant","stop sign","parking meter","bench", 49 | "bird","cat","dog","horse","sheep","cow","elephant","bear","zebra","giraffe","backpack", 50 | "umbrella","handbag","tie","suitcase","frisbee","skis","snowboard","sports ball","kite","baseball bat", 51 | "baseball glove","skateboard","surfboard","tennis racket","bottle","wine glass", 52 | "cup","fork","knife","spoon","bowl","banana","apple","sandwich","orange","broccoli","carrot","hot dog", 53 | "pizza","donut","cake","chair","couch","potted plant","bed","dining table","toilet","tv", 54 | "laptop","mouse","remote","keyboard","cell phone","microwave","oven","toaster","sink", 55 | "refrigerator","book","clock","vase","scissors","teddy bear","hair drier","toothbrush" }; 56 | 57 | 58 | }; 59 | 60 | -------------------------------------------------------------------------------- /linux_cc/co_detr/src/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #pragma warning(disable:4996) 18 | #ifndef TENSORRT_LOGGING_H 19 | #define TENSORRT_LOGGING_H 20 | 21 | #include "NvInferRuntimeCommon.h" 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | namespace sample 31 | { 32 | 33 | using Severity = nvinfer1::ILogger::Severity; 34 | 35 | class LogStreamConsumerBuffer : public std::stringbuf 36 | { 37 | public: 38 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 39 | : mOutput(stream) 40 | , mPrefix(prefix) 41 | , mShouldLog(shouldLog) 42 | { 43 | } 44 | 45 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) noexcept 46 | : mOutput(other.mOutput) 47 | , mPrefix(other.mPrefix) 48 | , mShouldLog(other.mShouldLog) 49 | { 50 | } 51 | LogStreamConsumerBuffer(const LogStreamConsumerBuffer& other) = delete; 52 | LogStreamConsumerBuffer() = delete; 53 | LogStreamConsumerBuffer& operator=(const LogStreamConsumerBuffer&) = delete; 54 | LogStreamConsumerBuffer& operator=(LogStreamConsumerBuffer&&) = delete; 55 | 56 | ~LogStreamConsumerBuffer() override 57 | { 58 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 59 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 60 | // if the pointer to the beginning is not equal to the pointer to the current position, 61 | // call putOutput() to log the output to the stream 62 | if (pbase() != pptr()) 63 | { 64 | putOutput(); 65 | } 66 | } 67 | 68 | //! 69 | //! synchronizes the stream buffer and returns 0 on success 70 | //! synchronizing the stream buffer consists of inserting the buffer contents into the stream, 71 | //! resetting the buffer and flushing the stream 72 | //! 73 | int32_t sync() override 74 | { 75 | putOutput(); 76 | return 0; 77 | } 78 | 79 | void putOutput() 80 | { 81 | if (mShouldLog) 82 | { 83 | // prepend timestamp 84 | std::time_t timestamp = std::time(nullptr); 85 | tm* tm_local = std::localtime(×tamp); 86 | mOutput << "["; 87 | mOutput << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 88 | mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 89 | mOutput << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 90 | mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 91 | mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 92 | mOutput << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 93 | // std::stringbuf::str() gets the string contents of the buffer 94 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 95 | mOutput << mPrefix << str(); 96 | } 97 | // set the buffer to empty 98 | str(""); 99 | // flush the stream 100 | mOutput.flush(); 101 | } 102 | 103 | void setShouldLog(bool shouldLog) 104 | { 105 | mShouldLog = shouldLog; 106 | } 107 | 108 | private: 109 | std::ostream& mOutput; 110 | std::string mPrefix; 111 | bool mShouldLog{}; 112 | }; // class LogStreamConsumerBuffer 113 | 114 | //! 115 | //! \class LogStreamConsumerBase 116 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 117 | //! 118 | class LogStreamConsumerBase 119 | { 120 | public: 121 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 122 | : mBuffer(stream, prefix, shouldLog) 123 | { 124 | } 125 | 126 | protected: 127 | LogStreamConsumerBuffer mBuffer; 128 | }; // class LogStreamConsumerBase 129 | 130 | //! 131 | //! \class LogStreamConsumer 132 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 133 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 134 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 135 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 136 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 137 | //! Please do not change the order of the parent classes. 138 | //! 139 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 140 | { 141 | public: 142 | //! 143 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 144 | //! Reportable severity determines if the messages are severe enough to be logged. 145 | //! 146 | LogStreamConsumer(nvinfer1::ILogger::Severity reportableSeverity, nvinfer1::ILogger::Severity severity) 147 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 148 | , std::ostream(&mBuffer) // links the stream buffer with the stream 149 | , mShouldLog(severity <= reportableSeverity) 150 | , mSeverity(severity) 151 | { 152 | } 153 | 154 | LogStreamConsumer(LogStreamConsumer&& other) noexcept 155 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 156 | , std::ostream(&mBuffer) // links the stream buffer with the stream 157 | , mShouldLog(other.mShouldLog) 158 | , mSeverity(other.mSeverity) 159 | { 160 | } 161 | LogStreamConsumer(const LogStreamConsumer& other) = delete; 162 | LogStreamConsumer() = delete; 163 | ~LogStreamConsumer() = default; 164 | LogStreamConsumer& operator=(const LogStreamConsumer&) = delete; 165 | LogStreamConsumer& operator=(LogStreamConsumer&&) = delete; 166 | 167 | void setReportableSeverity(Severity reportableSeverity) 168 | { 169 | mShouldLog = mSeverity <= reportableSeverity; 170 | mBuffer.setShouldLog(mShouldLog); 171 | } 172 | 173 | private: 174 | static std::ostream& severityOstream(Severity severity) 175 | { 176 | return severity >= Severity::kINFO ? std::cout : std::cerr; 177 | } 178 | 179 | static std::string severityPrefix(Severity severity) 180 | { 181 | switch (severity) 182 | { 183 | case Severity::kINTERNAL_ERROR: return "[F] "; 184 | case Severity::kERROR: return "[E] "; 185 | case Severity::kWARNING: return "[W] "; 186 | case Severity::kINFO: return "[I] "; 187 | case Severity::kVERBOSE: return "[V] "; 188 | default: assert(0); return ""; 189 | } 190 | } 191 | 192 | bool mShouldLog; 193 | Severity mSeverity; 194 | }; // class LogStreamConsumer 195 | 196 | //! 197 | //! \class Logger 198 | //! 199 | //! \brief Class which manages logging of TensorRT tools and samples 200 | //! 201 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 202 | //! and supports logging two types of messages: 203 | //! 204 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 205 | //! - Test pass/fail messages 206 | //! 207 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 208 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 209 | //! 210 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 211 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 212 | //! 213 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 214 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 215 | //! library and messages coming from the sample. 216 | //! 217 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 218 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 219 | //! object. 220 | //! 221 | class Logger : public nvinfer1::ILogger 222 | { 223 | public: 224 | explicit Logger(Severity severity = Severity::kWARNING) 225 | : mReportableSeverity(severity) 226 | { 227 | } 228 | 229 | //! 230 | //! \enum TestResult 231 | //! \brief Represents the state of a given test 232 | //! 233 | enum class TestResult 234 | { 235 | kRUNNING, //!< The test is running 236 | kPASSED, //!< The test passed 237 | kFAILED, //!< The test failed 238 | kWAIVED //!< The test was waived 239 | }; 240 | 241 | //! 242 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 243 | //! \return The nvinfer1::ILogger associated with this Logger 244 | //! 245 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 246 | //! we can eliminate the inheritance of Logger from ILogger 247 | //! 248 | nvinfer1::ILogger& getTRTLogger() noexcept 249 | { 250 | return *this; 251 | } 252 | 253 | //! 254 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 255 | //! 256 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 257 | //! inheritance from nvinfer1::ILogger 258 | //! 259 | void log(Severity severity, const char* msg) noexcept override 260 | { 261 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 262 | } 263 | 264 | //! 265 | //! \brief Method for controlling the verbosity of logging output 266 | //! 267 | //! \param severity The logger will only emit messages that have severity of this level or higher. 268 | //! 269 | void setReportableSeverity(Severity severity) noexcept 270 | { 271 | mReportableSeverity = severity; 272 | } 273 | 274 | //! 275 | //! \brief Opaque handle that holds logging information for a particular test 276 | //! 277 | //! This object is an opaque handle to information used by the Logger to print test results. 278 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 279 | //! with Logger::reportTest{Start,End}(). 280 | //! 281 | class TestAtom 282 | { 283 | public: 284 | TestAtom(TestAtom&&) = default; 285 | 286 | private: 287 | friend class Logger; 288 | 289 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 290 | : mStarted(started) 291 | , mName(name) 292 | , mCmdline(cmdline) 293 | { 294 | } 295 | 296 | bool mStarted; 297 | std::string mName; 298 | std::string mCmdline; 299 | }; 300 | 301 | //! 302 | //! \brief Define a test for logging 303 | //! 304 | //! \param[in] name The name of the test. This should be a string starting with 305 | //! "TensorRT" and containing dot-separated strings containing 306 | //! the characters [A-Za-z0-9_]. 307 | //! For example, "TensorRT.sample_googlenet" 308 | //! \param[in] cmdline The command line used to reproduce the test 309 | // 310 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 311 | //! 312 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 313 | { 314 | return TestAtom(false, name, cmdline); 315 | } 316 | 317 | //! 318 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 319 | //! as input 320 | //! 321 | //! \param[in] name The name of the test 322 | //! \param[in] argc The number of command-line arguments 323 | //! \param[in] argv The array of command-line arguments (given as C strings) 324 | //! 325 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 326 | //! 327 | static TestAtom defineTest(const std::string& name, int32_t argc, char const* const* argv) 328 | { 329 | // Append TensorRT version as info 330 | const std::string vname = name + " [TensorRT v" + std::to_string(NV_TENSORRT_VERSION) + "]"; 331 | auto cmdline = genCmdlineString(argc, argv); 332 | return defineTest(vname, cmdline); 333 | } 334 | 335 | //! 336 | //! \brief Report that a test has started. 337 | //! 338 | //! \pre reportTestStart() has not been called yet for the given testAtom 339 | //! 340 | //! \param[in] testAtom The handle to the test that has started 341 | //! 342 | static void reportTestStart(TestAtom& testAtom) 343 | { 344 | reportTestResult(testAtom, TestResult::kRUNNING); 345 | assert(!testAtom.mStarted); 346 | testAtom.mStarted = true; 347 | } 348 | 349 | //! 350 | //! \brief Report that a test has ended. 351 | //! 352 | //! \pre reportTestStart() has been called for the given testAtom 353 | //! 354 | //! \param[in] testAtom The handle to the test that has ended 355 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 356 | //! TestResult::kFAILED, TestResult::kWAIVED 357 | //! 358 | static void reportTestEnd(TestAtom const& testAtom, TestResult result) 359 | { 360 | assert(result != TestResult::kRUNNING); 361 | assert(testAtom.mStarted); 362 | reportTestResult(testAtom, result); 363 | } 364 | 365 | static int32_t reportPass(TestAtom const& testAtom) 366 | { 367 | reportTestEnd(testAtom, TestResult::kPASSED); 368 | return EXIT_SUCCESS; 369 | } 370 | 371 | static int32_t reportFail(TestAtom const& testAtom) 372 | { 373 | reportTestEnd(testAtom, TestResult::kFAILED); 374 | return EXIT_FAILURE; 375 | } 376 | 377 | static int32_t reportWaive(TestAtom const& testAtom) 378 | { 379 | reportTestEnd(testAtom, TestResult::kWAIVED); 380 | return EXIT_SUCCESS; 381 | } 382 | 383 | static int32_t reportTest(TestAtom const& testAtom, bool pass) 384 | { 385 | return pass ? reportPass(testAtom) : reportFail(testAtom); 386 | } 387 | 388 | Severity getReportableSeverity() const 389 | { 390 | return mReportableSeverity; 391 | } 392 | 393 | private: 394 | //! 395 | //! \brief returns an appropriate string for prefixing a log message with the given severity 396 | //! 397 | static const char* severityPrefix(Severity severity) 398 | { 399 | switch (severity) 400 | { 401 | case Severity::kINTERNAL_ERROR: return "[F] "; 402 | case Severity::kERROR: return "[E] "; 403 | case Severity::kWARNING: return "[W] "; 404 | case Severity::kINFO: return "[I] "; 405 | case Severity::kVERBOSE: return "[V] "; 406 | default: assert(0); return ""; 407 | } 408 | } 409 | 410 | //! 411 | //! \brief returns an appropriate string for prefixing a test result message with the given result 412 | //! 413 | static const char* testResultString(TestResult result) 414 | { 415 | switch (result) 416 | { 417 | case TestResult::kRUNNING: return "RUNNING"; 418 | case TestResult::kPASSED: return "PASSED"; 419 | case TestResult::kFAILED: return "FAILED"; 420 | case TestResult::kWAIVED: return "WAIVED"; 421 | default: assert(0); return ""; 422 | } 423 | } 424 | 425 | //! 426 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 427 | //! 428 | static std::ostream& severityOstream(Severity severity) 429 | { 430 | return severity >= Severity::kINFO ? std::cout : std::cerr; 431 | } 432 | 433 | //! 434 | //! \brief method that implements logging test results 435 | //! 436 | static void reportTestResult(TestAtom const& testAtom, TestResult result) 437 | { 438 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 439 | << testAtom.mCmdline << std::endl; 440 | } 441 | 442 | //! 443 | //! \brief generate a command line string from the given (argc, argv) values 444 | //! 445 | static std::string genCmdlineString(int32_t argc, char const* const* argv) 446 | { 447 | std::stringstream ss; 448 | for (int32_t i = 0; i < argc; i++) 449 | { 450 | if (i > 0) 451 | { 452 | ss << " "; 453 | } 454 | ss << argv[i]; 455 | } 456 | return ss.str(); 457 | } 458 | 459 | Severity mReportableSeverity; 460 | }; // class Logger 461 | 462 | namespace 463 | { 464 | //! 465 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 466 | //! 467 | //! Example usage: 468 | //! 469 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 470 | //! 471 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 472 | { 473 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 474 | } 475 | 476 | //! 477 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 478 | //! 479 | //! Example usage: 480 | //! 481 | //! LOG_INFO(logger) << "hello world" << std::endl; 482 | //! 483 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 484 | { 485 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 486 | } 487 | 488 | //! 489 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 490 | //! 491 | //! Example usage: 492 | //! 493 | //! LOG_WARN(logger) << "hello world" << std::endl; 494 | //! 495 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 496 | { 497 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 498 | } 499 | 500 | //! 501 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 502 | //! 503 | //! Example usage: 504 | //! 505 | //! LOG_ERROR(logger) << "hello world" << std::endl; 506 | //! 507 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 508 | { 509 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 510 | } 511 | 512 | //! 513 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 514 | //! ("fatal" severity) 515 | //! 516 | //! Example usage: 517 | //! 518 | //! LOG_FATAL(logger) << "hello world" << std::endl; 519 | //! 520 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 521 | { 522 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 523 | } 524 | } // anonymous namespace 525 | } // namespace sample 526 | #endif // TENSORRT_LOGGING_H -------------------------------------------------------------------------------- /linux_cc/co_detr/test_img/bus.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/linux_cc/co_detr/test_img/bus.jpg -------------------------------------------------------------------------------- /linux_cc/co_detr/test_img/zidane.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/linux_cc/co_detr/test_img/zidane.jpg -------------------------------------------------------------------------------- /linux_cc/plugin/Makefile: -------------------------------------------------------------------------------- 1 | include ./Makefile.inc 2 | 3 | SOURCE_CU = $(shell find . -name '*.cu' 2>/dev/null) 4 | SOURCE_PY = $(shell find . -name '*.py' 2>/dev/null) 5 | OBJ = $(shell find . -name *.o 2>/dev/null) 6 | DEP = $(OBJ:.o=.d) 7 | TARGET_SO = $(SOURCE_CU:.cu=.so) 8 | 9 | -include $(DEP) 10 | 11 | all: $(TARGET_SO) 12 | 13 | %.so: %.o 14 | $(NVCC) $(SOFLAG) $(LDFLAG) -o $@ $+ 15 | 16 | %.o: %.cu 17 | $(NVCC) $(CUFLAG) $(INCLUDE) -M -MT $@ -o $(@:.o=.d) $< 18 | $(NVCC) $(CUFLAG) $(INCLUDE) -o $@ -c $< 19 | -------------------------------------------------------------------------------- /linux_cc/plugin/Makefile.inc: -------------------------------------------------------------------------------- 1 | CUDA_PATH = /usr/local/cuda 2 | NVCC = $(CUDA_PATH)/bin/nvcc 3 | TRT_PATH = /usr/lib/x86_64-linux-gnu 4 | GENCODE = -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 5 | DEBUG_MACRO = -UDEBUG 6 | CUFLAG = -w -std=c++14 -O3 $(DEBUG_MACRO) -Xcompiler -fPIC $(GENCODE) 7 | CCFLAG = -w -std=c++14 -O3 $(DEBUG_MACRO) -Xcompiler -fPIC -use_fast_math 8 | SOFLAG = -shared 9 | INCLUDE = -I. -I$(CUDA_PATH)/include -I$(TRT_PATH)/include -I./common 10 | LDFLAG = -L$(CUDA_PATH)/lib64 -lcudart -L$(TRT_PATH)/lib -lnvinfer 11 | -------------------------------------------------------------------------------- /linux_cc/plugin/common/common_cuda_helper.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef COMMON_CUDA_HELPER 3 | #define COMMON_CUDA_HELPER 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) 13 | 14 | #define THREADS_PER_BLOCK 512 15 | 16 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 17 | inline int GET_BLOCKS(const int N) { 18 | int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); 19 | int max_block_num = 4096; 20 | return std::min(optimal_block_num, max_block_num); 21 | } 22 | 23 | #define cudaCheckError() \ 24 | { \ 25 | cudaError_t e = cudaGetLastError(); \ 26 | if (e != cudaSuccess) { \ 27 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ 28 | exit(0); \ 29 | } \ 30 | } 31 | 32 | /** 33 | * Returns a view of the original tensor with its dimensions permuted. 34 | * 35 | * @param[out] dst pointer to the destination tensor 36 | * @param[in] src pointer to the source tensor 37 | * @param[in] src_size shape of the src tensor 38 | * @param[in] permute The desired ordering of dimensions 39 | * @param[in] src_dim dim of src tensor 40 | * @param[in] stream cuda stream handle 41 | */ 42 | template 43 | void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, 44 | cudaStream_t stream = 0); 45 | 46 | template 47 | cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, 48 | cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, 49 | const scalar_t* A, int lda, const scalar_t* B, int ldb, 50 | const scalar_t* beta, scalar_t* C, int ldc); 51 | 52 | template 53 | __device__ __forceinline__ scalar_t bilinear_interpolate(const scalar_t* __restrict__ input, 54 | const int height, const int width, 55 | scalar_t y, scalar_t x) { 56 | // deal with cases that inverse elements are out of feature map boundary 57 | if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; 58 | 59 | y = min(scalar_t(height - 1), max(scalar_t(0), y)); 60 | x = min(scalar_t(width - 1), max(scalar_t(0), x)); 61 | 62 | const int y_low = floor(y); 63 | const int x_low = floor(x); 64 | const int y_high = ceil(y); 65 | const int x_high = ceil(x); 66 | 67 | const scalar_t v1 = input[y_low * width + x_low]; 68 | const scalar_t v2 = input[y_low * width + x_high]; 69 | const scalar_t v3 = input[y_high * width + x_low]; 70 | const scalar_t v4 = input[y_high * width + x_high]; 71 | 72 | // lerp can be performed by fma 73 | const scalar_t ly = y - y_low; 74 | const scalar_t lx = x - x_low; 75 | const scalar_t v_low = fma(v2 - v1, lx, v1); 76 | const scalar_t v_high = fma(v4 - v3, lx, v3); 77 | const scalar_t val = fma(v_high - v_low, ly, v_low); 78 | 79 | return val; 80 | } 81 | 82 | #endif // COMMON_CUDA_HELPER 83 | -------------------------------------------------------------------------------- /linux_cc/plugin/common/trt_plugin_base.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef TRT_PLUGIN_BASE_HPP 3 | #define TRT_PLUGIN_BASE_HPP 4 | #include "NvInferRuntime.h" 5 | #include "NvInferVersion.h" 6 | #include "trt_plugin_helper.hpp" 7 | 8 | namespace mmdeploy { 9 | 10 | #if NV_TENSORRT_MAJOR > 7 11 | #define TRT_NOEXCEPT noexcept 12 | #else 13 | #define TRT_NOEXCEPT 14 | #endif 15 | 16 | class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt { 17 | public: 18 | TRTPluginBase(const std::string &name) : mLayerName(name) {} 19 | // IPluginV2 Methods 20 | const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; } 21 | int initialize() TRT_NOEXCEPT override { return STATUS_SUCCESS; } 22 | void terminate() TRT_NOEXCEPT override {} 23 | void destroy() TRT_NOEXCEPT override { delete this; } 24 | void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { 25 | mNamespace = pluginNamespace; 26 | } 27 | const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } 28 | 29 | virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, 30 | const nvinfer1::DynamicPluginTensorDesc *out, 31 | int nbOutputs) TRT_NOEXCEPT override {} 32 | 33 | virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 34 | const nvinfer1::PluginTensorDesc *outputs, 35 | int nbOutputs) const TRT_NOEXCEPT override { 36 | return 0; 37 | } 38 | 39 | virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, 40 | nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {} 41 | 42 | virtual void detachFromContext() TRT_NOEXCEPT override {} 43 | 44 | protected: 45 | const std::string mLayerName; 46 | std::string mNamespace; 47 | 48 | #if NV_TENSORRT_MAJOR < 8 49 | protected: 50 | // To prevent compiler warnings. 51 | using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; 52 | using nvinfer1::IPluginV2DynamicExt::enqueue; 53 | using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; 54 | using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; 55 | using nvinfer1::IPluginV2DynamicExt::supportsFormat; 56 | #endif 57 | }; 58 | 59 | class TRTPluginCreatorBase : public nvinfer1::IPluginCreator { 60 | public: 61 | const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }; 62 | 63 | const nvinfer1::PluginFieldCollection *getFieldNames() TRT_NOEXCEPT override { return &mFC; } 64 | 65 | void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { 66 | mNamespace = pluginNamespace; 67 | } 68 | 69 | const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } 70 | 71 | protected: 72 | nvinfer1::PluginFieldCollection mFC; 73 | std::vector mPluginAttributes; 74 | std::string mNamespace; 75 | }; 76 | } // namespace mmdeploy 77 | #endif 78 | -------------------------------------------------------------------------------- /linux_cc/plugin/common/trt_plugin_helper.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef TRT_PLUGIN_HELPER_HPP 3 | #define TRT_PLUGIN_HELPER_HPP 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "NvInferRuntime.h" 10 | 11 | cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t* cudnn_dtype); 12 | 13 | // Enumerator for status 14 | typedef enum { 15 | STATUS_SUCCESS = 0, 16 | STATUS_FAILURE = 1, 17 | STATUS_BAD_PARAM = 2, 18 | STATUS_NOT_SUPPORTED = 3, 19 | STATUS_NOT_INITIALIZED = 4 20 | } pluginStatus_t; 21 | 22 | #define ASSERT(assertion) \ 23 | { \ 24 | if (!(assertion)) { \ 25 | std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \ 26 | abort(); \ 27 | } \ 28 | } 29 | 30 | #define CUASSERT(status_) \ 31 | { \ 32 | auto s_ = status_; \ 33 | if (s_ != cudaSuccess) { \ 34 | std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " << cudaGetErrorString(s_) \ 35 | << std::endl; \ 36 | } \ 37 | } 38 | #define CUBLASASSERT(status_) \ 39 | { \ 40 | auto s_ = status_; \ 41 | if (s_ != CUBLAS_STATUS_SUCCESS) { \ 42 | std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ 43 | } \ 44 | } 45 | #define CUERRORMSG(status_) \ 46 | { \ 47 | auto s_ = status_; \ 48 | if (s_ != 0) std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ 49 | } 50 | 51 | #ifndef DEBUG 52 | 53 | #define CHECK(status) \ 54 | do { \ 55 | if (status != 0) abort(); \ 56 | } while (0) 57 | 58 | #define ASSERT_PARAM(exp) \ 59 | do { \ 60 | if (!(exp)) return STATUS_BAD_PARAM; \ 61 | } while (0) 62 | 63 | #define ASSERT_FAILURE(exp) \ 64 | do { \ 65 | if (!(exp)) return STATUS_FAILURE; \ 66 | } while (0) 67 | 68 | #define CSC(call, err) \ 69 | do { \ 70 | cudaError_t cudaStatus = call; \ 71 | if (cudaStatus != cudaSuccess) { \ 72 | return err; \ 73 | } \ 74 | } while (0) 75 | 76 | #define DEBUG_PRINTF(...) \ 77 | do { \ 78 | } while (0) 79 | 80 | #else 81 | 82 | #define ASSERT_PARAM(exp) \ 83 | do { \ 84 | if (!(exp)) { \ 85 | fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ 86 | return STATUS_BAD_PARAM; \ 87 | } \ 88 | } while (0) 89 | 90 | #define ASSERT_FAILURE(exp) \ 91 | do { \ 92 | if (!(exp)) { \ 93 | fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ 94 | return STATUS_FAILURE; \ 95 | } \ 96 | } while (0) 97 | 98 | #define CSC(call, err) \ 99 | do { \ 100 | cudaError_t cudaStatus = call; \ 101 | if (cudaStatus != cudaSuccess) { \ 102 | printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ 103 | return err; \ 104 | } \ 105 | } while (0) 106 | 107 | #define CHECK(status) \ 108 | { \ 109 | if (status != 0) { \ 110 | DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ 111 | abort(); \ 112 | } \ 113 | } 114 | 115 | #define DEBUG_PRINTF(...) \ 116 | do { \ 117 | printf(__VA_ARGS__); \ 118 | } while (0) 119 | 120 | #endif 121 | 122 | namespace mmdeploy { 123 | 124 | const int MAXTENSORDIMS = 10; 125 | 126 | struct TensorDesc { 127 | int shape[MAXTENSORDIMS]; 128 | int stride[MAXTENSORDIMS]; 129 | int dim; 130 | }; 131 | 132 | inline unsigned int getElementSize(nvinfer1::DataType t) { 133 | switch (t) { 134 | case nvinfer1::DataType::kINT32: 135 | return 4; 136 | case nvinfer1::DataType::kFLOAT: 137 | return 4; 138 | case nvinfer1::DataType::kHALF: 139 | return 2; 140 | // case nvinfer1::DataType::kBOOL: 141 | case nvinfer1::DataType::kINT8: 142 | return 1; 143 | default: 144 | throw std::runtime_error("Invalid DataType."); 145 | } 146 | throw std::runtime_error("Invalid DataType."); 147 | return 0; 148 | } 149 | 150 | inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) { 151 | return size_t((origin_size + aligned_number - 1) / aligned_number) * aligned_number; 152 | } 153 | 154 | } // namespace mmdeploy 155 | #endif // TRT_PLUGIN_HELPER_HPP 156 | -------------------------------------------------------------------------------- /linux_cc/plugin/common/trt_serialize.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // Modified from: 3 | // https://github.com/NVIDIA/TensorRT/blob/master/plugin/common/serialize.hpp 4 | 5 | #ifndef TRT_SERIALIZE_HPP 6 | #define TRT_SERIALIZE_HPP 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | template 13 | inline void serialize_value(void** buffer, T const& value); 14 | 15 | template 16 | inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value); 17 | 18 | namespace { 19 | 20 | template 21 | struct Serializer {}; 22 | 23 | template 24 | struct Serializer::value || std::is_enum::value || 26 | std::is_pod::value>::type> { 27 | static size_t serialized_size(T const& value) { return sizeof(T); } 28 | static void serialize(void** buffer, T const& value) { 29 | ::memcpy(*buffer, &value, sizeof(T)); 30 | reinterpret_cast(*buffer) += sizeof(T); 31 | } 32 | static void deserialize(void const** buffer, size_t* buffer_size, T* value) { 33 | assert(*buffer_size >= sizeof(T)); 34 | ::memcpy(value, *buffer, sizeof(T)); 35 | reinterpret_cast(*buffer) += sizeof(T); 36 | *buffer_size -= sizeof(T); 37 | } 38 | }; 39 | 40 | template <> 41 | struct Serializer { 42 | static size_t serialized_size(const char* value) { return strlen(value) + 1; } 43 | static void serialize(void** buffer, const char* value) { 44 | ::strcpy(static_cast(*buffer), value); 45 | reinterpret_cast(*buffer) += strlen(value) + 1; 46 | } 47 | static void deserialize(void const** buffer, size_t* buffer_size, const char** value) { 48 | *value = static_cast(*buffer); 49 | size_t data_size = strnlen(*value, *buffer_size) + 1; 50 | assert(*buffer_size >= data_size); 51 | reinterpret_cast(*buffer) += data_size; 52 | *buffer_size -= data_size; 53 | } 54 | }; 55 | 56 | template 57 | struct Serializer, 58 | typename std::enable_if::value || std::is_enum::value || 59 | std::is_pod::value>::type> { 60 | static size_t serialized_size(std::vector const& value) { 61 | return sizeof(value.size()) + value.size() * sizeof(T); 62 | } 63 | static void serialize(void** buffer, std::vector const& value) { 64 | serialize_value(buffer, value.size()); 65 | size_t nbyte = value.size() * sizeof(T); 66 | ::memcpy(*buffer, value.data(), nbyte); 67 | reinterpret_cast(*buffer) += nbyte; 68 | } 69 | static void deserialize(void const** buffer, size_t* buffer_size, std::vector* value) { 70 | size_t size; 71 | deserialize_value(buffer, buffer_size, &size); 72 | value->resize(size); 73 | size_t nbyte = value->size() * sizeof(T); 74 | assert(*buffer_size >= nbyte); 75 | ::memcpy(value->data(), *buffer, nbyte); 76 | reinterpret_cast(*buffer) += nbyte; 77 | *buffer_size -= nbyte; 78 | } 79 | }; 80 | 81 | } // namespace 82 | 83 | template 84 | inline size_t serialized_size(T const& value) { 85 | return Serializer::serialized_size(value); 86 | } 87 | 88 | template 89 | inline void serialize_value(void** buffer, T const& value) { 90 | return Serializer::serialize(buffer, value); 91 | } 92 | 93 | template 94 | inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value) { 95 | return Serializer::deserialize(buffer, buffer_size, value); 96 | } 97 | #endif // TRT_SERIALIZE_HPP 98 | -------------------------------------------------------------------------------- /linux_cc/plugin/trt_grid_sampler_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // modified from 3 | // https://github.com/pytorch/pytorch/blob/ec683299ebabf297a3504c76248d37be830e4342/aten/src/ATen/native/cuda/GridSampler.cuh 4 | // and 5 | // https://github.com/pytorch/pytorch/blob/ec683299ebabf297a3504c76248d37be830e4342/aten/src/ATen/native/cuda/GridSampler.cu 6 | 7 | 8 | // Copyright (c) OpenMMLab. All rights reserved. 9 | //#include "trt_grid_sampler.hpp" 10 | 11 | #include 12 | 13 | #include 14 | 15 | //#include "trt_grid_sampler_kernel.hpp" 16 | //#include "trt_plugin_helper.hpp" 17 | #include "trt_serialize.hpp" 18 | 19 | #include 20 | #include 21 | 22 | #include 23 | #include 24 | #include 25 | 26 | #include "common_cuda_helper.hpp" 27 | #include "trt_grid_sampler_kernel.hpp" 28 | #include "trt_plugin_helper.hpp" 29 | 30 | using mmdeploy::TensorDesc; 31 | 32 | // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, 33 | // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). 34 | // if align_corners: -1 and +1 get sent to the centers of the corner pixels 35 | // -1 --> 0 36 | // +1 --> (size - 1) 37 | // scale_factor = (size - 1) / 2 38 | // if not align_corners: -1 and +1 get sent to the image edges 39 | // -1 --> -0.5 40 | // +1 --> (size - 1) + 0.5 == size - 0.5 41 | // scale_factor = size / 2 42 | template 43 | static __forceinline__ __device__ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, 44 | bool align_corners) { 45 | if (align_corners) { 46 | // unnormalize coord from [-1, 1] to [0, size - 1] 47 | return ((coord + 1.f) / 2) * (size - 1); 48 | } else { 49 | // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] 50 | return ((coord + 1.f) * size - 1) / 2; 51 | } 52 | } 53 | 54 | // Clips coordinates to between 0 and clip_limit - 1 55 | template 56 | static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, int clip_limit) { 57 | return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); 58 | } 59 | 60 | // Reflects coordinates until they fall between low and high (inclusive). 61 | // The bounds are passed as twice their value so that half-integer values 62 | // can be represented as ints. 63 | template 64 | static __forceinline__ __device__ scalar_t reflect_coordinates(scalar_t in, int twice_low, 65 | int twice_high) { 66 | if (twice_low == twice_high) { 67 | return static_cast(0); 68 | } 69 | scalar_t min = static_cast(twice_low) / 2; 70 | scalar_t span = static_cast(twice_high - twice_low) / 2; 71 | in = ::fabs(in - min); 72 | // `fmod` returns same sign as `in`, which is positive after the `fabs` above. 73 | scalar_t extra = ::fmod(in, span); 74 | int flips = static_cast(::floor(in / span)); 75 | if (flips % 2 == 0) { 76 | return extra + min; 77 | } else { 78 | return span - extra + min; 79 | } 80 | } 81 | 82 | template 83 | static __forceinline__ __device__ scalar_t safe_downgrade_to_int_range(scalar_t x) { 84 | // -100.0 does not have special meaning. This is just to make sure 85 | // it's not within_bounds_2d or within_bounds_3d, and does not cause 86 | // undefined behavior. See #35506. 87 | if (x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x))) 88 | return static_cast(-100.0); 89 | return x; 90 | } 91 | 92 | // Computes the pixel source index value for a grid coordinate 93 | template 94 | static __forceinline__ __device__ scalar_t grid_sampler_compute_source_index( 95 | scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) { 96 | coord = grid_sampler_unnormalize(coord, size, align_corners); 97 | if (padding_mode == GridSamplerPadding::Border) { 98 | // clip coordinates to image borders 99 | coord = clip_coordinates(coord, size); 100 | } else if (padding_mode == GridSamplerPadding::Reflection) { 101 | // reflect coordinates by image borders 102 | if (align_corners) { 103 | coord = reflect_coordinates(coord, 0, 2 * (size - 1)); 104 | } else { 105 | coord = reflect_coordinates(coord, -1, 2 * size - 1); 106 | } 107 | // clip coordinates to image borders 108 | coord = clip_coordinates(coord, size); 109 | } 110 | 111 | coord = safe_downgrade_to_int_range(coord); 112 | return coord; 113 | } 114 | 115 | static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) { 116 | return h >= 0 && h < H && w >= 0 && w < W; 117 | } 118 | 119 | static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { 120 | return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; 121 | } 122 | 123 | template 124 | __global__ void grid_sampler_2d_kernel(const int nthreads, const scalar_t *input, 125 | const scalar_t *grid, scalar_t *output, 126 | TensorDesc input_desc, TensorDesc grid_desc, 127 | TensorDesc output_desc, 128 | const GridSamplerInterpolation interpolation_mode, 129 | const GridSamplerPadding padding_mode, bool align_corners) { 130 | int C = input_desc.shape[1]; 131 | int inp_H = input_desc.shape[2]; 132 | int inp_W = input_desc.shape[3]; 133 | int out_H = grid_desc.shape[1]; 134 | int out_W = grid_desc.shape[2]; 135 | int inp_sN = input_desc.stride[0]; 136 | int inp_sC = input_desc.stride[1]; 137 | int inp_sH = input_desc.stride[2]; 138 | int inp_sW = input_desc.stride[3]; 139 | int grid_sN = grid_desc.stride[0]; 140 | int grid_sH = grid_desc.stride[1]; 141 | int grid_sW = grid_desc.stride[2]; 142 | int grid_sCoor = grid_desc.stride[3]; 143 | int out_sN = output_desc.stride[0]; 144 | int out_sC = output_desc.stride[1]; 145 | int out_sH = output_desc.stride[2]; 146 | int out_sW = output_desc.stride[3]; 147 | 148 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 149 | const int w = index % out_W; 150 | const int h = (index / out_W) % out_H; 151 | const int n = index / (out_H * out_W); 152 | const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; 153 | 154 | // get the corresponding input x, y coordinates from grid 155 | scalar_t ix = grid[grid_offset]; 156 | scalar_t iy = grid[grid_offset + grid_sCoor]; 157 | 158 | ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); 159 | iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); 160 | 161 | if (interpolation_mode == GridSamplerInterpolation::Bilinear) { 162 | // get NE, NW, SE, SW pixel values from (x, y) 163 | int ix_nw = static_cast(::floor(ix)); 164 | int iy_nw = static_cast(::floor(iy)); 165 | int ix_ne = ix_nw + 1; 166 | int iy_ne = iy_nw; 167 | int ix_sw = ix_nw; 168 | int iy_sw = iy_nw + 1; 169 | int ix_se = ix_nw + 1; 170 | int iy_se = iy_nw + 1; 171 | 172 | // get surfaces to each neighbor: 173 | scalar_t nw = (ix_se - ix) * (iy_se - iy); 174 | scalar_t ne = (ix - ix_sw) * (iy_sw - iy); 175 | scalar_t sw = (ix_ne - ix) * (iy - iy_ne); 176 | scalar_t se = (ix - ix_nw) * (iy - iy_nw); 177 | 178 | // calculate bilinear weighted pixel value and set output pixel 179 | auto inp_ptr_NC = input + n * inp_sN; 180 | auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; 181 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { 182 | *out_ptr_NCHW = static_cast(0); 183 | if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { 184 | *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; 185 | } 186 | if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { 187 | *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; 188 | } 189 | if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { 190 | *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; 191 | } 192 | if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { 193 | *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; 194 | } 195 | } 196 | } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { 197 | int ix_nearest = static_cast(::round(ix)); 198 | int iy_nearest = static_cast(::round(iy)); 199 | 200 | // assign nearest neighbor pixel value to output pixel 201 | auto inp_ptr_NC = input + n * inp_sN; 202 | auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; 203 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { 204 | if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { 205 | *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; 206 | } else { 207 | *out_ptr_NCHW = static_cast(0); 208 | } 209 | } 210 | } 211 | } 212 | } 213 | 214 | template 215 | __global__ void grid_sampler_3d_kernel(const int nthreads, const scalar_t *input, 216 | const scalar_t *grid, scalar_t *output, 217 | TensorDesc input_desc, TensorDesc grid_desc, 218 | TensorDesc output_desc, 219 | const GridSamplerInterpolation interpolation_mode, 220 | const GridSamplerPadding padding_mode, bool align_corners) { 221 | int C = input_desc.shape[1]; 222 | int inp_D = input_desc.shape[2]; 223 | int inp_H = input_desc.shape[3]; 224 | int inp_W = input_desc.shape[4]; 225 | int out_D = grid_desc.shape[1]; 226 | int out_H = grid_desc.shape[2]; 227 | int out_W = grid_desc.shape[3]; 228 | int inp_sN = input_desc.stride[0]; 229 | int inp_sC = input_desc.stride[1]; 230 | int inp_sD = input_desc.stride[2]; 231 | int inp_sH = input_desc.stride[3]; 232 | int inp_sW = input_desc.stride[4]; 233 | int grid_sN = grid_desc.stride[0]; 234 | int grid_sD = grid_desc.stride[1]; 235 | int grid_sH = grid_desc.stride[2]; 236 | int grid_sW = grid_desc.stride[3]; 237 | int grid_sCoor = grid_desc.stride[4]; 238 | int out_sN = output_desc.stride[0]; 239 | int out_sC = output_desc.stride[1]; 240 | int out_sD = output_desc.stride[2]; 241 | int out_sH = output_desc.stride[3]; 242 | int out_sW = output_desc.stride[4]; 243 | 244 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 245 | const int w = index % out_W; 246 | const int h = (index / out_W) % out_H; 247 | const int d = (index / (out_H * out_W)) % out_D; 248 | const int n = index / (out_D * out_H * out_W); 249 | const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; 250 | 251 | // get the corresponding input x, y, z coordinates from grid 252 | scalar_t ix = grid[grid_offset]; 253 | scalar_t iy = grid[grid_offset + grid_sCoor]; 254 | scalar_t iz = grid[grid_offset + 2 * grid_sCoor]; 255 | 256 | ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); 257 | iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); 258 | iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); 259 | 260 | if (interpolation_mode == GridSamplerInterpolation::Bilinear) { 261 | // get corner pixel values from (x, y, z) 262 | // for 4d, we used north-east-south-west 263 | // for 5d, we add top-bottom 264 | int ix_tnw = static_cast(::floor(ix)); 265 | int iy_tnw = static_cast(::floor(iy)); 266 | int iz_tnw = static_cast(::floor(iz)); 267 | 268 | int ix_tne = ix_tnw + 1; 269 | int iy_tne = iy_tnw; 270 | int iz_tne = iz_tnw; 271 | 272 | int ix_tsw = ix_tnw; 273 | int iy_tsw = iy_tnw + 1; 274 | int iz_tsw = iz_tnw; 275 | 276 | int ix_tse = ix_tnw + 1; 277 | int iy_tse = iy_tnw + 1; 278 | int iz_tse = iz_tnw; 279 | 280 | int ix_bnw = ix_tnw; 281 | int iy_bnw = iy_tnw; 282 | int iz_bnw = iz_tnw + 1; 283 | 284 | int ix_bne = ix_tnw + 1; 285 | int iy_bne = iy_tnw; 286 | int iz_bne = iz_tnw + 1; 287 | 288 | int ix_bsw = ix_tnw; 289 | int iy_bsw = iy_tnw + 1; 290 | int iz_bsw = iz_tnw + 1; 291 | 292 | int ix_bse = ix_tnw + 1; 293 | int iy_bse = iy_tnw + 1; 294 | int iz_bse = iz_tnw + 1; 295 | 296 | // get surfaces to each neighbor: 297 | scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); 298 | scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); 299 | scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); 300 | scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); 301 | scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); 302 | scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); 303 | scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); 304 | scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); 305 | 306 | auto inp_ptr_NC = input + n * inp_sN; 307 | auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; 308 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { 309 | // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * 310 | // tne 311 | // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * 312 | // tse 313 | // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * 314 | // bne 315 | // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * 316 | // bse 317 | *out_ptr_NCDHW = static_cast(0); 318 | if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { 319 | *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; 320 | } 321 | if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { 322 | *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; 323 | } 324 | if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { 325 | *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; 326 | } 327 | if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { 328 | *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; 329 | } 330 | if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { 331 | *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; 332 | } 333 | if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { 334 | *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; 335 | } 336 | if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { 337 | *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; 338 | } 339 | if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { 340 | *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; 341 | } 342 | } 343 | } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { 344 | int ix_nearest = static_cast(::round(ix)); 345 | int iy_nearest = static_cast(::round(iy)); 346 | int iz_nearest = static_cast(::round(iz)); 347 | 348 | // assign nearest neighbor pixel value to output pixel 349 | auto inp_ptr_NC = input + n * inp_sN; 350 | auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; 351 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { 352 | if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { 353 | *out_ptr_NCDHW = 354 | inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; 355 | } else { 356 | *out_ptr_NCDHW = static_cast(0); 357 | } 358 | } 359 | } 360 | } 361 | } 362 | 363 | void create_desc(const int *dims, int nb_dims, TensorDesc &desc) { 364 | memcpy(&desc.shape[0], dims, sizeof(int) * nb_dims); 365 | desc.stride[nb_dims - 1] = 1; 366 | for (int i = nb_dims - 2; i >= 0; --i) { 367 | desc.stride[i] = desc.stride[i + 1] * desc.shape[i + 1]; 368 | } 369 | } 370 | 371 | template 372 | void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, 373 | int *grid_dims, int nb_dims, GridSamplerInterpolation interp, 374 | GridSamplerPadding padding, bool align_corners, cudaStream_t stream) { 375 | TensorDesc input_desc; 376 | create_desc(input_dims, nb_dims, input_desc); 377 | 378 | TensorDesc output_desc; 379 | create_desc(output_dims, nb_dims, output_desc); 380 | 381 | TensorDesc grid_desc; 382 | create_desc(grid_dims, nb_dims, grid_desc); 383 | 384 | int count = 1; 385 | for (int i = 0; i < nb_dims; ++i) { 386 | if (i == 1) { 387 | continue; 388 | } 389 | count *= output_desc.shape[i]; 390 | } 391 | 392 | if (nb_dims == 4) { 393 | grid_sampler_2d_kernel<<>>( 394 | count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, 395 | align_corners); 396 | } else if (nb_dims == 5) { 397 | grid_sampler_3d_kernel<<>>( 398 | count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, 399 | align_corners); 400 | } else { 401 | printf("input and grid dims should be 4 or 5\n"); 402 | } 403 | } 404 | 405 | template void grid_sample(float *output, const float *input, const float *grid, 406 | int *output_dims, int *input_dims, int *grid_dims, int nb_dims, 407 | GridSamplerInterpolation interp, GridSamplerPadding padding, 408 | bool align_corners, cudaStream_t stream); 409 | 410 | 411 | //------------------plugin in 412 | 413 | namespace mmdeploy { 414 | namespace { 415 | static const char *PLUGIN_VERSION{"1"}; 416 | static const char *PLUGIN_NAME{"grid_sampler"}; 417 | } // namespace 418 | 419 | TRTGridSampler::TRTGridSampler(const std::string &name, int mode, int paddingMode, 420 | bool alignCorners) 421 | : TRTPluginBase(name), mMode(mode), mPaddingMode(paddingMode), mAlignCorners(alignCorners) {} 422 | 423 | TRTGridSampler::TRTGridSampler(const std::string name, const void *data, size_t length) 424 | : TRTPluginBase(name) { 425 | deserialize_value(&data, &length, &mMode); 426 | deserialize_value(&data, &length, &mPaddingMode); 427 | deserialize_value(&data, &length, &mAlignCorners); 428 | } 429 | 430 | nvinfer1::IPluginV2DynamicExt *TRTGridSampler::clone() const TRT_NOEXCEPT { 431 | TRTGridSampler *plugin = new TRTGridSampler(mLayerName, mMode, mPaddingMode, mAlignCorners); 432 | plugin->setPluginNamespace(getPluginNamespace()); 433 | 434 | return plugin; 435 | } 436 | 437 | nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions( 438 | int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, 439 | nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { 440 | nvinfer1::DimsExprs ret; 441 | ret.nbDims = inputs[0].nbDims; 442 | ret.d[0] = inputs[0].d[0]; 443 | ret.d[1] = inputs[0].d[1]; 444 | for (int i = 2; i < ret.nbDims; ++i) { 445 | ret.d[i] = inputs[1].d[i - 1]; 446 | } 447 | return ret; 448 | } 449 | 450 | bool TRTGridSampler::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, 451 | int nbInputs, int nbOutputs) TRT_NOEXCEPT { 452 | if (pos == 0) { 453 | return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && 454 | ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); 455 | } else { 456 | return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; 457 | } 458 | } 459 | 460 | void TRTGridSampler::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, 461 | const nvinfer1::DynamicPluginTensorDesc *outputs, 462 | int nbOutputs) TRT_NOEXCEPT { 463 | // Validate input arguments 464 | } 465 | 466 | size_t TRTGridSampler::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 467 | const nvinfer1::PluginTensorDesc *outputs, 468 | int nbOutputs) const TRT_NOEXCEPT { 469 | return 0; 470 | } 471 | 472 | int TRTGridSampler::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, 473 | const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, 474 | void *const *outputs, void *workSpace, 475 | cudaStream_t stream) TRT_NOEXCEPT { 476 | nvinfer1::Dims input_dims = inputDesc[0].dims; 477 | nvinfer1::Dims grid_dims = inputDesc[1].dims; 478 | nvinfer1::Dims output_dims = outputDesc[0].dims; 479 | 480 | GridSamplerInterpolation interp_mode = GridSamplerInterpolation::Bilinear; 481 | switch (mMode) { 482 | case 0: 483 | interp_mode = GridSamplerInterpolation::Bilinear; 484 | break; 485 | case 1: 486 | interp_mode = GridSamplerInterpolation::Nearest; 487 | break; 488 | default: 489 | break; 490 | } 491 | 492 | GridSamplerPadding padding_mode = GridSamplerPadding::Zeros; 493 | switch (mPaddingMode) { 494 | case 0: 495 | padding_mode = GridSamplerPadding::Zeros; 496 | break; 497 | 498 | case 1: 499 | padding_mode = GridSamplerPadding::Border; 500 | break; 501 | 502 | case 2: 503 | padding_mode = GridSamplerPadding::Reflection; 504 | break; 505 | default: 506 | break; 507 | } 508 | 509 | auto data_type = inputDesc[0].type; 510 | 511 | switch (data_type) { 512 | case nvinfer1::DataType::kFLOAT: 513 | grid_sample((float *)outputs[0], (float *)inputs[0], (float *)inputs[1], 514 | &(output_dims.d[0]), &(input_dims.d[0]), &(grid_dims.d[0]), 515 | input_dims.nbDims, interp_mode, padding_mode, mAlignCorners, stream); 516 | break; 517 | default: 518 | return 1; 519 | break; 520 | } 521 | 522 | return 0; 523 | } 524 | 525 | nvinfer1::DataType TRTGridSampler::getOutputDataType(int index, 526 | const nvinfer1::DataType *inputTypes, 527 | int nbInputs) const TRT_NOEXCEPT { 528 | return inputTypes[0]; 529 | } 530 | 531 | // IPluginV2 Methods 532 | const char *TRTGridSampler::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } 533 | 534 | const char *TRTGridSampler::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } 535 | 536 | int TRTGridSampler::getNbOutputs() const TRT_NOEXCEPT { return 1; } 537 | 538 | size_t TRTGridSampler::getSerializationSize() const TRT_NOEXCEPT { 539 | return serialized_size(mMode) + serialized_size(mPaddingMode) + serialized_size(mAlignCorners); 540 | } 541 | 542 | void TRTGridSampler::serialize(void *buffer) const TRT_NOEXCEPT { 543 | serialize_value(&buffer, mMode); 544 | serialize_value(&buffer, mPaddingMode); 545 | serialize_value(&buffer, mAlignCorners); 546 | } 547 | 548 | ////////////////////// creator ///////////////////////////// 549 | 550 | TRTGridSamplerCreator::TRTGridSamplerCreator() { 551 | mPluginAttributes = std::vector( 552 | {nvinfer1::PluginField("interpolation_mode"), nvinfer1::PluginField("padding_mode"), 553 | nvinfer1::PluginField("align_corners")}); 554 | mFC.nbFields = mPluginAttributes.size(); 555 | mFC.fields = mPluginAttributes.data(); 556 | } 557 | 558 | const char *TRTGridSamplerCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } 559 | 560 | const char *TRTGridSamplerCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } 561 | 562 | nvinfer1::IPluginV2 *TRTGridSamplerCreator::createPlugin( 563 | const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { 564 | int mode = 0; 565 | int paddingMode = 0; 566 | bool alignCorners = false; 567 | 568 | for (int i = 0; i < fc->nbFields; i++) { 569 | if (fc->fields[i].data == nullptr) { 570 | continue; 571 | } 572 | std::string field_name(fc->fields[i].name); 573 | 574 | if (field_name.compare("interpolation_mode") == 0) { 575 | mode = static_cast(fc->fields[i].data)[0]; 576 | } 577 | 578 | if (field_name.compare("padding_mode") == 0) { 579 | paddingMode = static_cast(fc->fields[i].data)[0]; 580 | } 581 | 582 | if (field_name.compare("align_corners") == 0) { 583 | alignCorners = (bool)(static_cast(fc->fields[i].data)[0]); 584 | } 585 | } 586 | 587 | TRTGridSampler *plugin = new TRTGridSampler(name, mode, paddingMode, alignCorners); 588 | plugin->setPluginNamespace(getPluginNamespace()); 589 | return plugin; 590 | } 591 | 592 | nvinfer1::IPluginV2 *TRTGridSamplerCreator::deserializePlugin(const char *name, 593 | const void *serialData, 594 | size_t serialLength) TRT_NOEXCEPT { 595 | // This object will be deleted when the network is destroyed, which will 596 | // call FCPluginDynamic::destroy() 597 | auto plugin = new TRTGridSampler(name, serialData, serialLength); 598 | plugin->setPluginNamespace(getPluginNamespace()); 599 | return plugin; 600 | } 601 | 602 | REGISTER_TENSORRT_PLUGIN(TRTGridSamplerCreator); 603 | } // namespace mmdeploy 604 | -------------------------------------------------------------------------------- /linux_cc/plugin/trt_grid_sampler_kernel.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef TRT_GRID_SAMPLER_KERNEL_HPP 3 | #define TRT_GRID_SAMPLER_KERNEL_HPP 4 | #include 5 | 6 | #include 7 | 8 | #include 9 | #include 10 | #include 11 | 12 | #include "trt_plugin_base.hpp" 13 | 14 | enum class GridSamplerInterpolation { Bilinear, Nearest }; 15 | enum class GridSamplerPadding { Zeros, Border, Reflection }; 16 | 17 | template 18 | void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, 19 | int *grid_dims, int nb_dims, GridSamplerInterpolation interp, 20 | GridSamplerPadding padding, bool align_corners, cudaStream_t stream); 21 | 22 | 23 | namespace mmdeploy { 24 | 25 | class TRTGridSampler : public TRTPluginBase { 26 | public: 27 | TRTGridSampler(const std::string &name, int mode, int paddingMode, bool alignCorners); 28 | 29 | TRTGridSampler(const std::string name, const void *data, size_t length); 30 | 31 | TRTGridSampler() = delete; 32 | 33 | ~TRTGridSampler() TRT_NOEXCEPT override = default; 34 | 35 | // IPluginV2DynamicExt Methods 36 | nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; 37 | 38 | nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, 39 | int nbInputs, nvinfer1::IExprBuilder &exprBuilder) 40 | TRT_NOEXCEPT override; 41 | 42 | bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, 43 | int nbOutputs) TRT_NOEXCEPT override; 44 | 45 | void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, 46 | const nvinfer1::DynamicPluginTensorDesc *out, 47 | int nbOutputs) TRT_NOEXCEPT override; 48 | 49 | size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 50 | const nvinfer1::PluginTensorDesc *outputs, 51 | int nbOutputs) const TRT_NOEXCEPT override; 52 | 53 | int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, 54 | const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, 55 | void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; 56 | 57 | // IPluginV2Ext Methods 58 | nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, 59 | int nbInputs) const TRT_NOEXCEPT override; 60 | 61 | // IPluginV2 Methods 62 | const char *getPluginType() const TRT_NOEXCEPT override; 63 | 64 | const char *getPluginVersion() const TRT_NOEXCEPT override; 65 | 66 | int getNbOutputs() const TRT_NOEXCEPT override; 67 | 68 | size_t getSerializationSize() const TRT_NOEXCEPT override; 69 | 70 | void serialize(void *buffer) const TRT_NOEXCEPT override; 71 | 72 | private: 73 | int mMode; 74 | int mPaddingMode; 75 | bool mAlignCorners; 76 | }; 77 | 78 | class TRTGridSamplerCreator : public TRTPluginCreatorBase { 79 | public: 80 | TRTGridSamplerCreator(); 81 | 82 | ~TRTGridSamplerCreator() TRT_NOEXCEPT override = default; 83 | 84 | const char *getPluginName() const TRT_NOEXCEPT override; 85 | 86 | const char *getPluginVersion() const TRT_NOEXCEPT override; 87 | 88 | nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) 89 | TRT_NOEXCEPT override; 90 | 91 | nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, 92 | size_t serialLength) TRT_NOEXCEPT override; 93 | }; 94 | } // namespace mmdeploy 95 | 96 | 97 | #endif // TRT_GRID_SAMPLER_KERNEL_HPP 98 | -------------------------------------------------------------------------------- /plugin/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.6) 2 | 3 | project(mmdeploy_plugins) 4 | 5 | add_definitions(-std=c++11) 6 | add_definitions(-DAPI_EXPORTS) 7 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 8 | set(CMAKE_CXX_STANDARD 11) 9 | set(CMAKE_BUILD_TYPE Release) 10 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /O2") 11 | add_compile_definitions(WIN32_LEAN_AND_MEAN NOMINMAX) 12 | 13 | find_package(CUDA REQUIRED) 14 | 15 | #if(WIN32) 16 | #enable_language(CUDA) 17 | #endif(WIN32) 18 | 19 | # cuda 20 | set(cuda_inc "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0/include") 21 | set(cuda_lib "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0/lib/x64") 22 | include_directories(${cuda_inc}) 23 | link_directories(${cuda_lib}) 24 | #cub 25 | set(CUB_ROOT_DIR "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0/include/cub") 26 | include_directories(${CUB_ROOT_DIR}) 27 | # tensorrt 28 | set(tensorrt_inc "D:/trt_install/TensorRT-8.5.1.7/include") 29 | set(tensorrt_lib "D:/trt_install/TensorRT-8.5.1.7/lib") 30 | include_directories(${tensorrt_inc}) 31 | link_directories(${tensorrt_lib}) 32 | # opencv 33 | #include_directories("${PROJECT_SOURCE_DIR}/third_party/CV460_64/include") 34 | #set(opencv_lib "${PROJECT_SOURCE_DIR}/third_party/CV460_64/lib/opencv_world460.lib") 35 | 36 | # common files 37 | include_directories(common) 38 | 39 | file(GLOB grid_sampler_src ${PROJECT_SOURCE_DIR}/grid_sampler/*.cpp ${PROJECT_SOURCE_DIR}/grid_sampler/*.cu) 40 | cuda_add_library(trtgrid_sampler SHARED ${grid_sampler_src}) 41 | #cuda_add_library(trtgrid_sampler STATIC ${grid_sampler_src}) 42 | target_link_libraries(trtgrid_sampler nvinfer cudart) 43 | 44 | 45 | #file(GLOB topk_src ${PROJECT_SOURCE_DIR}/gather_topk/*.cpp ${PROJECT_SOURCE_DIR}/gather_topk/*.cu) 46 | #cuda_add_library(trtgather_topk SHARED ${topk_src}) 47 | ##cuda_add_library(trtgather_topk STATIC ${topk_src}) 48 | #target_link_libraries(trtgather_topk nvinfer cudart) 49 | 50 | 51 | if(UNIX) 52 | add_definitions(-O2 -pthread) 53 | endif(UNIX) -------------------------------------------------------------------------------- /plugin/common/common_cuda_helper.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef COMMON_CUDA_HELPER 3 | #define COMMON_CUDA_HELPER 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | 11 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) 13 | 14 | #define THREADS_PER_BLOCK 512 15 | 16 | #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) 17 | inline int GET_BLOCKS(const int N) { 18 | int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); 19 | int max_block_num = 4096; 20 | return std::min(optimal_block_num, max_block_num); 21 | } 22 | 23 | #define cudaCheckError() \ 24 | { \ 25 | cudaError_t e = cudaGetLastError(); \ 26 | if (e != cudaSuccess) { \ 27 | printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ 28 | exit(0); \ 29 | } \ 30 | } 31 | 32 | /** 33 | * Returns a view of the original tensor with its dimensions permuted. 34 | * 35 | * @param[out] dst pointer to the destination tensor 36 | * @param[in] src pointer to the source tensor 37 | * @param[in] src_size shape of the src tensor 38 | * @param[in] permute The desired ordering of dimensions 39 | * @param[in] src_dim dim of src tensor 40 | * @param[in] stream cuda stream handle 41 | */ 42 | template 43 | void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, 44 | cudaStream_t stream = 0); 45 | 46 | template 47 | cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, 48 | cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, 49 | const scalar_t* A, int lda, const scalar_t* B, int ldb, 50 | const scalar_t* beta, scalar_t* C, int ldc); 51 | 52 | template 53 | __device__ __forceinline__ scalar_t bilinear_interpolate(const scalar_t* __restrict__ input, 54 | const int height, const int width, 55 | scalar_t y, scalar_t x) { 56 | // deal with cases that inverse elements are out of feature map boundary 57 | if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; 58 | 59 | y = min(scalar_t(height - 1), max(scalar_t(0), y)); 60 | x = min(scalar_t(width - 1), max(scalar_t(0), x)); 61 | 62 | const int y_low = floor(y); 63 | const int x_low = floor(x); 64 | const int y_high = ceil(y); 65 | const int x_high = ceil(x); 66 | 67 | const scalar_t v1 = input[y_low * width + x_low]; 68 | const scalar_t v2 = input[y_low * width + x_high]; 69 | const scalar_t v3 = input[y_high * width + x_low]; 70 | const scalar_t v4 = input[y_high * width + x_high]; 71 | 72 | // lerp can be performed by fma 73 | const scalar_t ly = y - y_low; 74 | const scalar_t lx = x - x_low; 75 | const scalar_t v_low = fma(v2 - v1, lx, v1); 76 | const scalar_t v_high = fma(v4 - v3, lx, v3); 77 | const scalar_t val = fma(v_high - v_low, ly, v_low); 78 | 79 | return val; 80 | } 81 | 82 | #endif // COMMON_CUDA_HELPER 83 | -------------------------------------------------------------------------------- /plugin/common/trt_plugin_base.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef TRT_PLUGIN_BASE_HPP 3 | #define TRT_PLUGIN_BASE_HPP 4 | #include "NvInferRuntime.h" 5 | #include "NvInferVersion.h" 6 | #include "trt_plugin_helper.hpp" 7 | 8 | namespace mmdeploy { 9 | 10 | #if NV_TENSORRT_MAJOR > 7 11 | #define TRT_NOEXCEPT noexcept 12 | #else 13 | #define TRT_NOEXCEPT 14 | #endif 15 | 16 | class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt { 17 | public: 18 | TRTPluginBase(const std::string &name) : mLayerName(name) {} 19 | // IPluginV2 Methods 20 | const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; } 21 | int initialize() TRT_NOEXCEPT override { return STATUS_SUCCESS; } 22 | void terminate() TRT_NOEXCEPT override {} 23 | void destroy() TRT_NOEXCEPT override { delete this; } 24 | void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { 25 | mNamespace = pluginNamespace; 26 | } 27 | const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } 28 | 29 | virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, 30 | const nvinfer1::DynamicPluginTensorDesc *out, 31 | int nbOutputs) TRT_NOEXCEPT override {} 32 | 33 | virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 34 | const nvinfer1::PluginTensorDesc *outputs, 35 | int nbOutputs) const TRT_NOEXCEPT override { 36 | return 0; 37 | } 38 | 39 | virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, 40 | nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {} 41 | 42 | virtual void detachFromContext() TRT_NOEXCEPT override {} 43 | 44 | protected: 45 | const std::string mLayerName; 46 | std::string mNamespace; 47 | 48 | #if NV_TENSORRT_MAJOR < 8 49 | protected: 50 | // To prevent compiler warnings. 51 | using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; 52 | using nvinfer1::IPluginV2DynamicExt::enqueue; 53 | using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; 54 | using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; 55 | using nvinfer1::IPluginV2DynamicExt::supportsFormat; 56 | #endif 57 | }; 58 | 59 | class TRTPluginCreatorBase : public nvinfer1::IPluginCreator { 60 | public: 61 | const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }; 62 | 63 | const nvinfer1::PluginFieldCollection *getFieldNames() TRT_NOEXCEPT override { return &mFC; } 64 | 65 | void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { 66 | mNamespace = pluginNamespace; 67 | } 68 | 69 | const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } 70 | 71 | protected: 72 | nvinfer1::PluginFieldCollection mFC; 73 | std::vector mPluginAttributes; 74 | std::string mNamespace; 75 | }; 76 | } // namespace mmdeploy 77 | #endif 78 | -------------------------------------------------------------------------------- /plugin/common/trt_plugin_helper.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef TRT_PLUGIN_HELPER_HPP 3 | #define TRT_PLUGIN_HELPER_HPP 4 | #include 5 | 6 | #include 7 | #include 8 | 9 | #include "NvInferRuntime.h" 10 | 11 | cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t* cudnn_dtype); 12 | 13 | // Enumerator for status 14 | typedef enum { 15 | STATUS_SUCCESS = 0, 16 | STATUS_FAILURE = 1, 17 | STATUS_BAD_PARAM = 2, 18 | STATUS_NOT_SUPPORTED = 3, 19 | STATUS_NOT_INITIALIZED = 4 20 | } pluginStatus_t; 21 | 22 | #define ASSERT(assertion) \ 23 | { \ 24 | if (!(assertion)) { \ 25 | std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \ 26 | abort(); \ 27 | } \ 28 | } 29 | 30 | #define CUASSERT(status_) \ 31 | { \ 32 | auto s_ = status_; \ 33 | if (s_ != cudaSuccess) { \ 34 | std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " << cudaGetErrorString(s_) \ 35 | << std::endl; \ 36 | } \ 37 | } 38 | #define CUBLASASSERT(status_) \ 39 | { \ 40 | auto s_ = status_; \ 41 | if (s_ != CUBLAS_STATUS_SUCCESS) { \ 42 | std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ 43 | } \ 44 | } 45 | #define CUERRORMSG(status_) \ 46 | { \ 47 | auto s_ = status_; \ 48 | if (s_ != 0) std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ 49 | } 50 | 51 | #ifndef DEBUG 52 | 53 | #define CHECK(status) \ 54 | do { \ 55 | if (status != 0) abort(); \ 56 | } while (0) 57 | 58 | #define ASSERT_PARAM(exp) \ 59 | do { \ 60 | if (!(exp)) return STATUS_BAD_PARAM; \ 61 | } while (0) 62 | 63 | #define ASSERT_FAILURE(exp) \ 64 | do { \ 65 | if (!(exp)) return STATUS_FAILURE; \ 66 | } while (0) 67 | 68 | #define CSC(call, err) \ 69 | do { \ 70 | cudaError_t cudaStatus = call; \ 71 | if (cudaStatus != cudaSuccess) { \ 72 | return err; \ 73 | } \ 74 | } while (0) 75 | 76 | #define DEBUG_PRINTF(...) \ 77 | do { \ 78 | } while (0) 79 | 80 | #else 81 | 82 | #define ASSERT_PARAM(exp) \ 83 | do { \ 84 | if (!(exp)) { \ 85 | fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ 86 | return STATUS_BAD_PARAM; \ 87 | } \ 88 | } while (0) 89 | 90 | #define ASSERT_FAILURE(exp) \ 91 | do { \ 92 | if (!(exp)) { \ 93 | fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ 94 | return STATUS_FAILURE; \ 95 | } \ 96 | } while (0) 97 | 98 | #define CSC(call, err) \ 99 | do { \ 100 | cudaError_t cudaStatus = call; \ 101 | if (cudaStatus != cudaSuccess) { \ 102 | printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ 103 | return err; \ 104 | } \ 105 | } while (0) 106 | 107 | #define CHECK(status) \ 108 | { \ 109 | if (status != 0) { \ 110 | DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ 111 | abort(); \ 112 | } \ 113 | } 114 | 115 | #define DEBUG_PRINTF(...) \ 116 | do { \ 117 | printf(__VA_ARGS__); \ 118 | } while (0) 119 | 120 | #endif 121 | 122 | namespace mmdeploy { 123 | 124 | const int MAXTENSORDIMS = 10; 125 | 126 | struct TensorDesc { 127 | int shape[MAXTENSORDIMS]; 128 | int stride[MAXTENSORDIMS]; 129 | int dim; 130 | }; 131 | 132 | inline unsigned int getElementSize(nvinfer1::DataType t) { 133 | switch (t) { 134 | case nvinfer1::DataType::kINT32: 135 | return 4; 136 | case nvinfer1::DataType::kFLOAT: 137 | return 4; 138 | case nvinfer1::DataType::kHALF: 139 | return 2; 140 | // case nvinfer1::DataType::kBOOL: 141 | case nvinfer1::DataType::kINT8: 142 | return 1; 143 | default: 144 | throw std::runtime_error("Invalid DataType."); 145 | } 146 | throw std::runtime_error("Invalid DataType."); 147 | return 0; 148 | } 149 | 150 | inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) { 151 | return size_t((origin_size + aligned_number - 1) / aligned_number) * aligned_number; 152 | } 153 | 154 | } // namespace mmdeploy 155 | #endif // TRT_PLUGIN_HELPER_HPP 156 | -------------------------------------------------------------------------------- /plugin/common/trt_serialize.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // Modified from: 3 | // https://github.com/NVIDIA/TensorRT/blob/master/plugin/common/serialize.hpp 4 | 5 | #ifndef TRT_SERIALIZE_HPP 6 | #define TRT_SERIALIZE_HPP 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | template 13 | inline void serialize_value(void** buffer, T const& value); 14 | 15 | template 16 | inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value); 17 | 18 | namespace { 19 | 20 | template 21 | struct Serializer {}; 22 | 23 | template 24 | struct Serializer::value || std::is_enum::value || 26 | std::is_pod::value>::type> { 27 | static size_t serialized_size(T const& value) { return sizeof(T); } 28 | static void serialize(void** buffer, T const& value) { 29 | ::memcpy(*buffer, &value, sizeof(T)); 30 | reinterpret_cast(*buffer) += sizeof(T); 31 | } 32 | static void deserialize(void const** buffer, size_t* buffer_size, T* value) { 33 | assert(*buffer_size >= sizeof(T)); 34 | ::memcpy(value, *buffer, sizeof(T)); 35 | reinterpret_cast(*buffer) += sizeof(T); 36 | *buffer_size -= sizeof(T); 37 | } 38 | }; 39 | 40 | template <> 41 | struct Serializer { 42 | static size_t serialized_size(const char* value) { return strlen(value) + 1; } 43 | static void serialize(void** buffer, const char* value) { 44 | ::strcpy(static_cast(*buffer), value); 45 | reinterpret_cast(*buffer) += strlen(value) + 1; 46 | } 47 | static void deserialize(void const** buffer, size_t* buffer_size, const char** value) { 48 | *value = static_cast(*buffer); 49 | size_t data_size = strnlen(*value, *buffer_size) + 1; 50 | assert(*buffer_size >= data_size); 51 | reinterpret_cast(*buffer) += data_size; 52 | *buffer_size -= data_size; 53 | } 54 | }; 55 | 56 | template 57 | struct Serializer, 58 | typename std::enable_if::value || std::is_enum::value || 59 | std::is_pod::value>::type> { 60 | static size_t serialized_size(std::vector const& value) { 61 | return sizeof(value.size()) + value.size() * sizeof(T); 62 | } 63 | static void serialize(void** buffer, std::vector const& value) { 64 | serialize_value(buffer, value.size()); 65 | size_t nbyte = value.size() * sizeof(T); 66 | ::memcpy(*buffer, value.data(), nbyte); 67 | reinterpret_cast(*buffer) += nbyte; 68 | } 69 | static void deserialize(void const** buffer, size_t* buffer_size, std::vector* value) { 70 | size_t size; 71 | deserialize_value(buffer, buffer_size, &size); 72 | value->resize(size); 73 | size_t nbyte = value->size() * sizeof(T); 74 | assert(*buffer_size >= nbyte); 75 | ::memcpy(value->data(), *buffer, nbyte); 76 | reinterpret_cast(*buffer) += nbyte; 77 | *buffer_size -= nbyte; 78 | } 79 | }; 80 | 81 | } // namespace 82 | 83 | template 84 | inline size_t serialized_size(T const& value) { 85 | return Serializer::serialized_size(value); 86 | } 87 | 88 | template 89 | inline void serialize_value(void** buffer, T const& value) { 90 | return Serializer::serialize(buffer, value); 91 | } 92 | 93 | template 94 | inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value) { 95 | return Serializer::deserialize(buffer, buffer_size, value); 96 | } 97 | #endif // TRT_SERIALIZE_HPP 98 | -------------------------------------------------------------------------------- /plugin/grid_sampler/trt_grid_sampler.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #include "trt_grid_sampler.hpp" 3 | 4 | #include 5 | 6 | #include 7 | 8 | #include "trt_grid_sampler_kernel.hpp" 9 | #include "trt_plugin_helper.hpp" 10 | #include "trt_serialize.hpp" 11 | 12 | namespace mmdeploy { 13 | namespace { 14 | static const char *PLUGIN_VERSION{"1"}; 15 | static const char *PLUGIN_NAME{"grid_sampler"}; 16 | } // namespace 17 | 18 | TRTGridSampler::TRTGridSampler(const std::string &name, int mode, int paddingMode, 19 | bool alignCorners) 20 | : TRTPluginBase(name), mMode(mode), mPaddingMode(paddingMode), mAlignCorners(alignCorners) {} 21 | 22 | TRTGridSampler::TRTGridSampler(const std::string name, const void *data, size_t length) 23 | : TRTPluginBase(name) { 24 | deserialize_value(&data, &length, &mMode); 25 | deserialize_value(&data, &length, &mPaddingMode); 26 | deserialize_value(&data, &length, &mAlignCorners); 27 | } 28 | 29 | nvinfer1::IPluginV2DynamicExt *TRTGridSampler::clone() const TRT_NOEXCEPT { 30 | TRTGridSampler *plugin = new TRTGridSampler(mLayerName, mMode, mPaddingMode, mAlignCorners); 31 | plugin->setPluginNamespace(getPluginNamespace()); 32 | 33 | return plugin; 34 | } 35 | 36 | nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions( 37 | int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, 38 | nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { 39 | nvinfer1::DimsExprs ret; 40 | ret.nbDims = inputs[0].nbDims; 41 | ret.d[0] = inputs[0].d[0]; 42 | ret.d[1] = inputs[0].d[1]; 43 | for (int i = 2; i < ret.nbDims; ++i) { 44 | ret.d[i] = inputs[1].d[i - 1]; 45 | } 46 | return ret; 47 | } 48 | 49 | bool TRTGridSampler::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, 50 | int nbInputs, int nbOutputs) TRT_NOEXCEPT { 51 | if (pos == 0) { 52 | return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && 53 | ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); 54 | } else { 55 | return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; 56 | } 57 | } 58 | 59 | void TRTGridSampler::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, 60 | const nvinfer1::DynamicPluginTensorDesc *outputs, 61 | int nbOutputs) TRT_NOEXCEPT { 62 | // Validate input arguments 63 | } 64 | 65 | size_t TRTGridSampler::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 66 | const nvinfer1::PluginTensorDesc *outputs, 67 | int nbOutputs) const TRT_NOEXCEPT { 68 | return 0; 69 | } 70 | 71 | int TRTGridSampler::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, 72 | const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, 73 | void *const *outputs, void *workSpace, 74 | cudaStream_t stream) TRT_NOEXCEPT { 75 | nvinfer1::Dims input_dims = inputDesc[0].dims; 76 | nvinfer1::Dims grid_dims = inputDesc[1].dims; 77 | nvinfer1::Dims output_dims = outputDesc[0].dims; 78 | 79 | GridSamplerInterpolation interp_mode = GridSamplerInterpolation::Bilinear; 80 | switch (mMode) { 81 | case 0: 82 | interp_mode = GridSamplerInterpolation::Bilinear; 83 | break; 84 | case 1: 85 | interp_mode = GridSamplerInterpolation::Nearest; 86 | break; 87 | default: 88 | break; 89 | } 90 | 91 | GridSamplerPadding padding_mode = GridSamplerPadding::Zeros; 92 | switch (mPaddingMode) { 93 | case 0: 94 | padding_mode = GridSamplerPadding::Zeros; 95 | break; 96 | 97 | case 1: 98 | padding_mode = GridSamplerPadding::Border; 99 | break; 100 | 101 | case 2: 102 | padding_mode = GridSamplerPadding::Reflection; 103 | break; 104 | default: 105 | break; 106 | } 107 | 108 | auto data_type = inputDesc[0].type; 109 | 110 | switch (data_type) { 111 | case nvinfer1::DataType::kFLOAT: 112 | grid_sample((float *)outputs[0], (float *)inputs[0], (float *)inputs[1], 113 | &(output_dims.d[0]), &(input_dims.d[0]), &(grid_dims.d[0]), 114 | input_dims.nbDims, interp_mode, padding_mode, mAlignCorners, stream); 115 | break; 116 | default: 117 | return 1; 118 | break; 119 | } 120 | 121 | return 0; 122 | } 123 | 124 | nvinfer1::DataType TRTGridSampler::getOutputDataType(int index, 125 | const nvinfer1::DataType *inputTypes, 126 | int nbInputs) const TRT_NOEXCEPT { 127 | return inputTypes[0]; 128 | } 129 | 130 | // IPluginV2 Methods 131 | const char *TRTGridSampler::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } 132 | 133 | const char *TRTGridSampler::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } 134 | 135 | int TRTGridSampler::getNbOutputs() const TRT_NOEXCEPT { return 1; } 136 | 137 | size_t TRTGridSampler::getSerializationSize() const TRT_NOEXCEPT { 138 | return serialized_size(mMode) + serialized_size(mPaddingMode) + serialized_size(mAlignCorners); 139 | } 140 | 141 | void TRTGridSampler::serialize(void *buffer) const TRT_NOEXCEPT { 142 | serialize_value(&buffer, mMode); 143 | serialize_value(&buffer, mPaddingMode); 144 | serialize_value(&buffer, mAlignCorners); 145 | } 146 | 147 | ////////////////////// creator ///////////////////////////// 148 | 149 | TRTGridSamplerCreator::TRTGridSamplerCreator() { 150 | mPluginAttributes = std::vector( 151 | {nvinfer1::PluginField("interpolation_mode"), nvinfer1::PluginField("padding_mode"), 152 | nvinfer1::PluginField("align_corners")}); 153 | mFC.nbFields = mPluginAttributes.size(); 154 | mFC.fields = mPluginAttributes.data(); 155 | } 156 | 157 | const char *TRTGridSamplerCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } 158 | 159 | const char *TRTGridSamplerCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } 160 | 161 | nvinfer1::IPluginV2 *TRTGridSamplerCreator::createPlugin( 162 | const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { 163 | int mode = 0; 164 | int paddingMode = 0; 165 | bool alignCorners = false; 166 | 167 | for (int i = 0; i < fc->nbFields; i++) { 168 | if (fc->fields[i].data == nullptr) { 169 | continue; 170 | } 171 | std::string field_name(fc->fields[i].name); 172 | 173 | if (field_name.compare("interpolation_mode") == 0) { 174 | mode = static_cast(fc->fields[i].data)[0]; 175 | } 176 | 177 | if (field_name.compare("padding_mode") == 0) { 178 | paddingMode = static_cast(fc->fields[i].data)[0]; 179 | } 180 | 181 | if (field_name.compare("align_corners") == 0) { 182 | alignCorners = (bool)(static_cast(fc->fields[i].data)[0]); 183 | } 184 | } 185 | 186 | TRTGridSampler *plugin = new TRTGridSampler(name, mode, paddingMode, alignCorners); 187 | plugin->setPluginNamespace(getPluginNamespace()); 188 | return plugin; 189 | } 190 | 191 | nvinfer1::IPluginV2 *TRTGridSamplerCreator::deserializePlugin(const char *name, 192 | const void *serialData, 193 | size_t serialLength) TRT_NOEXCEPT { 194 | // This object will be deleted when the network is destroyed, which will 195 | // call FCPluginDynamic::destroy() 196 | auto plugin = new TRTGridSampler(name, serialData, serialLength); 197 | plugin->setPluginNamespace(getPluginNamespace()); 198 | return plugin; 199 | } 200 | 201 | REGISTER_TENSORRT_PLUGIN(TRTGridSamplerCreator); 202 | } // namespace mmdeploy 203 | -------------------------------------------------------------------------------- /plugin/grid_sampler/trt_grid_sampler.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef TRT_GRID_SAMPLER_HPP 3 | #define TRT_GRID_SAMPLER_HPP 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | #include "trt_plugin_base.hpp" 11 | 12 | namespace mmdeploy { 13 | 14 | class TRTGridSampler : public TRTPluginBase { 15 | public: 16 | TRTGridSampler(const std::string &name, int mode, int paddingMode, bool alignCorners); 17 | 18 | TRTGridSampler(const std::string name, const void *data, size_t length); 19 | 20 | TRTGridSampler() = delete; 21 | 22 | ~TRTGridSampler() TRT_NOEXCEPT override = default; 23 | 24 | // IPluginV2DynamicExt Methods 25 | nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; 26 | 27 | nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, 28 | int nbInputs, nvinfer1::IExprBuilder &exprBuilder) 29 | TRT_NOEXCEPT override; 30 | 31 | bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, 32 | int nbOutputs) TRT_NOEXCEPT override; 33 | 34 | void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, 35 | const nvinfer1::DynamicPluginTensorDesc *out, 36 | int nbOutputs) TRT_NOEXCEPT override; 37 | 38 | size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 39 | const nvinfer1::PluginTensorDesc *outputs, 40 | int nbOutputs) const TRT_NOEXCEPT override; 41 | 42 | int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, 43 | const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, 44 | void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; 45 | 46 | // IPluginV2Ext Methods 47 | nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, 48 | int nbInputs) const TRT_NOEXCEPT override; 49 | 50 | // IPluginV2 Methods 51 | const char *getPluginType() const TRT_NOEXCEPT override; 52 | 53 | const char *getPluginVersion() const TRT_NOEXCEPT override; 54 | 55 | int getNbOutputs() const TRT_NOEXCEPT override; 56 | 57 | size_t getSerializationSize() const TRT_NOEXCEPT override; 58 | 59 | void serialize(void *buffer) const TRT_NOEXCEPT override; 60 | 61 | private: 62 | int mMode; 63 | int mPaddingMode; 64 | bool mAlignCorners; 65 | }; 66 | 67 | class TRTGridSamplerCreator : public TRTPluginCreatorBase { 68 | public: 69 | TRTGridSamplerCreator(); 70 | 71 | ~TRTGridSamplerCreator() TRT_NOEXCEPT override = default; 72 | 73 | const char *getPluginName() const TRT_NOEXCEPT override; 74 | 75 | const char *getPluginVersion() const TRT_NOEXCEPT override; 76 | 77 | nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) 78 | TRT_NOEXCEPT override; 79 | 80 | nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, 81 | size_t serialLength) TRT_NOEXCEPT override; 82 | }; 83 | } // namespace mmdeploy 84 | #endif // TRT_GRID_SAMPLER_HPP 85 | -------------------------------------------------------------------------------- /plugin/grid_sampler/trt_grid_sampler_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | // modified from 3 | // https://github.com/pytorch/pytorch/blob/ec683299ebabf297a3504c76248d37be830e4342/aten/src/ATen/native/cuda/GridSampler.cuh 4 | // and 5 | // https://github.com/pytorch/pytorch/blob/ec683299ebabf297a3504c76248d37be830e4342/aten/src/ATen/native/cuda/GridSampler.cu 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | 14 | #include "common_cuda_helper.hpp" 15 | #include "trt_grid_sampler_kernel.hpp" 16 | #include "trt_plugin_helper.hpp" 17 | 18 | using mmdeploy::TensorDesc; 19 | 20 | // Unnormalizes a coordinate from the -1 to +1 scale to its pixel index value, 21 | // where we view each pixel as an area between (idx - 0.5) and (idx + 0.5). 22 | // if align_corners: -1 and +1 get sent to the centers of the corner pixels 23 | // -1 --> 0 24 | // +1 --> (size - 1) 25 | // scale_factor = (size - 1) / 2 26 | // if not align_corners: -1 and +1 get sent to the image edges 27 | // -1 --> -0.5 28 | // +1 --> (size - 1) + 0.5 == size - 0.5 29 | // scale_factor = size / 2 30 | template 31 | static __forceinline__ __device__ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, 32 | bool align_corners) { 33 | if (align_corners) { 34 | // unnormalize coord from [-1, 1] to [0, size - 1] 35 | return ((coord + 1.f) / 2) * (size - 1); 36 | } else { 37 | // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] 38 | return ((coord + 1.f) * size - 1) / 2; 39 | } 40 | } 41 | 42 | // Clips coordinates to between 0 and clip_limit - 1 43 | template 44 | static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, int clip_limit) { 45 | return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); 46 | } 47 | 48 | // Reflects coordinates until they fall between low and high (inclusive). 49 | // The bounds are passed as twice their value so that half-integer values 50 | // can be represented as ints. 51 | template 52 | static __forceinline__ __device__ scalar_t reflect_coordinates(scalar_t in, int twice_low, 53 | int twice_high) { 54 | if (twice_low == twice_high) { 55 | return static_cast(0); 56 | } 57 | scalar_t min = static_cast(twice_low) / 2; 58 | scalar_t span = static_cast(twice_high - twice_low) / 2; 59 | in = ::fabs(in - min); 60 | // `fmod` returns same sign as `in`, which is positive after the `fabs` above. 61 | scalar_t extra = ::fmod(in, span); 62 | int flips = static_cast(::floor(in / span)); 63 | if (flips % 2 == 0) { 64 | return extra + min; 65 | } else { 66 | return span - extra + min; 67 | } 68 | } 69 | 70 | template 71 | static __forceinline__ __device__ scalar_t safe_downgrade_to_int_range(scalar_t x) { 72 | // -100.0 does not have special meaning. This is just to make sure 73 | // it's not within_bounds_2d or within_bounds_3d, and does not cause 74 | // undefined behavior. See #35506. 75 | if (x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x))) 76 | return static_cast(-100.0); 77 | return x; 78 | } 79 | 80 | // Computes the pixel source index value for a grid coordinate 81 | template 82 | static __forceinline__ __device__ scalar_t grid_sampler_compute_source_index( 83 | scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) { 84 | coord = grid_sampler_unnormalize(coord, size, align_corners); 85 | if (padding_mode == GridSamplerPadding::Border) { 86 | // clip coordinates to image borders 87 | coord = clip_coordinates(coord, size); 88 | } else if (padding_mode == GridSamplerPadding::Reflection) { 89 | // reflect coordinates by image borders 90 | if (align_corners) { 91 | coord = reflect_coordinates(coord, 0, 2 * (size - 1)); 92 | } else { 93 | coord = reflect_coordinates(coord, -1, 2 * size - 1); 94 | } 95 | // clip coordinates to image borders 96 | coord = clip_coordinates(coord, size); 97 | } 98 | 99 | coord = safe_downgrade_to_int_range(coord); 100 | return coord; 101 | } 102 | 103 | static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) { 104 | return h >= 0 && h < H && w >= 0 && w < W; 105 | } 106 | 107 | static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { 108 | return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; 109 | } 110 | 111 | template 112 | __global__ void grid_sampler_2d_kernel(const int nthreads, const scalar_t *input, 113 | const scalar_t *grid, scalar_t *output, 114 | TensorDesc input_desc, TensorDesc grid_desc, 115 | TensorDesc output_desc, 116 | const GridSamplerInterpolation interpolation_mode, 117 | const GridSamplerPadding padding_mode, bool align_corners) { 118 | int C = input_desc.shape[1]; 119 | int inp_H = input_desc.shape[2]; 120 | int inp_W = input_desc.shape[3]; 121 | int out_H = grid_desc.shape[1]; 122 | int out_W = grid_desc.shape[2]; 123 | int inp_sN = input_desc.stride[0]; 124 | int inp_sC = input_desc.stride[1]; 125 | int inp_sH = input_desc.stride[2]; 126 | int inp_sW = input_desc.stride[3]; 127 | int grid_sN = grid_desc.stride[0]; 128 | int grid_sH = grid_desc.stride[1]; 129 | int grid_sW = grid_desc.stride[2]; 130 | int grid_sCoor = grid_desc.stride[3]; 131 | int out_sN = output_desc.stride[0]; 132 | int out_sC = output_desc.stride[1]; 133 | int out_sH = output_desc.stride[2]; 134 | int out_sW = output_desc.stride[3]; 135 | 136 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 137 | const int w = index % out_W; 138 | const int h = (index / out_W) % out_H; 139 | const int n = index / (out_H * out_W); 140 | const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; 141 | 142 | // get the corresponding input x, y coordinates from grid 143 | scalar_t ix = grid[grid_offset]; 144 | scalar_t iy = grid[grid_offset + grid_sCoor]; 145 | 146 | ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); 147 | iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); 148 | 149 | if (interpolation_mode == GridSamplerInterpolation::Bilinear) { 150 | // get NE, NW, SE, SW pixel values from (x, y) 151 | int ix_nw = static_cast(::floor(ix)); 152 | int iy_nw = static_cast(::floor(iy)); 153 | int ix_ne = ix_nw + 1; 154 | int iy_ne = iy_nw; 155 | int ix_sw = ix_nw; 156 | int iy_sw = iy_nw + 1; 157 | int ix_se = ix_nw + 1; 158 | int iy_se = iy_nw + 1; 159 | 160 | // get surfaces to each neighbor: 161 | scalar_t nw = (ix_se - ix) * (iy_se - iy); 162 | scalar_t ne = (ix - ix_sw) * (iy_sw - iy); 163 | scalar_t sw = (ix_ne - ix) * (iy - iy_ne); 164 | scalar_t se = (ix - ix_nw) * (iy - iy_nw); 165 | 166 | // calculate bilinear weighted pixel value and set output pixel 167 | auto inp_ptr_NC = input + n * inp_sN; 168 | auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; 169 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { 170 | *out_ptr_NCHW = static_cast(0); 171 | if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { 172 | *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; 173 | } 174 | if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { 175 | *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; 176 | } 177 | if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { 178 | *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; 179 | } 180 | if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { 181 | *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; 182 | } 183 | } 184 | } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { 185 | int ix_nearest = static_cast(::round(ix)); 186 | int iy_nearest = static_cast(::round(iy)); 187 | 188 | // assign nearest neighbor pixel value to output pixel 189 | auto inp_ptr_NC = input + n * inp_sN; 190 | auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; 191 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { 192 | if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { 193 | *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; 194 | } else { 195 | *out_ptr_NCHW = static_cast(0); 196 | } 197 | } 198 | } 199 | } 200 | } 201 | 202 | template 203 | __global__ void grid_sampler_3d_kernel(const int nthreads, const scalar_t *input, 204 | const scalar_t *grid, scalar_t *output, 205 | TensorDesc input_desc, TensorDesc grid_desc, 206 | TensorDesc output_desc, 207 | const GridSamplerInterpolation interpolation_mode, 208 | const GridSamplerPadding padding_mode, bool align_corners) { 209 | int C = input_desc.shape[1]; 210 | int inp_D = input_desc.shape[2]; 211 | int inp_H = input_desc.shape[3]; 212 | int inp_W = input_desc.shape[4]; 213 | int out_D = grid_desc.shape[1]; 214 | int out_H = grid_desc.shape[2]; 215 | int out_W = grid_desc.shape[3]; 216 | int inp_sN = input_desc.stride[0]; 217 | int inp_sC = input_desc.stride[1]; 218 | int inp_sD = input_desc.stride[2]; 219 | int inp_sH = input_desc.stride[3]; 220 | int inp_sW = input_desc.stride[4]; 221 | int grid_sN = grid_desc.stride[0]; 222 | int grid_sD = grid_desc.stride[1]; 223 | int grid_sH = grid_desc.stride[2]; 224 | int grid_sW = grid_desc.stride[3]; 225 | int grid_sCoor = grid_desc.stride[4]; 226 | int out_sN = output_desc.stride[0]; 227 | int out_sC = output_desc.stride[1]; 228 | int out_sD = output_desc.stride[2]; 229 | int out_sH = output_desc.stride[3]; 230 | int out_sW = output_desc.stride[4]; 231 | 232 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 233 | const int w = index % out_W; 234 | const int h = (index / out_W) % out_H; 235 | const int d = (index / (out_H * out_W)) % out_D; 236 | const int n = index / (out_D * out_H * out_W); 237 | const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; 238 | 239 | // get the corresponding input x, y, z coordinates from grid 240 | scalar_t ix = grid[grid_offset]; 241 | scalar_t iy = grid[grid_offset + grid_sCoor]; 242 | scalar_t iz = grid[grid_offset + 2 * grid_sCoor]; 243 | 244 | ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); 245 | iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); 246 | iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); 247 | 248 | if (interpolation_mode == GridSamplerInterpolation::Bilinear) { 249 | // get corner pixel values from (x, y, z) 250 | // for 4d, we used north-east-south-west 251 | // for 5d, we add top-bottom 252 | int ix_tnw = static_cast(::floor(ix)); 253 | int iy_tnw = static_cast(::floor(iy)); 254 | int iz_tnw = static_cast(::floor(iz)); 255 | 256 | int ix_tne = ix_tnw + 1; 257 | int iy_tne = iy_tnw; 258 | int iz_tne = iz_tnw; 259 | 260 | int ix_tsw = ix_tnw; 261 | int iy_tsw = iy_tnw + 1; 262 | int iz_tsw = iz_tnw; 263 | 264 | int ix_tse = ix_tnw + 1; 265 | int iy_tse = iy_tnw + 1; 266 | int iz_tse = iz_tnw; 267 | 268 | int ix_bnw = ix_tnw; 269 | int iy_bnw = iy_tnw; 270 | int iz_bnw = iz_tnw + 1; 271 | 272 | int ix_bne = ix_tnw + 1; 273 | int iy_bne = iy_tnw; 274 | int iz_bne = iz_tnw + 1; 275 | 276 | int ix_bsw = ix_tnw; 277 | int iy_bsw = iy_tnw + 1; 278 | int iz_bsw = iz_tnw + 1; 279 | 280 | int ix_bse = ix_tnw + 1; 281 | int iy_bse = iy_tnw + 1; 282 | int iz_bse = iz_tnw + 1; 283 | 284 | // get surfaces to each neighbor: 285 | scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); 286 | scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); 287 | scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); 288 | scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); 289 | scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); 290 | scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); 291 | scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); 292 | scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); 293 | 294 | auto inp_ptr_NC = input + n * inp_sN; 295 | auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; 296 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { 297 | // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * 298 | // tne 299 | // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * 300 | // tse 301 | // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * 302 | // bne 303 | // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * 304 | // bse 305 | *out_ptr_NCDHW = static_cast(0); 306 | if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { 307 | *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; 308 | } 309 | if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { 310 | *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; 311 | } 312 | if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { 313 | *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; 314 | } 315 | if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { 316 | *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; 317 | } 318 | if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { 319 | *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; 320 | } 321 | if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { 322 | *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; 323 | } 324 | if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { 325 | *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; 326 | } 327 | if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { 328 | *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; 329 | } 330 | } 331 | } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { 332 | int ix_nearest = static_cast(::round(ix)); 333 | int iy_nearest = static_cast(::round(iy)); 334 | int iz_nearest = static_cast(::round(iz)); 335 | 336 | // assign nearest neighbor pixel value to output pixel 337 | auto inp_ptr_NC = input + n * inp_sN; 338 | auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; 339 | for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { 340 | if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { 341 | *out_ptr_NCDHW = 342 | inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; 343 | } else { 344 | *out_ptr_NCDHW = static_cast(0); 345 | } 346 | } 347 | } 348 | } 349 | } 350 | 351 | void create_desc(const int *dims, int nb_dims, TensorDesc &desc) { 352 | memcpy(&desc.shape[0], dims, sizeof(int) * nb_dims); 353 | desc.stride[nb_dims - 1] = 1; 354 | for (int i = nb_dims - 2; i >= 0; --i) { 355 | desc.stride[i] = desc.stride[i + 1] * desc.shape[i + 1]; 356 | } 357 | } 358 | 359 | template 360 | void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, 361 | int *grid_dims, int nb_dims, GridSamplerInterpolation interp, 362 | GridSamplerPadding padding, bool align_corners, cudaStream_t stream) { 363 | TensorDesc input_desc; 364 | create_desc(input_dims, nb_dims, input_desc); 365 | 366 | TensorDesc output_desc; 367 | create_desc(output_dims, nb_dims, output_desc); 368 | 369 | TensorDesc grid_desc; 370 | create_desc(grid_dims, nb_dims, grid_desc); 371 | 372 | int count = 1; 373 | for (int i = 0; i < nb_dims; ++i) { 374 | if (i == 1) { 375 | continue; 376 | } 377 | count *= output_desc.shape[i]; 378 | } 379 | 380 | if (nb_dims == 4) { 381 | grid_sampler_2d_kernel<<>>( 382 | count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, 383 | align_corners); 384 | } else if (nb_dims == 5) { 385 | grid_sampler_3d_kernel<<>>( 386 | count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, 387 | align_corners); 388 | } else { 389 | printf("input and grid dims should be 4 or 5\n"); 390 | } 391 | } 392 | 393 | template void grid_sample(float *output, const float *input, const float *grid, 394 | int *output_dims, int *input_dims, int *grid_dims, int nb_dims, 395 | GridSamplerInterpolation interp, GridSamplerPadding padding, 396 | bool align_corners, cudaStream_t stream); 397 | -------------------------------------------------------------------------------- /plugin/grid_sampler/trt_grid_sampler_kernel.hpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) OpenMMLab. All rights reserved. 2 | #ifndef TRT_GRID_SAMPLER_KERNEL_HPP 3 | #define TRT_GRID_SAMPLER_KERNEL_HPP 4 | #include 5 | 6 | enum class GridSamplerInterpolation { Bilinear, Nearest }; 7 | enum class GridSamplerPadding { Zeros, Border, Reflection }; 8 | 9 | template 10 | void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, 11 | int *grid_dims, int nb_dims, GridSamplerInterpolation interp, 12 | GridSamplerPadding padding, bool align_corners, cudaStream_t stream); 13 | #endif // TRT_GRID_SAMPLER_KERNEL_HPP 14 | -------------------------------------------------------------------------------- /test_res/final.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/test_res/final.jpg -------------------------------------------------------------------------------- /test_res/final_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/test_res/final_0.jpg -------------------------------------------------------------------------------- /test_res/static/cmd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/test_res/static/cmd.png -------------------------------------------------------------------------------- /test_res/static/v1.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/test_res/static/v1.PNG -------------------------------------------------------------------------------- /test_res/static/v2.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataXujing/Co-DETR-TensorRT/d79cf4fe508c47cca86594b52cb1088968fe1cd3/test_res/static/v2.PNG --------------------------------------------------------------------------------