├── .gitignore ├── .vscode ├── c_cpp_properties.json ├── launch.json ├── settings.json └── tasks.json ├── CMakeLists.txt ├── Makefile ├── README.md ├── onnx ├── make_pb.sh ├── onnx-ml.proto └── onnx-operators-ml.proto ├── onnx_parser ├── change.log.md ├── onnx_parser_7.x │ ├── ImporterContext.hpp │ ├── LoopHelpers.cpp │ ├── LoopHelpers.hpp │ ├── ModelImporter.cpp │ ├── ModelImporter.hpp │ ├── NvOnnxParser.cpp │ ├── NvOnnxParser.h │ ├── OnnxAttrs.cpp │ ├── OnnxAttrs.hpp │ ├── RNNHelpers.cpp │ ├── RNNHelpers.hpp │ ├── ShapeTensor.cpp │ ├── ShapeTensor.hpp │ ├── ShapedWeights.cpp │ ├── ShapedWeights.hpp │ ├── Status.hpp │ ├── TensorOrWeights.hpp │ ├── builtin_op_importers.cpp │ ├── builtin_op_importers.hpp │ ├── common.hpp │ ├── onnx2trt.hpp │ ├── onnx2trt_common.hpp │ ├── onnx2trt_runtime.hpp │ ├── onnx2trt_utils.cpp │ ├── onnx2trt_utils.hpp │ ├── onnx_utils.hpp │ ├── toposort.hpp │ ├── trt_utils.hpp │ └── utils.hpp ├── onnx_parser_8.x │ ├── ImporterContext.hpp │ ├── LoopHelpers.cpp │ ├── LoopHelpers.hpp │ ├── ModelImporter.cpp │ ├── ModelImporter.hpp │ ├── NvOnnxParser.cpp │ ├── NvOnnxParser.h │ ├── OnnxAttrs.cpp │ ├── OnnxAttrs.hpp │ ├── RNNHelpers.cpp │ ├── RNNHelpers.hpp │ ├── ShapeTensor.cpp │ ├── ShapeTensor.hpp │ ├── ShapedWeights.cpp │ ├── ShapedWeights.hpp │ ├── Status.hpp │ ├── TensorOrWeights.hpp │ ├── builtin_op_importers.cpp │ ├── builtin_op_importers.hpp │ ├── onnx2trt.hpp │ ├── onnx2trt_common.hpp │ ├── onnx2trt_runtime.hpp │ ├── onnx2trt_utils.cpp │ ├── onnx2trt_utils.hpp │ ├── onnxErrorRecorder.cpp │ ├── onnxErrorRecorder.hpp │ ├── onnx_utils.hpp │ ├── readme.md │ ├── toposort.hpp │ ├── trt_utils.hpp │ └── utils.hpp ├── readme.md ├── use_tensorrt_7.x.sh └── use_tensorrt_8.x.sh ├── src ├── application │ ├── app_demuxer.cpp │ ├── app_hard_decode.cpp │ ├── app_yolo.cpp │ ├── app_yolo │ │ ├── yolo.cpp │ │ ├── yolo.hpp │ │ └── yolo_decode.cu │ ├── common │ │ └── object_detector.hpp │ └── tools │ │ ├── auto_download.cpp │ │ ├── zmq_remote_show.cpp │ │ ├── zmq_remote_show.hpp │ │ ├── zmq_u.cpp │ │ └── zmq_u.hpp ├── ffhdd │ ├── cuvid_decoder.cpp │ ├── cuvid_decoder.hpp │ ├── ffmpeg_demuxer.cpp │ ├── ffmpeg_demuxer.hpp │ └── nalu.hpp ├── main.cpp └── tensorRT │ ├── builder │ ├── trt_builder.cpp │ └── trt_builder.hpp │ ├── common │ ├── cuda_tools.cpp │ ├── cuda_tools.hpp │ ├── ilogger.cpp │ ├── ilogger.hpp │ ├── infer_controller.hpp │ ├── json.cpp │ ├── json.hpp │ ├── monopoly_allocator.hpp │ ├── preprocess_kernel.cu │ ├── preprocess_kernel.cuh │ ├── trt_tensor.cpp │ └── trt_tensor.hpp │ ├── import_lib.cpp │ ├── infer │ ├── trt_infer.cpp │ └── trt_infer.hpp │ ├── onnx │ ├── onnx-ml.pb.cpp │ ├── onnx-ml.pb.h │ ├── onnx-operators-ml.pb.cpp │ ├── onnx-operators-ml.pb.h │ ├── onnx_pb.h │ ├── onnxifi.h │ └── readme.md │ ├── onnx_parser │ ├── ImporterContext.hpp │ ├── LoopHelpers.cpp │ ├── LoopHelpers.hpp │ ├── ModelImporter.cpp │ ├── ModelImporter.hpp │ ├── NvOnnxParser.cpp │ ├── NvOnnxParser.h │ ├── OnnxAttrs.cpp │ ├── OnnxAttrs.hpp │ ├── RNNHelpers.cpp │ ├── RNNHelpers.hpp │ ├── ShapeTensor.cpp │ ├── ShapeTensor.hpp │ ├── ShapedWeights.cpp │ ├── ShapedWeights.hpp │ ├── Status.hpp │ ├── TensorOrWeights.hpp │ ├── builtin_op_importers.cpp │ ├── builtin_op_importers.hpp │ ├── onnx2trt.hpp │ ├── onnx2trt_common.hpp │ ├── onnx2trt_runtime.hpp │ ├── onnx2trt_utils.cpp │ ├── onnx2trt_utils.hpp │ ├── onnxErrorRecorder.cpp │ ├── onnxErrorRecorder.hpp │ ├── onnx_utils.hpp │ ├── readme.md │ ├── toposort.hpp │ ├── trt_utils.hpp │ └── utils.hpp │ └── onnxplugin │ ├── onnxplugin.cpp │ ├── onnxplugin.hpp │ ├── plugin_binary_io.cpp │ ├── plugin_binary_io.hpp │ └── plugins │ ├── DCNv2.cu │ ├── HSigmoid.cu │ └── HSwish.cu └── workspace └── exp ├── fall_video.mp4 └── number100.mp4 /.gitignore: -------------------------------------------------------------------------------- 1 | # 2 | tutorial/2.0CenterNet_from_torch_trt/0_to_1_python_to_cuda/cpp_cuda_centernet/src/tensorRT 3 | tutorial/2.0CenterNet_from_torch_trt/0_to_1_python_to_cuda/cpp_cuda_centernet/objs 4 | tutorial/2.0CenterNet_from_torch_trt/0_to_1_python_to_cuda/cpp_cuda_centernet/workspace/pro 5 | 6 | # compressed files 7 | *.tar.gz 8 | *.zip 9 | 10 | # temp tensor and data 11 | *.tensor 12 | *.data 13 | 14 | 15 | # Prerequisites 16 | *.d 17 | 18 | # Compiled Object files 19 | *.slo 20 | *.lo 21 | *.o 22 | *.obj 23 | 24 | # Precompiled Headers 25 | *.gch 26 | *.pch 27 | 28 | # Compiled Dynamic libraries 29 | *.so 30 | *.dylib 31 | *.dll 32 | 33 | # Fortran module files 34 | *.mod 35 | *.smod 36 | 37 | # Compiled Static libraries 38 | *.lai 39 | *.la 40 | *.a 41 | *.lib 42 | 43 | # Executables 44 | *.exe 45 | *.out 46 | *.app 47 | 48 | /objs 49 | 50 | *.trtmodel 51 | *.onnx 52 | /workspace/pro 53 | /build 54 | /workspace/*.avi 55 | /workspace/.ipynb_checkpoints 56 | /workspace/*_result 57 | /workspace/face/library_draw 58 | /workspace/face/result 59 | /workspace/face/library/laq.jpg 60 | __pycache__ 61 | /tools/process_so.sh 62 | /tools/proc2.sh 63 | /python/trtpy.egg-info 64 | /python/dist 65 | /python/build 66 | /workspace/formtest.ipynb 67 | /workspace/meta.json 68 | /.vs 69 | *.pyd 70 | *.zip 71 | *.pdb 72 | *.ilk 73 | *.lib 74 | *.exp 75 | 76 | /lean/cuda10.1 77 | /lean/cudnn8.2.2.26 78 | /lean/opencv3.4.6 79 | /lean/protobuf3.11.4 80 | /lean/TensorRT-8.0.1.6 81 | 82 | __pycache__ 83 | 84 | !/workspace/wget.exe 85 | /workspace/*.mp4 86 | /workspace/single_inference 87 | /workspace/exp/tracker.final.mp4 88 | /workspace/perf.result.log 89 | /simple_yolo/workspace/pro 90 | /simple_yolo/objs 91 | /workspace/*.json 92 | /workspace/*.png 93 | /workspace/imgs 94 | /workspace/hard 95 | /workspace/soft 96 | -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Linux", 5 | "includePath": [ 6 | "${workspaceFolder}/src/**", 7 | "/data/sxai/lean/protobuf3.11.4/include/**", 8 | "/data/sxai/lean/opencv4.2.0/include/opencv4/**", 9 | "/data/sxai/lean/cuda-10.2/include/**", 10 | "/data/sxai/lean/TensorRT-8.0.1.6-cuda10.2-cudnn8.2/include/**", 11 | "/data/sxai/lean/cudnn7.6.5.32-cuda10.2/include/**", 12 | "/data/datav/newbb/lean/anaconda3/envs/torch1.8/include/python3.9/**", 13 | "/data/sxai/lean/ffmpeg4.2/include/**", 14 | "/data/sxai/lean/Video_Codec_SDK_10.0.26/Interface/**" 15 | ], 16 | "defines": ["__CUDACC__", "HAS_PYTHON"], 17 | "compilerPath": "/usr/bin/gcc", 18 | "cStandard": "gnu11", 19 | "cppStandard": "gnu++11", 20 | "intelliSenseMode": "linux-gcc-x64", 21 | "configurationProvider": "ms-vscode.makefile-tools", 22 | "browse": { 23 | "path": [ 24 | "${workspaceFolder}/src/**", 25 | "/data/sxai/lean/protobuf3.11.4/include/**", 26 | "/data/sxai/lean/opencv4.2.0/include/opencv4/**", 27 | "/data/sxai/lean/cuda-10.2/include/**", 28 | "/data/sxai/lean/TensorRT-8.0.1.6-cuda10.2-cudnn8.2/include/**", 29 | "/data/sxai/lean/cudnn7.6.5.32-cuda10.2/include/**", 30 | "/data/datav/newbb/lean/anaconda3/envs/torch1.8/include/python3.9/**", 31 | "/data/sxai/lean/ffmpeg4.2/include/**", 32 | "/data/sxai/lean/Video_Codec_SDK_10.0.26/Interface/**" 33 | ], 34 | "limitSymbolsToIncludedHeaders": false, 35 | "databaseFilename": "" 36 | } 37 | } 38 | ], 39 | "version": 4 40 | } -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // 使用 IntelliSense 了解相关属性。 3 | // 悬停以查看现有属性的描述。 4 | // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "(gdb) 启动", 9 | "type": "cppdbg", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/workspace/pro", 12 | "args": ["demuxer"], 13 | "stopAtEntry": false, 14 | "cwd": "${workspaceFolder}/workspace", 15 | "environment": [], 16 | "externalConsole": false, 17 | "MIMode": "gdb", 18 | "setupCommands": [ 19 | { 20 | "description": "为 gdb 启用整齐打印", 21 | "text": "-enable-pretty-printing", 22 | "ignoreFailures": true 23 | } 24 | ], 25 | "preLaunchTask": "build" 26 | } 27 | ] 28 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.snippetSuggestions": "bottom", 3 | "editor.suggest.snippetsPreventQuickSuggestions": false, 4 | "python.languageServer": "Jedi", 5 | "files.associations": { 6 | "*.cpp": "cpp", 7 | "*.cu": "cpp", 8 | "*.cuh": "cpp", 9 | "unordered_map": "cpp", 10 | "chrono": "cpp", 11 | "thread": "cpp", 12 | "future": "cpp", 13 | "type_traits": "cpp", 14 | "tuple": "cpp", 15 | "__functional_03": "cpp", 16 | "functional": "cpp", 17 | "__nullptr": "cpp", 18 | "cstddef": "cpp", 19 | "exception": "cpp", 20 | "initializer_list": "cpp", 21 | "new": "cpp", 22 | "optional": "cpp", 23 | "typeinfo": "cpp", 24 | "array": "cpp", 25 | "atomic": "cpp", 26 | "*.tcc": "cpp", 27 | "bitset": "cpp", 28 | "cctype": "cpp", 29 | "clocale": "cpp", 30 | "cmath": "cpp", 31 | "codecvt": "cpp", 32 | "complex": "cpp", 33 | "condition_variable": "cpp", 34 | "cstdarg": "cpp", 35 | "cstdint": "cpp", 36 | "cstdio": "cpp", 37 | "cstdlib": "cpp", 38 | "cstring": "cpp", 39 | "ctime": "cpp", 40 | "cwchar": "cpp", 41 | "cwctype": "cpp", 42 | "deque": "cpp", 43 | "list": "cpp", 44 | "unordered_set": "cpp", 45 | "vector": "cpp", 46 | "algorithm": "cpp", 47 | "filesystem": "cpp", 48 | "ratio": "cpp", 49 | "string_view": "cpp", 50 | "system_error": "cpp", 51 | "fstream": "cpp", 52 | "iomanip": "cpp", 53 | "iosfwd": "cpp", 54 | "iostream": "cpp", 55 | "istream": "cpp", 56 | "limits": "cpp", 57 | "memory": "cpp", 58 | "mutex": "cpp", 59 | "ostream": "cpp", 60 | "numeric": "cpp", 61 | "sstream": "cpp", 62 | "stdexcept": "cpp", 63 | "streambuf": "cpp", 64 | "cinttypes": "cpp", 65 | "utility": "cpp", 66 | "__config": "cpp", 67 | "variant": "cpp", 68 | "forward_list": "cpp", 69 | "typeindex": "cpp", 70 | "valarray": "cpp", 71 | "__bit_reference": "cpp", 72 | "__hash_table": "cpp", 73 | "__split_buffer": "cpp", 74 | "__tree": "cpp", 75 | "iterator": "cpp", 76 | "locale": "cpp", 77 | "map": "cpp", 78 | "set": "cpp", 79 | "string": "cpp", 80 | "__locale": "cpp", 81 | "ios": "cpp", 82 | "queue": "cpp", 83 | "random": "cpp", 84 | "stack": "cpp", 85 | "__atomic": "cpp", 86 | "__debug": "cpp", 87 | "__node_handle": "cpp", 88 | "__mutex_base": "cpp", 89 | "__functional_base": "cpp", 90 | "__memory": "cpp", 91 | "__atomic_generated": "cpp", 92 | "__functional_base_03": "cpp", 93 | "__tuple": "cpp", 94 | "*.inc": "cpp" 95 | } 96 | } -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "build", 8 | "type": "shell", 9 | "command": "make pro -j64" 10 | 11 | // for cmake 12 | //"command": "cd build && make pro -j25" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.6) 2 | project(pro) 3 | 4 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 5 | set(CMAKE_CXX_STANDARD 11) 6 | set(CMAKE_BUILD_TYPE Debug) 7 | set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/workspace) 8 | 9 | # 如果你是不同显卡,请设置为显卡对应的号码参考这里:https://developer.nvidia.com/zh-cn/cuda-gpus#compute 10 | set(CUDA_GEN_CODE "-gencode=arch=compute_75,code=sm_75") 11 | 12 | # 如果你的opencv找不到,可以自己指定目录 13 | set(OpenCV_DIR "/data/datav/expstation/lean/opencv4.2.0/lib/cmake/opencv4/") 14 | 15 | set(CUDA_DIR "/data/sxai/lean/cuda-10.2") 16 | set(CUDNN_DIR "/data/sxai/lean/cudnn8.2.2.26") 17 | set(TENSORRT_DIR "/data/sxai/lean/TensorRT-8.0.1.6-cuda10.2-cudnn8.2") 18 | set(NVDEC_DIR "/data/sxai/lean/Video_Codec_SDK_10.0.26") 19 | set(FFMPEG_DIR "/data/sxai/lean/ffmpeg4.2") 20 | 21 | # set(CUDA_DIR "/data/sxai/lean/cuda-10.2") 22 | # set(CUDNN_DIR "/data/sxai/lean/cudnn7.6.5.32-cuda10.2") 23 | # set(TENSORRT_DIR "/data/sxai/lean/TensorRT-7.0.0.11") 24 | 25 | # set(CUDA_DIR "/data/sxai/lean/cuda-11.1") 26 | # set(CUDNN_DIR "/data/sxai/lean/cudnn8.2.2.26") 27 | # set(TENSORRT_DIR "/data/sxai/lean/TensorRT-7.2.1.6") 28 | 29 | # 因为protobuf,需要用特定版本,所以这里指定路径 30 | set(PROTOBUF_DIR "/data/sxai/lean/protobuf3.11.4") 31 | 32 | 33 | find_package(CUDA REQUIRED) 34 | find_package(OpenCV) 35 | 36 | include_directories( 37 | ${PROJECT_SOURCE_DIR}/src 38 | ${PROJECT_SOURCE_DIR}/src/application 39 | ${PROJECT_SOURCE_DIR}/src/tensorRT 40 | ${PROJECT_SOURCE_DIR}/src/tensorRT/common 41 | ${OpenCV_INCLUDE_DIRS} 42 | ${CUDA_DIR}/include 43 | ${PROTOBUF_DIR}/include 44 | ${TENSORRT_DIR}/include 45 | ${CUDNN_DIR}/include 46 | ${NVDEC_DIR}/Interface 47 | ${FFMPEG_DIR}/include 48 | ) 49 | 50 | # 切记,protobuf的lib目录一定要比tensorRT目录前面,因为tensorRTlib下带有protobuf的so文件 51 | # 这可能带来错误 52 | link_directories( 53 | ${PROTOBUF_DIR}/lib 54 | ${TENSORRT_DIR}/lib 55 | ${CUDA_DIR}/lib64 56 | ${CUDNN_DIR}/lib 57 | ${NVDEC_DIR}/Lib/linux/stubs/x86_64 58 | ${FFMPEG_DIR}/lib 59 | ) 60 | 61 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -O0 -Wfatal-errors -pthread -w -g") 62 | set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11 -O0 -Xcompiler -fPIC -g -w ${CUDA_GEN_CODE}") 63 | file(GLOB_RECURSE cpp_srcs ${PROJECT_SOURCE_DIR}/src/*.cpp) 64 | file(GLOB_RECURSE cuda_srcs ${PROJECT_SOURCE_DIR}/src/*.cu) 65 | cuda_add_library(plugin_list SHARED ${cuda_srcs}) 66 | 67 | add_executable(pro ${cpp_srcs}) 68 | 69 | # 如果提示插件找不到,请使用dlopen(xxx.so, NOW)的方式手动加载可以解决插件找不到问题 70 | target_link_libraries(pro nvinfer nvinfer_plugin) 71 | target_link_libraries(pro cuda cublas cudart cudnn) 72 | target_link_libraries(pro nvcuvid nvidia-encode) 73 | target_link_libraries(pro protobuf pthread plugin_list) 74 | target_link_libraries(pro avcodec avformat avresample swscale avutil) 75 | target_link_libraries(pro ${OpenCV_LIBS}) 76 | 77 | add_custom_target( 78 | yolo 79 | DEPENDS pro 80 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/workspace 81 | COMMAND ./pro yolo 82 | ) 83 | 84 | add_custom_target( 85 | demuxer 86 | DEPENDS pro 87 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/workspace 88 | COMMAND ./pro demuxer 89 | ) 90 | 91 | add_custom_target( 92 | hard_decode 93 | DEPENDS pro 94 | WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}/workspace 95 | COMMAND ./pro hard_decode 96 | ) 97 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | cpp_srcs := $(shell find src -name "*.cpp") 3 | cpp_objs := $(cpp_srcs:.cpp=.o) 4 | cpp_objs := $(cpp_objs:src/%=objs/%) 5 | cpp_mk := $(cpp_objs:.o=.mk) 6 | 7 | cu_srcs := $(shell find src -name "*.cu") 8 | cu_objs := $(cu_srcs:.cu=.cuo) 9 | cu_objs := $(cu_objs:src/%=objs/%) 10 | cu_mk := $(cu_objs:.cuo=.cumk) 11 | 12 | lean_protobuf := /data/sxai/lean/protobuf3.11.4 13 | lean_tensor_rt := /data/sxai/lean/TensorRT-8.0.1.6-cuda10.2-cudnn8.2 14 | lean_cudnn := /data/sxai/lean/cudnn8.2.2.26 15 | lean_opencv := /data/sxai/lean/opencv4.2.0 16 | lean_cuda := /data/sxai/lean/cuda-10.2 17 | lean_ffmpeg := /data/sxai/lean/ffmpeg4.2 18 | lean_nvdec := /data/sxai/lean/Video_Codec_SDK_10.0.26 19 | 20 | include_paths := src \ 21 | src/application \ 22 | src/tensorRT \ 23 | src/tensorRT/common \ 24 | $(lean_protobuf)/include \ 25 | $(lean_opencv)/include/opencv4 \ 26 | $(lean_tensor_rt)/include \ 27 | $(lean_cuda)/include \ 28 | $(lean_cudnn)/include \ 29 | $(lean_ffmpeg)/include \ 30 | $(lean_nvdec)/Interface 31 | 32 | library_paths := $(lean_protobuf)/lib \ 33 | $(lean_opencv)/lib \ 34 | $(lean_tensor_rt)/lib \ 35 | $(lean_cuda)/lib64 \ 36 | $(lean_cudnn)/lib \ 37 | $(lean_ffmpeg)/lib \ 38 | $(lean_nvdec)/Lib/linux/stubs/x86_64 39 | 40 | link_librarys := opencv_core opencv_imgproc opencv_videoio opencv_imgcodecs \ 41 | nvinfer nvinfer_plugin \ 42 | cuda cublas cudart cudnn \ 43 | nvcuvid nvidia-encode \ 44 | avcodec avformat avresample swscale avutil \ 45 | stdc++ protobuf dl 46 | 47 | paths := $(foreach item,$(library_paths),-Wl,-rpath=$(item)) 48 | include_paths := $(foreach item,$(include_paths),-I$(item)) 49 | library_paths := $(foreach item,$(library_paths),-L$(item)) 50 | link_librarys := $(foreach item,$(link_librarys),-l$(item)) 51 | 52 | # 如果是其他显卡,请修改-gencode=arch=compute_75,code=sm_75为对应显卡的能力 53 | # 显卡对应的号码参考这里:https://developer.nvidia.com/zh-cn/cuda-gpus#compute 54 | # 如果是 jetson nano,提示找不到-m64指令,请删掉 -m64选项。不影响结果 55 | cpp_compile_flags := -std=c++11 -fPIC -m64 -g -fopenmp -w -O0 56 | cu_compile_flags := -std=c++11 -m64 -Xcompiler -fPIC -g -w -gencode=arch=compute_75,code=sm_75 -O0 57 | link_flags := -pthread -fopenmp -Wl,-rpath='$$ORIGIN' 58 | 59 | cpp_compile_flags += $(include_paths) 60 | cu_compile_flags += $(include_paths) 61 | link_flags += $(library_paths) $(link_librarys) $(paths) 62 | 63 | ifneq ($(MAKECMDGOALS), clean) 64 | -include $(cpp_mk) $(cu_mk) 65 | endif 66 | 67 | pro : workspace/pro 68 | trtpyc : python/trtpy/libtrtpyc.so 69 | 70 | workspace/pro : $(cpp_objs) $(cu_objs) 71 | @echo Link $@ 72 | @mkdir -p $(dir $@) 73 | @g++ $^ -o $@ $(link_flags) 74 | 75 | python/trtpy/libtrtpyc.so : $(cpp_objs) $(cu_objs) 76 | @echo Link $@ 77 | @mkdir -p $(dir $@) 78 | @g++ -shared $^ -o $@ $(link_flags) 79 | 80 | objs/%.o : src/%.cpp 81 | @echo Compile CXX $< 82 | @mkdir -p $(dir $@) 83 | @g++ -c $< -o $@ $(cpp_compile_flags) 84 | 85 | objs/%.cuo : src/%.cu 86 | @echo Compile CUDA $< 87 | @mkdir -p $(dir $@) 88 | @nvcc -c $< -o $@ $(cu_compile_flags) 89 | 90 | objs/%.mk : src/%.cpp 91 | @echo Compile depends CXX $< 92 | @mkdir -p $(dir $@) 93 | @g++ -M $< -MF $@ -MT $(@:.mk=.o) $(cpp_compile_flags) 94 | 95 | objs/%.cumk : src/%.cu 96 | @echo Compile depends CUDA $< 97 | @mkdir -p $(dir $@) 98 | @nvcc -M $< -MF $@ -MT $(@:.cumk=.o) $(cu_compile_flags) 99 | 100 | demuxer : workspace/pro 101 | @cd workspace && ./pro demuxer 102 | 103 | hard_decode : workspace/pro 104 | @cd workspace && ./pro hard_decode 105 | 106 | yolo : workspace/pro 107 | @cd workspace && ./pro yolo 108 | 109 | debug : 110 | @echo $(includes) 111 | 112 | clean : 113 | @rm -rf objs workspace/pro python/trtpy/libtrtpyc.so python/build python/dist python/trtpy.egg-info python/trtpy/__pycache__ 114 | @rm -rf workspace/single_inference 115 | @rm -rf workspace/scrfd_result workspace/retinaface_result 116 | @rm -rf workspace/YoloV5_result workspace/YoloX_result 117 | @rm -rf workspace/face/library_draw workspace/face/result 118 | @rm -rf build 119 | @rm -rf python/trtpy/libplugin_list.so 120 | 121 | .PHONY : clean yolo alphapose fall debug 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 硬件解码配合TensorRT 2 | - 配置tensorRT一样的环境 3 | - 增加NVDEC和ffmpeg的配置 4 | - `make yolo -j64` 5 | - Yolo和硬件解码直接对接 6 | - `make demuxer -j64` 7 | - 仅仅解封装得到h264的包,并分析是什么帧 8 | - `make hard_decode -j64` 9 | - 硬件解码测试 10 | - 软解码和硬解码,分别消耗cpu和gpu资源。在多路,大分辨率下体现明显 11 | - 硬件解码和推理可以允许跨显卡 12 | - 理解并善于利用的时候,他才可能发挥最大的效果 -------------------------------------------------------------------------------- /onnx/make_pb.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # 请修改protoc为你要使用的版本protoc 4 | protoc=/data/sxai/lean/protobuf3.11.4/bin/protoc 5 | #protoc=/data/sxai/temp/protobuf-build3.18.x/bin/protoc 6 | 7 | echo Create directory "pbout" 8 | rm -rf pbout 9 | mkdir -p pbout 10 | 11 | $protoc onnx-ml.proto --cpp_out=pbout 12 | $protoc onnx-operators-ml.proto --cpp_out=pbout 13 | 14 | echo Copy pbout/onnx-ml.pb.cc to ../src/tensorRT/onnx/onnx-ml.pb.cpp 15 | cp pbout/onnx-ml.pb.cc ../src/tensorRT/onnx/onnx-ml.pb.cpp 16 | 17 | echo Copy pbout/onnx-operators-ml.pb.cc to ../src/tensorRT/onnx/onnx-operators-ml.pb.cpp 18 | cp pbout/onnx-operators-ml.pb.cc ../src/tensorRT/onnx/onnx-operators-ml.pb.cpp 19 | 20 | echo Copy pbout/onnx-ml.pb.h to ../src/tensorRT/onnx/onnx-ml.pb.h 21 | cp pbout/onnx-ml.pb.h ../src/tensorRT/onnx/onnx-ml.pb.h 22 | 23 | echo Copy pbout/onnx-operators-ml.pb.h to ../src/tensorRT/onnx/onnx-operators-ml.pb.h 24 | cp pbout/onnx-operators-ml.pb.h ../src/tensorRT/onnx/onnx-operators-ml.pb.h 25 | 26 | echo Remove directory "pbout" 27 | rm -rf pbout -------------------------------------------------------------------------------- /onnx/onnx-operators-ml.proto: -------------------------------------------------------------------------------- 1 | // 2 | // WARNING: This file is automatically generated! Please edit onnx.in.proto. 3 | // 4 | 5 | 6 | // Copyright (c) ONNX Project Contributors. 7 | // Licensed under the MIT license. 8 | 9 | syntax = "proto2"; 10 | 11 | package onnx; 12 | import "onnx-ml.proto"; 13 | 14 | // 15 | // This file contains the proto definitions for OperatorSetProto and 16 | // OperatorProto. OperatorSetProtos are used to describe a versioned 17 | // set of operators that can be used by a ModelProto. 18 | // 19 | // Like ModelProto, OperatorSetProto is defined as a top-level file/wire 20 | // format, however their usage is different. 21 | // 22 | // ModelProto files are used to describe executable graphs that can be 23 | // executed directly by a framework, runtime, or engine. 24 | // 25 | // OperatorSetProto files are used to describe a set of operators that are 26 | // available in a given environment. The file TBD.TBD is the OperatorSetProto 27 | // that describes the ONNX standard operators. 28 | // 29 | 30 | // An OperatorProto represents the immutable specification of the signature 31 | // and semantics of an operator. 32 | // 33 | // Operators are declared as part of an OperatorSet, which also defines the 34 | // domain name for the set. 35 | // 36 | // Operators are uniquely identified by a three part identifier 37 | // (domain, op_type, since_version) 38 | // where 39 | // *domain* is the domain of an operator set that 40 | // contains this operator specification. 41 | // 42 | // *op_type* is the name of the operator as referenced by a 43 | // NodeProto.op_type 44 | // 45 | // *since_version* is the version of the operator set that 46 | // this operator was initially declared in. 47 | // 48 | message OperatorProto { 49 | // The name of the operator within a domain. 50 | // This field MUST be present in this version of the IR. 51 | optional string op_type = 1; 52 | 53 | // The version of the operator set that first introduced this 54 | // operator. This value MUST be the same value as the 55 | // opset_version of the operator set that first published this operator. 56 | // Subsequent versions of the operator set MUST NOT alter the signature 57 | // or semantics of the operator once published as STABLE. 58 | // This field MUST be present in this version of the IR. 59 | optional int64 since_version = 2; 60 | 61 | // This field indicates whether the syntax, semantics, or presence 62 | // of this operator is in an experimental or stable stage. Once an 63 | // operator is published as STABLE, it's syntax and semantics MUST NOT 64 | // change in subsequent versions of the operator set. 65 | // When an operator is published as EXPERIMENTAL, the syntax and semantics 66 | // of the operator MAY change across operator set versions. 67 | // Operators "become" stable by deprecating the experimental version and 68 | // introducing a new stable operator with the same op_type. 69 | optional OperatorStatus status = 3; 70 | 71 | // Eventually we will declare the signature of the operator here 72 | 73 | // A human-readable documentation for this operator. Markdown is allowed. 74 | optional string doc_string = 10; 75 | } 76 | 77 | // An OperatorSetProto represents an immutable set of immutable operator 78 | // specifications. 79 | // 80 | // The domain of the set (OperatorSetProto.domain) is a reverse-DNS name 81 | // that disambiguates operator sets defined by independent entities. 82 | // 83 | // The version of the set (opset_version) is a monotonically increasing 84 | // integer that indicates changes to the membership of the operator set. 85 | // 86 | // 87 | // Operator sets are uniquely identified by a two part identifier (domain, opset_version) 88 | // 89 | // Like ModelProto, OperatorSetProto is intended as a top-level file/wire format, 90 | // and thus has the standard format headers in addition to the operator set information. 91 | // 92 | message OperatorSetProto { 93 | // All OperatorSetProtos start with a distingushed byte sequence to disambiguate 94 | // protobuf files containing OperatorSets from other content. 95 | // This field MUST be "ONNXOPSET" 96 | // This field MUST be present in this version of the IR 97 | optional string magic = 1; 98 | 99 | // All OperatorSetProtos indicate the version of the IR syntax and semantics 100 | // they adhere to. It is always IR_VERSION. 101 | // This field MUST be present in this version of the IR 102 | optional int64 ir_version = 2; 103 | 104 | // The prerelease component of the SemVer of the IR. 105 | // This field MAY be absent in this version of the IR 106 | optional string ir_version_prerelease = 3; 107 | 108 | // The build metadata component of the SemVer of the IR. 109 | // This field MAY be absent in this version of the IR 110 | optional string ir_build_metadata = 7; 111 | 112 | // Domain name of the operator set, in reverse DNS form (e.g., com.acme.dnnops). 113 | optional string domain = 4; 114 | 115 | // The version of the set of operators. This is a simple int value 116 | // that is monotonically increasing as new versions of the operator set 117 | // are published. All operators in this set MUST have since_version 118 | // <= opset_version. 119 | optional int64 opset_version = 5; 120 | 121 | // A human-readable documentation for this set of operators. Markdown is allowed. 122 | optional string doc_string = 6; 123 | 124 | // The operators specified by this operator set. 125 | // The (name, version) MUST be unique across all OperatorProtos in operator 126 | repeated OperatorProto operator = 8; 127 | 128 | // The functions specified by this operator set. 129 | // The (name, version) MUST be unique across all OperatorProtos/FunctionProtos in operator/functions 130 | repeated FunctionProto functions = 9; 131 | } 132 | 133 | 134 | // For using protobuf-lite 135 | // option optimize_for = LITE_RUNTIME; 136 | 137 | -------------------------------------------------------------------------------- /onnx_parser/change.log.md: -------------------------------------------------------------------------------- 1 | # 针对OnnxParser的修改记录 2 | 1. builtin_op_importers.cpp:28 3 | - 增加针对reshape层hook的函数支持register_layerhook_reshape 4 | 2. builtin_op_importers.cpp:3543 5 | - 增加reshape节点中调用g_layerhook_func_reshape函数部分,使得hook生效 6 | 3. builtin_op_importers.cpp:168 7 | - 增加Plugin节点支持,并转发到自定义的plugin上实现自定义插件注册机制 8 | 4. builtin_op_importers.cpp:4480 9 | - 对upsample屏蔽代码,并允许支持scales为3个值的情况 10 | 5. ModelImporter.cpp:243 11 | - 增加对重定义维度的支持,并实现动态batch,固定batch维度为-1 12 | 6. ModelImporter.cpp:750 13 | - 增加对onnx数据直接做解析的支持 14 | 7. ModelImporter.hpp:29 15 | - 增加对定义维度的支持 16 | 8. ModelImporter.hpp:72 17 | - 增加对onnx文件数据解析的支持 18 | 9. NvOnnxParser.h:207 19 | - 增加对reshape钩子的支持,接口api 20 | 10. NvOnnxParser.h:228 21 | - 增加input_dims参数,支持对输入维度的重定义 -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/LoopHelpers.cpp: -------------------------------------------------------------------------------- 1 | #include "LoopHelpers.hpp" 2 | #include "onnx2trt_utils.hpp" 3 | 4 | namespace onnx2trt 5 | { 6 | 7 | nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial) 8 | { 9 | nvinfer1::ITensor* initialTensor = addConstantScalar(ctx, initial, ::onnx::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0); 10 | nvinfer1::ITensor* one = addConstantScalar(ctx, 1, ::onnx::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0); 11 | 12 | auto counter = loop->addRecurrence(*initialTensor); 13 | nvinfer1::ITensor* addOne = ctx->network()->addElementWise(*counter->getOutput(0), *one, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); 14 | counter->setInput(1, *addOne); 15 | return counter->getOutput(0); 16 | } 17 | 18 | } // namespace onnx2trt 19 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/LoopHelpers.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "ImporterContext.hpp" 6 | 7 | namespace onnx2trt 8 | { 9 | 10 | nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial = 0); 11 | 12 | } // namespace onnx2trt 13 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/ModelImporter.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include "ImporterContext.hpp" 26 | #include "NvInferPlugin.h" 27 | #include "NvOnnxParser.h" 28 | #include "builtin_op_importers.hpp" 29 | #include "onnx_utils.hpp" 30 | #include "utils.hpp" 31 | 32 | namespace onnx2trt 33 | { 34 | 35 | Status parseGraph(IImporterContext* ctx, const ::onnx::GraphProto& graph, bool deserializingINetwork = false, int* currentNode = nullptr); 36 | 37 | class ModelImporter : public nvonnxparser::IParser 38 | { 39 | protected: 40 | string_map _op_importers; 41 | virtual Status importModel(::onnx::ModelProto const& model, uint32_t weight_count, 42 | onnxTensorDescriptorV1 const* weight_descriptors); 43 | 44 | private: 45 | ImporterContext _importer_ctx; 46 | RefitMap_t mRefitMap; 47 | std::list<::onnx::ModelProto> _onnx_models; // Needed for ownership of weights 48 | int _current_node; 49 | std::vector _errors; 50 | std::vector _input_dims; 51 | 52 | public: 53 | ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger, const std::vector& input_dims) 54 | : _op_importers(getBuiltinOpImporterMap()) 55 | , _importer_ctx(network, logger, &mRefitMap) 56 | , _input_dims(input_dims) 57 | { 58 | } 59 | bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size, 60 | uint32_t weight_count, onnxTensorDescriptorV1 const* weight_descriptors) override; 61 | bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override; 62 | bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size, 63 | SubGraphCollection_t& sub_graph_collection) override; 64 | 65 | bool supportsOperator(const char* op_name) const override; 66 | void destroy() override 67 | { 68 | delete this; 69 | } 70 | // virtual void registerOpImporter(std::string op, 71 | // NodeImporter const &node_importer) override { 72 | // // Note: This allows existing importers to be replaced 73 | // _op_importers[op] = node_importer; 74 | //} 75 | // virtual Status const &setInput(const char *name, 76 | // nvinfer1::ITensor *input) override; 77 | // virtual Status const& setOutput(const char* name, nvinfer1::ITensor** output) override; 78 | int getNbErrors() const override 79 | { 80 | return _errors.size(); 81 | } 82 | nvonnxparser::IParserError const* getError(int index) const override 83 | { 84 | assert(0 <= index && index < (int) _errors.size()); 85 | return &_errors[index]; 86 | } 87 | void clearErrors() override 88 | { 89 | _errors.clear(); 90 | } 91 | virtual int getRefitMap(const char** weightNames, const char** layerNames, nvinfer1::WeightsRole* roles) override 92 | { 93 | int count = 0; 94 | for (const auto& entry: mRefitMap) 95 | { 96 | if (weightNames != nullptr) 97 | { 98 | weightNames[count] = entry.first.c_str(); 99 | } 100 | if (layerNames != nullptr) 101 | { 102 | layerNames[count] = entry.second.first.c_str(); 103 | } 104 | if (roles != nullptr) 105 | { 106 | roles[count] = entry.second.second; 107 | } 108 | ++count; 109 | } 110 | return mRefitMap.size(); 111 | } 112 | //...LG: Move the implementation to .cpp 113 | bool parseFromFile(const char* onnxModelFile, int verbosity) override; 114 | bool parseFromData(const void* onnx_data, size_t size, int verbosity) override; 115 | }; 116 | 117 | } // namespace onnx2trt 118 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/NvOnnxParser.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #include "NvOnnxParser.h" 24 | #include "ModelImporter.hpp" 25 | 26 | extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version, const std::vector& input_dims) 27 | { 28 | auto network = static_cast(network_); 29 | auto logger = static_cast(logger_); 30 | return new onnx2trt::ModelImporter(network, logger, input_dims); 31 | } 32 | 33 | extern "C" int getNvOnnxParserVersion() 34 | { 35 | return NV_ONNX_PARSER_VERSION; 36 | } 37 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/OnnxAttrs.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | #include 27 | #include 28 | #include 29 | 30 | #include "ImporterContext.hpp" 31 | 32 | class OnnxAttrs 33 | { 34 | template 35 | using string_map = std::unordered_map; 36 | typedef string_map<::onnx::AttributeProto const*> AttrMap; 37 | AttrMap _attrs; 38 | onnx2trt::IImporterContext* mCtx; 39 | 40 | public: 41 | explicit OnnxAttrs(::onnx::NodeProto const& onnx_node, onnx2trt::IImporterContext* ctx) 42 | : mCtx{ctx} 43 | { 44 | for (auto const& attr : onnx_node.attribute()) 45 | { 46 | _attrs.insert({attr.name(), &attr}); 47 | } 48 | } 49 | 50 | bool count(const std::string& key) const 51 | { 52 | return _attrs.count(key); 53 | } 54 | 55 | ::onnx::AttributeProto const* at(std::string key) const 56 | { 57 | if (!_attrs.count(key)) 58 | { 59 | throw std::out_of_range("Attribute not found: " + key); 60 | } 61 | return _attrs.at(key); 62 | } 63 | 64 | const ::onnx::AttributeProto::AttributeType type(const std::string& key) const 65 | { 66 | return this->at(key)->type(); 67 | } 68 | 69 | 70 | template 71 | T get(const std::string& key) const; 72 | 73 | template 74 | T get(const std::string& key, T const& default_value) const 75 | { 76 | return _attrs.count(key) ? this->get(key) : default_value; 77 | } 78 | }; 79 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/RNNHelpers.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "TensorOrWeights.hpp" 8 | #include "ImporterContext.hpp" 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | nvinfer1::ITensor* addRNNInput(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, std::vector& inputs, const std::string& direction); 14 | 15 | // Zeros out invalid timesteps in toMask. maxLen must be provided if reverse is true 16 | nvinfer1::ITensor* clearMissingSequenceElements(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* toMask, nvinfer1::ITensor* maxLen, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 17 | 18 | // Returns a bool tensor which is true during valid timesteps 19 | nvinfer1::ITensor* getRaggedMask(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen = nullptr, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 20 | 21 | // Selects between prevH and Ht to forward previous hidden state through invalid timesteps 22 | nvinfer1::ITensor* maskRNNHidden(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* prevH, nvinfer1::ITensor* Ht, nvinfer1::ITensor* maxLen = nullptr, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 23 | 24 | // Splits a bidirectional hidden state into forward and reverse passes, masks each using maskRNNHidden, then concatenates 25 | nvinfer1::ITensor* maskBidirRNNHidden(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen, nvinfer1::ITensor* Ht1, nvinfer1::ITensor* Ht, nvinfer1::ITensor* singlePassShape); 26 | 27 | } // namespace onnx2trt 28 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/ShapedWeights.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #include "ShapedWeights.hpp" 24 | #include "onnx2trt_utils.hpp" 25 | #include "trt_utils.hpp" 26 | #include 27 | #include 28 | 29 | namespace onnx2trt 30 | { 31 | 32 | size_t ShapedWeights::count() const 33 | { 34 | if (this->values == nullptr && this->shape.nbDims <= 0) 35 | { 36 | return 0; 37 | } 38 | // TRT supports scalars, so 0D tensors should have a count of 1. 39 | size_t c = 1; 40 | for (int i = 0; i < this->shape.nbDims; ++i) 41 | { 42 | c *= this->shape.d[i]; 43 | } 44 | return c; 45 | } 46 | 47 | ShapedWeights ShapedWeights::empty(DataType type) 48 | { 49 | return ShapedWeights(type, nullptr, nvinfer1::Dims{0}); 50 | } 51 | 52 | ShapedWeights::ShapedWeights() 53 | : values(nullptr) 54 | , shape{0} 55 | { 56 | } 57 | 58 | ShapedWeights::ShapedWeights(DataType type_, void* values_, nvinfer1::Dims shape_) 59 | : type(type_) 60 | , values(values_) 61 | , shape(shape_) 62 | { 63 | // Note: this->shape.type[] is not used 64 | } 65 | 66 | size_t ShapedWeights::size_bytes() const 67 | { 68 | return this->count() * getDtypeSize(this->type); 69 | } 70 | 71 | const char* ShapedWeights::getName() const 72 | { 73 | return this->name; 74 | } 75 | 76 | void ShapedWeights::setName(const char* name) 77 | { 78 | this->name = name; 79 | } 80 | 81 | ShapedWeights::operator bool() const 82 | { 83 | return (bool) this->values; 84 | } 85 | 86 | ShapedWeights::operator nvinfer1::Weights() const 87 | { 88 | nvinfer1::Weights w{}; 89 | w.values = this->values; 90 | bool supported_type = convertDtype(this->type, &w.type); 91 | (void) supported_type; 92 | assert(supported_type); 93 | w.count = this->count(); 94 | return w; 95 | } 96 | 97 | template 98 | void transpose2DWeights(ShapedWeights const& weights, nvinfer1::Dims const& new_shape, ShapedWeights* result) 99 | { 100 | DType const* src = reinterpret_cast(weights.values); 101 | DType* dst = reinterpret_cast(result->values); 102 | int src_stride = weights.shape.d[1]; 103 | int dst_stride = result->shape.d[1]; 104 | for (int i = 0; i < new_shape.d[0]; ++i) 105 | { 106 | for (int j = 0; j < new_shape.d[1]; ++j) 107 | { 108 | dst[i * dst_stride + j] = src[j * src_stride + i]; 109 | } 110 | } 111 | } 112 | 113 | bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result) 114 | { 115 | nvinfer1::Dims shape = weights.shape; 116 | nvinfer1::Dims new_shape; 117 | new_shape.nbDims = shape.nbDims; 118 | for (int d = 0; d < shape.nbDims; ++d) 119 | { 120 | new_shape.d[d] = shape.d[perm.order[d]]; 121 | result->shape.d[d] = new_shape.d[d]; 122 | } 123 | // TODO: Need to generalize this transpose implementation 124 | assert(perm.order[0] == 1 && perm.order[1] == 0); 125 | 126 | if (shape.nbDims == 2) 127 | { 128 | if (weights.type == ::onnx::TensorProto::FLOAT) 129 | { 130 | transpose2DWeights(weights, new_shape, result); 131 | } 132 | else if (weights.type == ::onnx::TensorProto::FLOAT16) 133 | { 134 | transpose2DWeights(weights, new_shape, result); 135 | } 136 | else 137 | { 138 | return false; 139 | } 140 | } 141 | else 142 | { 143 | // TODO: Implement general transposes and multiple data types 144 | // Unsupported weights transpose 145 | return false; 146 | } 147 | return true; 148 | } 149 | 150 | } // namespace onnx2trt 151 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/ShapedWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | #include 27 | 28 | namespace onnx2trt 29 | { 30 | 31 | class ShapedWeights 32 | { 33 | public: 34 | using DataType = int32_t; 35 | DataType type; 36 | void* values; 37 | nvinfer1::Dims shape; 38 | const char* name = nullptr; 39 | static ShapedWeights empty(DataType type); 40 | ShapedWeights(); 41 | explicit ShapedWeights(DataType type, void* values, nvinfer1::Dims shape_); 42 | size_t count() const; 43 | size_t size_bytes() const; 44 | const char* getName() const; 45 | void setName(const char* name); 46 | explicit operator bool() const; 47 | operator nvinfer1::Weights() const; 48 | }; 49 | 50 | bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result); 51 | 52 | } // namespace onnx2trt 53 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/TensorOrWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include "ShapedWeights.hpp" 26 | 27 | #include 28 | #include 29 | 30 | namespace onnx2trt 31 | { 32 | 33 | class TensorOrWeights 34 | { 35 | union 36 | { 37 | nvinfer1::ITensor* _tensor; 38 | ShapedWeights _weights; 39 | }; 40 | enum 41 | { 42 | NODE_TENSOR, 43 | NODE_WEIGHTS 44 | } _variant; 45 | 46 | public: 47 | TensorOrWeights() 48 | : _tensor(nullptr) 49 | , _variant(NODE_TENSOR) 50 | { 51 | } 52 | TensorOrWeights(nvinfer1::ITensor* tensor) 53 | : _tensor(tensor) 54 | , _variant(NODE_TENSOR) 55 | { 56 | } 57 | TensorOrWeights(ShapedWeights const& weights) 58 | : _weights(weights) 59 | , _variant(NODE_WEIGHTS) 60 | { 61 | } 62 | bool is_tensor() const 63 | { 64 | return _variant == NODE_TENSOR; 65 | } 66 | bool is_weights() const 67 | { 68 | return _variant == NODE_WEIGHTS; 69 | } 70 | bool isNullTensor() const 71 | { 72 | return is_tensor() && _tensor == nullptr; 73 | } 74 | nvinfer1::ITensor& tensor() 75 | { 76 | assert(!isNullTensor()); 77 | return *_tensor; 78 | } 79 | nvinfer1::ITensor const& tensor() const 80 | { 81 | assert(!isNullTensor()); 82 | return *_tensor; 83 | } 84 | ShapedWeights& weights() 85 | { 86 | assert(is_weights()); 87 | return _weights; 88 | } 89 | ShapedWeights const& weights() const 90 | { 91 | assert(is_weights()); 92 | return _weights; 93 | } 94 | nvinfer1::Dims shape() const 95 | { 96 | return is_tensor() ? _tensor->getDimensions() : _weights.shape; 97 | } 98 | explicit operator bool() const 99 | { 100 | return is_tensor() ? _tensor != nullptr : static_cast(_weights); 101 | } 102 | bool isInt32() const 103 | { 104 | return is_tensor() ? _tensor->getType() == nvinfer1::DataType::kINT32 : _weights.type == ::onnx::TensorProto_DataType_INT32; 105 | } 106 | bool isBool() const 107 | { 108 | return is_tensor() ? _tensor->getType() == nvinfer1::DataType::kBOOL : _weights.type == ::onnx::TensorProto_DataType_BOOL; 109 | } 110 | std::string getName() const 111 | { 112 | return is_tensor() ? _tensor->getName() : _weights.getName(); 113 | } 114 | }; 115 | 116 | } // namespace onnx2trt 117 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/builtin_op_importers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include "onnx2trt.hpp" 26 | #include "utils.hpp" 27 | 28 | namespace onnx2trt 29 | { 30 | 31 | string_map& getBuiltinOpImporterMap(); 32 | 33 | } // namespace onnx2trt 34 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/common.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | #pragma once 23 | 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include 29 | #include // For ::open 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | // Namespace for common functions used throughout onnx-trt 36 | namespace common 37 | { 38 | struct InferDeleter { 39 | template 40 | void operator()(T* obj) const { 41 | if( obj ) { 42 | obj->destroy(); 43 | } 44 | } 45 | }; 46 | 47 | template 48 | inline std::shared_ptr infer_object(T* obj) { 49 | if( !obj ) { 50 | throw std::runtime_error("Failed to create object"); 51 | } 52 | return std::shared_ptr(obj, InferDeleter()); 53 | } 54 | 55 | // Logger for TensorRT info/warning/errors 56 | class TRT_Logger : public nvinfer1::ILogger { 57 | nvinfer1::ILogger::Severity _verbosity; 58 | std::ostream* _ostream; 59 | public: 60 | TRT_Logger(Severity verbosity=Severity::kWARNING, 61 | std::ostream& ostream=std::cout) 62 | : _verbosity(verbosity), _ostream(&ostream) {} 63 | void log(Severity severity, const char* msg) override { 64 | if( severity <= _verbosity ) { 65 | time_t rawtime = std::time(0); 66 | char buf[256]; 67 | strftime(&buf[0], 256, 68 | "%Y-%m-%d %H:%M:%S", 69 | std::gmtime(&rawtime)); 70 | const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" : 71 | severity == Severity::kERROR ? " ERROR" : 72 | severity == Severity::kWARNING ? "WARNING" : 73 | severity == Severity::kINFO ? " INFO" : 74 | "UNKNOWN"); 75 | (*_ostream) << "[" << buf << " " << sevstr << "] " 76 | << msg 77 | << std::endl; 78 | } 79 | } 80 | }; 81 | 82 | inline bool ParseFromFile_WAR(google::protobuf::Message* msg, 83 | const char* filename) { 84 | int fd = ::open(filename, O_RDONLY); 85 | google::protobuf::io::FileInputStream raw_input(fd); 86 | raw_input.SetCloseOnDelete(true); 87 | google::protobuf::io::CodedInputStream coded_input(&raw_input); 88 | // Note: This WARs the very low default size limit (64MB) 89 | coded_input.SetTotalBytesLimit(std::numeric_limits::max()); 90 | return msg->ParseFromCodedStream(&coded_input); 91 | } 92 | 93 | inline bool MessageToFile(const google::protobuf::Message* msg, 94 | const char* filename) { 95 | int fd = ::open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); 96 | google::protobuf::io::FileOutputStream raw_output(fd); 97 | raw_output.SetCloseOnDelete(true); 98 | google::protobuf::io::CodedOutputStream output(&raw_output); 99 | 100 | // Write the size. 101 | const int size = msg->ByteSize(); 102 | 103 | uint8_t* buffer = output.GetDirectBufferForNBytesAndAdvance(size); 104 | if (buffer != NULL) { 105 | // Optimization: The msg fits in one buffer, so use the faster 106 | // direct-to-array serialization path. 107 | msg->SerializeWithCachedSizesToArray(buffer); 108 | } else { 109 | // Slightly-slower path when the msg is multiple buffers. 110 | msg->SerializeWithCachedSizes(&output); 111 | if (output.HadError()) return false; 112 | } 113 | 114 | return true; 115 | } 116 | 117 | inline bool ParseFromTextFile(google::protobuf::Message* msg, 118 | const char* filename) { 119 | int fd = ::open(filename, O_RDONLY); 120 | google::protobuf::io::FileInputStream raw_input(fd); 121 | raw_input.SetCloseOnDelete(true); 122 | return google::protobuf::TextFormat::Parse(&raw_input, msg); 123 | } 124 | 125 | inline std::string onnx_ir_version_string(int64_t ir_version=::onnx::IR_VERSION) { 126 | int onnx_ir_major = ir_version / 1000000; 127 | int onnx_ir_minor = ir_version % 1000000 / 10000; 128 | int onnx_ir_patch = ir_version % 10000; 129 | return (std::to_string(onnx_ir_major) + "." + 130 | std::to_string(onnx_ir_minor) + "." + 131 | std::to_string(onnx_ir_patch)); 132 | } 133 | 134 | inline void print_version() { 135 | std::cout << "Parser built against:" << std::endl; 136 | std::cout << " ONNX IR version: " << onnx_ir_version_string(::onnx::IR_VERSION) << std::endl; 137 | std::cout << " TensorRT version: " 138 | << NV_TENSORRT_MAJOR << "." 139 | << NV_TENSORRT_MINOR << "." 140 | << NV_TENSORRT_PATCH << std::endl; 141 | } 142 | } // namespace common 143 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/onnx2trt.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include "NvOnnxParser.h" 26 | #include "ShapedWeights.hpp" 27 | #include "Status.hpp" 28 | #include "TensorOrWeights.hpp" 29 | 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | #include 37 | 38 | using WeightsPair_t = std::pair; 39 | 40 | using RefitMap_t = std::unordered_map; 41 | 42 | namespace onnx2trt 43 | { 44 | 45 | class IImporterContext; 46 | 47 | // TODO: Find ABI-safe alternative approach for this: 48 | // Can't use std::vector 49 | // Can't use ::onnx::NodeProto 50 | // Can't use std::function 51 | typedef ValueOrStatus> NodeImportResult; 52 | typedef std::function& inputs)> 54 | NodeImporter; 55 | 56 | template 57 | using StringMap = std::unordered_map; 58 | 59 | class IImporterContext 60 | { 61 | public: 62 | virtual nvinfer1::INetworkDefinition* network() = 0; 63 | virtual StringMap& tensors() = 0; 64 | virtual StringMap& tensorLocations() = 0; 65 | virtual StringMap& tensorRangeMins() = 0; 66 | virtual StringMap& tensorRangeMaxes() = 0; 67 | virtual StringMap& layerPrecisions() = 0; 68 | virtual std::unordered_set& unsupportedShapeTensors() = 0; 69 | virtual StringMap& loopTensors() = 0; 70 | virtual void setOnnxFileLocation(std::string location) = 0; 71 | virtual std::string getOnnxFileLocation() = 0; 72 | virtual void registerTensor(TensorOrWeights tensor, const std::string& basename) = 0; 73 | virtual void registerLayer(nvinfer1::ILayer* layer, const std::string& basename) = 0; 74 | virtual ShapedWeights createTempWeights(ShapedWeights::DataType type, nvinfer1::Dims shape) = 0; 75 | virtual int64_t getOpsetVersion(const char* domain = "") const = 0; 76 | virtual nvinfer1::ILogger& logger() = 0; 77 | virtual void insertRefitMap(std::string weightsName, std::string layerName, nvinfer1::WeightsRole role) = 0; 78 | 79 | protected: 80 | virtual ~IImporterContext() 81 | { 82 | } 83 | }; 84 | 85 | } // namespace onnx2trt 86 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/onnx2trt_common.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | #include 27 | 28 | #if NV_TENSORRT_MAJOR < 4 29 | namespace nvinfer1 30 | { 31 | 32 | enum class PluginFormat : uint8_t 33 | { 34 | kNCHW = 0, //!< NCHW 35 | kNC2HW2 = 1, //!< NCHW with 2-element packed channels 36 | kNHWC8 = 2 //!< NHWC with 8-element packed channels (C 37 | //! must be a multiple of 8) 38 | }; 39 | // from NvInfer.h 40 | class IPluginExt : public IPlugin 41 | { 42 | public: 43 | virtual int getTensorRTVersion() const 44 | { 45 | return NV_TENSORRT_VERSION; 46 | } 47 | virtual bool supportsFormat(DataType type, PluginFormat format) const = 0; 48 | virtual void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, 49 | DataType type, PluginFormat format, int maxBatchSize) 50 | = 0; 51 | 52 | protected: 53 | void configure(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) final 54 | { 55 | DataType type = nvinfer1::DataType::kFLOAT; 56 | PluginFormat format = nvinfer1::PluginFormat::kNCHW; 57 | return this->configureWithFormat(inputDims, nbInputs, outputDims, nbOutputs, type, format, maxBatchSize); 58 | } 59 | virtual ~IPluginExt() 60 | { 61 | } 62 | }; 63 | 64 | } // namespace nvinfer1 65 | #endif 66 | 67 | namespace onnx2trt 68 | { 69 | 70 | struct IOwnable 71 | { 72 | virtual void destroy() = 0; 73 | 74 | protected: 75 | virtual ~IOwnable() 76 | { 77 | } 78 | }; 79 | 80 | struct OwnableDeleter 81 | { 82 | void operator()(IOwnable* obj) const 83 | { 84 | obj->destroy(); 85 | } 86 | }; 87 | 88 | using UniqueOwnable = std::unique_ptr; 89 | class Plugin; 90 | class PluginV2; 91 | 92 | } // namespace onnx2trt 93 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/onnx2trt_runtime.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include "onnx2trt_common.hpp" 26 | 27 | namespace onnx2trt 28 | { 29 | 30 | typedef Plugin* (*plugin_deserializer)(const void* serialData, size_t serialLength); 31 | 32 | } // namespace onnx2trt 33 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/toposort.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | #include 27 | 28 | #include 29 | using std::cout; 30 | using std::cerr; 31 | using std::endl; 32 | 33 | namespace 34 | { 35 | 36 | enum NodeState 37 | { 38 | NODE_UNVISITED, 39 | NODE_ACTIVE, 40 | NODE_VISITED 41 | }; 42 | 43 | template 44 | bool get_post_order(size_t node_idx, Container const& nodes, std::unordered_map const& node_map, 45 | std::vector* node_states, std::vector* order) 46 | { 47 | NodeState& node_state = node_states->at(node_idx); 48 | if (node_state == NODE_ACTIVE) 49 | { 50 | // Cycle detected! 51 | cerr << "ERROR: Graph contains a cycle" << endl; 52 | return false; 53 | } 54 | else if (node_state == NODE_VISITED) 55 | { 56 | return true; 57 | } 58 | else 59 | { 60 | node_state = NODE_ACTIVE; 61 | // TODO: This .Get().input() is highly specific to protobuf, should 62 | // generalise it somehow. 63 | for (auto const& input : nodes.Get(node_idx).input()) 64 | { 65 | if (!node_map.count(input)) 66 | { 67 | // Input node not found in graph! 68 | // cerr << "ERROR: Input node not found in graph: " 69 | // << input << endl; 70 | // return false; 71 | continue; // Skip missing input edges 72 | } 73 | size_t input_node_idx = node_map.at(input); 74 | if (!get_post_order(input_node_idx, nodes, node_map, node_states, order)) 75 | { 76 | return false; 77 | } 78 | } 79 | node_state = NODE_VISITED; 80 | order->push_back(node_idx); 81 | } 82 | return true; 83 | } 84 | 85 | } // anonymous namespace 86 | 87 | template 88 | bool toposort(Container const& nodes, std::vector* order) 89 | { 90 | std::unordered_map node_map; 91 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 92 | { 93 | // TODO: This .Get().input() is highly specific to protobuf, should 94 | // generalise it somehow. 95 | for (auto const& output : nodes.Get(i).output()) 96 | { 97 | if (!node_map.emplace(output, i).second) 98 | { 99 | // Output name appears more than once in graph! 100 | cerr << "ERROR: Output name is not unique: " << output << endl; 101 | return false; 102 | } 103 | } 104 | } 105 | order->reserve(nodes.size()); 106 | std::vector node_states(nodes.size(), NODE_UNVISITED); 107 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 108 | { 109 | if (!get_post_order(i, nodes, node_map, &node_states, order)) 110 | { 111 | return false; 112 | } 113 | } 114 | return true; 115 | } 116 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_7.x/utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Permission is hereby granted, free of charge, to any person obtaining a 5 | * copy of this software and associated documentation files (the "Software"), 6 | * to deal in the Software without restriction, including without limitation 7 | * the rights to use, copy, modify, merge, publish, distribute, sublicense, 8 | * and/or sell copies of the Software, and to permit persons to whom the 9 | * Software is furnished to do so, subject to the following conditions: 10 | * 11 | * The above copyright notice and this permission notice shall be included in 12 | * all copies or substantial portions of the Software. 13 | * 14 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 17 | * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 19 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 20 | * DEALINGS IN THE SOFTWARE. 21 | */ 22 | 23 | #pragma once 24 | 25 | #include 26 | 27 | template 28 | using string_map = std::unordered_map; 29 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/LoopHelpers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "LoopHelpers.hpp" 6 | #include "onnx2trt_utils.hpp" 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial) 12 | { 13 | nvinfer1::ITensor* initialTensor = addConstantScalar(ctx, initial, ::onnx::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0); 14 | nvinfer1::ITensor* one = addConstantScalar(ctx, 1, ::onnx::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0); 15 | 16 | auto counter = loop->addRecurrence(*initialTensor); 17 | nvinfer1::ITensor* addOne = ctx->network()->addElementWise(*counter->getOutput(0), *one, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); 18 | counter->setInput(1, *addOne); 19 | return counter->getOutput(0); 20 | } 21 | 22 | } // namespace onnx2trt 23 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/LoopHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include "ImporterContext.hpp" 10 | 11 | namespace onnx2trt 12 | { 13 | 14 | nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial = 0); 15 | 16 | } // namespace onnx2trt 17 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/ModelImporter.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ImporterContext.hpp" 8 | #include "NvInferPlugin.h" 9 | #include "NvOnnxParser.h" 10 | #include "builtin_op_importers.hpp" 11 | #include "utils.hpp" 12 | 13 | namespace onnx2trt 14 | { 15 | 16 | Status parseGraph(IImporterContext* ctx, const ::onnx::GraphProto& graph, bool deserializingINetwork = false, int* currentNode = nullptr); 17 | 18 | class ModelImporter : public nvonnxparser::IParser 19 | { 20 | protected: 21 | string_map _op_importers; 22 | virtual Status importModel(::onnx::ModelProto const& model); 23 | 24 | private: 25 | ImporterContext _importer_ctx; 26 | std::list<::onnx::ModelProto> _onnx_models; // Needed for ownership of weights 27 | int _current_node; 28 | std::vector _errors; 29 | std::vector _input_dims; 30 | 31 | public: 32 | ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger, const std::vector& input_dims) 33 | : _op_importers(getBuiltinOpImporterMap()) 34 | , _importer_ctx(network, logger) 35 | , _input_dims(input_dims) 36 | { 37 | } 38 | bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override; 39 | bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) override; 40 | bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size, 41 | SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) override; 42 | 43 | bool supportsOperator(const char* op_name) const override; 44 | void destroy() override 45 | { 46 | delete this; 47 | } 48 | // virtual void registerOpImporter(std::string op, 49 | // NodeImporter const &node_importer) override { 50 | // // Note: This allows existing importers to be replaced 51 | // _op_importers[op] = node_importer; 52 | //} 53 | // virtual Status const &setInput(const char *name, 54 | // nvinfer1::ITensor *input) override; 55 | // virtual Status const& setOutput(const char* name, nvinfer1::ITensor** output) override; 56 | int getNbErrors() const override 57 | { 58 | return _errors.size(); 59 | } 60 | nvonnxparser::IParserError const* getError(int index) const override 61 | { 62 | assert(0 <= index && index < (int) _errors.size()); 63 | return &_errors[index]; 64 | } 65 | void clearErrors() override 66 | { 67 | _errors.clear(); 68 | } 69 | 70 | //...LG: Move the implementation to .cpp 71 | bool parseFromFile(const char* onnxModelFile, int verbosity) override; 72 | bool parseFromData(const void* onnx_data, size_t size, int verbosity) override; 73 | }; 74 | 75 | } // namespace onnx2trt 76 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/NvOnnxParser.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "NvOnnxParser.h" 6 | #include "ModelImporter.hpp" 7 | 8 | extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version, const std::vector& input_dims) 9 | { 10 | auto network = static_cast(network_); 11 | auto logger = static_cast(logger_); 12 | return new onnx2trt::ModelImporter(network, logger, input_dims); 13 | } 14 | 15 | extern "C" int getNvOnnxParserVersion() 16 | { 17 | return NV_ONNX_PARSER_VERSION; 18 | } -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/OnnxAttrs.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "ImporterContext.hpp" 13 | 14 | class OnnxAttrs 15 | { 16 | template 17 | using string_map = std::unordered_map; 18 | typedef string_map<::onnx::AttributeProto const*> AttrMap; 19 | AttrMap _attrs; 20 | onnx2trt::IImporterContext* mCtx; 21 | 22 | public: 23 | explicit OnnxAttrs(::onnx::NodeProto const& onnx_node, onnx2trt::IImporterContext* ctx) 24 | : mCtx{ctx} 25 | { 26 | for (auto const& attr : onnx_node.attribute()) 27 | { 28 | _attrs.insert({attr.name(), &attr}); 29 | } 30 | } 31 | 32 | bool count(const std::string& key) const 33 | { 34 | return _attrs.count(key); 35 | } 36 | 37 | ::onnx::AttributeProto const* at(std::string key) const 38 | { 39 | if (!_attrs.count(key)) 40 | { 41 | throw std::out_of_range("Attribute not found: " + key); 42 | } 43 | return _attrs.at(key); 44 | } 45 | 46 | ::onnx::AttributeProto::AttributeType type(const std::string& key) const 47 | { 48 | return this->at(key)->type(); 49 | } 50 | 51 | 52 | template 53 | T get(const std::string& key) const; 54 | 55 | template 56 | T get(const std::string& key, T const& default_value) const 57 | { 58 | return _attrs.count(key) ? this->get(key) : default_value; 59 | } 60 | }; 61 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/RNNHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "TensorOrWeights.hpp" 12 | #include "ImporterContext.hpp" 13 | 14 | namespace onnx2trt 15 | { 16 | 17 | nvinfer1::ITensor* addRNNInput(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, std::vector& inputs, const std::string& direction); 18 | 19 | // Zeros out invalid timesteps in toMask. maxLen must be provided if reverse is true 20 | nvinfer1::ITensor* clearMissingSequenceElements(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* toMask, nvinfer1::ITensor* maxLen, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 21 | 22 | // Returns a bool tensor which is true during valid timesteps 23 | nvinfer1::ITensor* getRaggedMask(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen = nullptr, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 24 | 25 | // Selects between prevH and Ht to forward previous hidden state through invalid timesteps 26 | nvinfer1::ITensor* maskRNNHidden(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* prevH, nvinfer1::ITensor* Ht, nvinfer1::ITensor* maxLen = nullptr, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 27 | 28 | // Splits a bidirectional hidden state into forward and reverse passes, masks each using maskRNNHidden, then concatenates 29 | nvinfer1::ITensor* maskBidirRNNHidden(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen, nvinfer1::ITensor* Ht1, nvinfer1::ITensor* Ht, nvinfer1::ITensor* singlePassShape); 30 | 31 | } // namespace onnx2trt 32 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/ShapedWeights.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "ShapedWeights.hpp" 6 | #include "onnx2trt_utils.hpp" 7 | #include "trt_utils.hpp" 8 | #include 9 | #include 10 | 11 | namespace onnx2trt 12 | { 13 | 14 | size_t ShapedWeights::count() const 15 | { 16 | if (this->values == nullptr && this->shape.nbDims <= 0) 17 | { 18 | return 0; 19 | } 20 | // TRT supports scalars, so 0D tensors should have a count of 1. 21 | size_t c = 1; 22 | for (int i = 0; i < this->shape.nbDims; ++i) 23 | { 24 | c *= this->shape.d[i]; 25 | } 26 | return c; 27 | } 28 | 29 | ShapedWeights ShapedWeights::empty(DataType type) 30 | { 31 | return ShapedWeights(type, nullptr, nvinfer1::Dims{0}); 32 | } 33 | 34 | ShapedWeights::ShapedWeights() 35 | : values(nullptr) 36 | , shape{0} 37 | { 38 | } 39 | 40 | ShapedWeights::ShapedWeights(DataType type_, void* values_, nvinfer1::Dims shape_) 41 | : type(type_) 42 | , values(values_) 43 | , shape(shape_) 44 | { 45 | // Note: this->shape.type[] is not used 46 | } 47 | 48 | size_t ShapedWeights::size_bytes() const 49 | { 50 | return this->count() * getDtypeSize(this->type); 51 | } 52 | 53 | ShapedWeights::operator bool() const 54 | { 55 | return (bool) this->values; 56 | } 57 | 58 | ShapedWeights::operator nvinfer1::Weights() const 59 | { 60 | nvinfer1::Weights w{}; 61 | w.values = this->values; 62 | bool supported_type = convertDtype(this->type, &w.type); 63 | (void) supported_type; 64 | assert(supported_type); 65 | w.count = this->count(); 66 | return w; 67 | } 68 | 69 | const char* ShapedWeights::getName() const 70 | { 71 | return this->name; 72 | } 73 | 74 | void ShapedWeights::setName(const char* name) 75 | { 76 | this->name = name; 77 | } 78 | 79 | template 80 | void transpose4DWeights(ShapedWeights const& weights, nvinfer1::Permutation const perm, ShapedWeights* result) 81 | { 82 | nvinfer1::Dims original_shape = weights.shape; 83 | nvinfer1::Dims new_shape = result->shape; 84 | int nbDims = new_shape.nbDims; 85 | DType const* src = reinterpret_cast(weights.values); 86 | DType* dst = reinterpret_cast(result->values); 87 | 88 | nvinfer1::Dims expanded_original_shape{4, {1, 1, 1, 1}}; 89 | nvinfer1::Dims expanded_new_shape{4, {1, 1, 1, 1}}; 90 | nvinfer1::Permutation expanded_perm{0, 1, 2, 3}; 91 | 92 | int pad = 4 - nbDims; 93 | for (int i = 0; i < nbDims; ++i) 94 | { 95 | expanded_original_shape.d[pad + i] = original_shape.d[i]; 96 | expanded_new_shape.d[pad + i] = new_shape.d[i]; 97 | expanded_perm.order[pad + i] = perm.order[i] + pad; 98 | } 99 | 100 | 101 | int src_strides[4] = {1, 1, 1, 1}; 102 | int dst_strides[4] = {1, 1, 1, 1}; 103 | 104 | for (int i = 2; i >= 0; --i) 105 | { 106 | src_strides[i] = expanded_original_shape.d[i + 1] * src_strides[i + 1]; 107 | dst_strides[i] = expanded_new_shape.d[i + 1] * dst_strides[i + 1]; 108 | } 109 | 110 | for (int n = 0; n < expanded_original_shape.d[0]; ++n) 111 | { 112 | for (int c = 0; c < expanded_original_shape.d[1]; ++c) 113 | { 114 | for (int h = 0; h < expanded_original_shape.d[2]; ++h) 115 | { 116 | for (int w = 0; w < expanded_original_shape.d[3]; ++w) 117 | { 118 | int src_index = 0; 119 | int dst_index = 0; 120 | int src_coord[4] = {n, c, h, w}; 121 | int dst_coord[4]; 122 | for (int i = 0 ; i < 4; ++i) 123 | { 124 | dst_coord[i] = src_coord[expanded_perm.order[i]]; 125 | src_index += src_coord[i] * src_strides[i]; 126 | dst_index += dst_coord[i] * dst_strides[i]; 127 | } 128 | dst[dst_index] = src[src_index]; 129 | } 130 | } 131 | } 132 | } 133 | } 134 | 135 | bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result, IImporterContext* ctx) 136 | { 137 | nvinfer1::Dims shape = weights.shape; 138 | int nbDims = shape.nbDims; 139 | nvinfer1::Dims new_shape; 140 | new_shape.nbDims = nbDims; 141 | for (int d = 0; d < nbDims; ++d) 142 | { 143 | new_shape.d[d] = shape.d[perm.order[d]]; 144 | result->shape.d[d] = new_shape.d[d]; 145 | } 146 | 147 | if (shape.nbDims <= 4) 148 | { 149 | if (weights.type == ::onnx::TensorProto::FLOAT) 150 | { 151 | transpose4DWeights(weights, perm, result); 152 | } 153 | else if (weights.type == ::onnx::TensorProto::FLOAT16) 154 | { 155 | transpose4DWeights(weights, perm, result); 156 | } 157 | else 158 | { 159 | return false; 160 | } 161 | } 162 | else 163 | { 164 | // TODO: Implement general transposes and multiple data types 165 | // Unsupported weights transpose 166 | return false; 167 | } 168 | nvinfer1::Dims permDims{nbDims, {}}; 169 | std::copy_n(perm.order, nbDims, permDims.d); 170 | LOG_WARNING("Weights " 171 | << weights.getName() << " has been transposed with permutation of " << permDims 172 | << "! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed."); 173 | result->setName(weights.getName()); 174 | return true; 175 | } 176 | 177 | } // namespace onnx2trt 178 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/ShapedWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | class ShapedWeights 14 | { 15 | public: 16 | using DataType = int32_t; 17 | 18 | static ShapedWeights empty(DataType type); 19 | 20 | ShapedWeights(); 21 | 22 | explicit ShapedWeights(DataType type, void* values, nvinfer1::Dims shape_); 23 | 24 | size_t count() const; 25 | 26 | size_t size_bytes() const; 27 | 28 | const char* getName() const; 29 | 30 | void setName(const char* name); 31 | 32 | explicit operator bool() const; 33 | 34 | operator nvinfer1::Weights() const; 35 | 36 | template 37 | T& at(size_t index) 38 | { 39 | assert(index >= 0 && (index * sizeof(T)) < size_bytes()); 40 | return static_cast(values)[index]; 41 | } 42 | 43 | template 44 | const T& at(size_t index) const 45 | { 46 | assert(index >= 0 && (index * sizeof(T)) < size_bytes()); 47 | return static_cast(values)[index]; 48 | } 49 | 50 | public: 51 | DataType type; 52 | void* values; 53 | nvinfer1::Dims shape; 54 | const char* name{}; 55 | }; 56 | 57 | class IImporterContext; 58 | bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result, IImporterContext* ctx); 59 | 60 | } // namespace onnx2trt 61 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/TensorOrWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ShapedWeights.hpp" 8 | 9 | #include 10 | #include 11 | 12 | namespace onnx2trt 13 | { 14 | 15 | class TensorOrWeights 16 | { 17 | union 18 | { 19 | nvinfer1::ITensor* _tensor; 20 | ShapedWeights _weights; 21 | }; 22 | enum 23 | { 24 | NODE_TENSOR, 25 | NODE_WEIGHTS 26 | } _variant; 27 | 28 | public: 29 | TensorOrWeights() 30 | : _tensor(nullptr) 31 | , _variant(NODE_TENSOR) 32 | { 33 | } 34 | TensorOrWeights(nvinfer1::ITensor* tensor) 35 | : _tensor(tensor) 36 | , _variant(NODE_TENSOR) 37 | { 38 | } 39 | TensorOrWeights(ShapedWeights const& weights) 40 | : _weights(weights) 41 | , _variant(NODE_WEIGHTS) 42 | { 43 | } 44 | bool is_tensor() const 45 | { 46 | return _variant == NODE_TENSOR; 47 | } 48 | bool is_weights() const 49 | { 50 | return _variant == NODE_WEIGHTS; 51 | } 52 | bool isNullTensor() const 53 | { 54 | return is_tensor() && _tensor == nullptr; 55 | } 56 | nvinfer1::ITensor& tensor() 57 | { 58 | assert(!isNullTensor()); 59 | return *_tensor; 60 | } 61 | nvinfer1::ITensor const& tensor() const 62 | { 63 | assert(!isNullTensor()); 64 | return *_tensor; 65 | } 66 | ShapedWeights& weights() 67 | { 68 | assert(is_weights()); 69 | return _weights; 70 | } 71 | ShapedWeights const& weights() const 72 | { 73 | assert(is_weights()); 74 | return _weights; 75 | } 76 | nvinfer1::Dims shape() const 77 | { 78 | return is_tensor() ? _tensor->getDimensions() : _weights.shape; 79 | } 80 | explicit operator bool() const 81 | { 82 | return is_tensor() ? _tensor != nullptr : static_cast(_weights); 83 | } 84 | bool isInt32() const 85 | { 86 | return is_tensor() ? _tensor->getType() == nvinfer1::DataType::kINT32 : _weights.type == ::onnx::TensorProto_DataType_INT32; 87 | } 88 | bool isBool() const 89 | { 90 | return is_tensor() ? _tensor->getType() == nvinfer1::DataType::kBOOL : _weights.type == ::onnx::TensorProto_DataType_BOOL; 91 | } 92 | std::string getName() const 93 | { 94 | return is_tensor() ? _tensor->getName() : _weights.getName(); 95 | } 96 | std::string getType() const 97 | { 98 | if (is_tensor()) 99 | { 100 | switch(_tensor->getType()) 101 | { 102 | case nvinfer1::DataType::kFLOAT:return "FLOAT"; 103 | case nvinfer1::DataType::kHALF: return "HALF"; 104 | case nvinfer1::DataType::kINT8: return "INT8"; 105 | case nvinfer1::DataType::kINT32: return "INT32"; 106 | case nvinfer1::DataType::kBOOL: return "BOOL"; 107 | default: return "UNKNOWN TYPE"; 108 | } 109 | } 110 | else 111 | { 112 | switch(_weights.type) 113 | { 114 | case ::onnx::TensorProto::DOUBLE: return "DOUBLE -> FLOAT"; 115 | case ::onnx::TensorProto::FLOAT: return "FLOAT"; 116 | case ::onnx::TensorProto::INT8: return "INT8"; 117 | case ::onnx::TensorProto::FLOAT16: return "HALF"; 118 | case ::onnx::TensorProto::BOOL: return "BOOL"; 119 | case ::onnx::TensorProto::INT32: return "INT32"; 120 | case ::onnx::TensorProto::INT64: return "INT64 -> INT32"; 121 | default: return "UNKNOWN TYPE"; 122 | } 123 | } 124 | } 125 | }; 126 | 127 | } // namespace onnx2trt 128 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/builtin_op_importers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "onnx2trt.hpp" 8 | #include "utils.hpp" 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | string_map& getBuiltinOpImporterMap(); 14 | 15 | } // namespace onnx2trt 16 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/onnx2trt.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "NvOnnxParser.h" 8 | #include "ShapedWeights.hpp" 9 | #include "Status.hpp" 10 | #include "TensorOrWeights.hpp" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | namespace onnx2trt 21 | { 22 | 23 | class IImporterContext; 24 | 25 | // TODO: Find ABI-safe alternative approach for this: 26 | // Can't use std::vector 27 | // Can't use ::onnx::NodeProto 28 | // Can't use std::function 29 | typedef ValueOrStatus> NodeImportResult; 30 | typedef std::function& inputs)> 32 | NodeImporter; 33 | 34 | template 35 | using StringMap = std::unordered_map; 36 | 37 | class IImporterContext 38 | { 39 | public: 40 | virtual nvinfer1::INetworkDefinition* network() = 0; 41 | virtual StringMap& tensors() = 0; 42 | virtual StringMap& tensorLocations() = 0; 43 | virtual StringMap& tensorRangeMins() = 0; 44 | virtual StringMap& tensorRangeMaxes() = 0; 45 | virtual StringMap& layerPrecisions() = 0; 46 | virtual std::unordered_set& unsupportedShapeTensors() = 0; 47 | virtual StringMap& loopTensors() = 0; 48 | virtual void setOnnxFileLocation(std::string location) = 0; 49 | virtual std::string getOnnxFileLocation() = 0; 50 | virtual void registerTensor(TensorOrWeights tensor, const std::string& basename) = 0; 51 | virtual void registerLayer(nvinfer1::ILayer* layer, const std::string& basename) = 0; 52 | virtual ShapedWeights createTempWeights(ShapedWeights::DataType type, nvinfer1::Dims shape, uint8_t value = 0) = 0; 53 | virtual int64_t getOpsetVersion(const char* domain = "") const = 0; 54 | virtual nvinfer1::ILogger& logger() = 0; 55 | virtual bool hasError() const = 0; 56 | virtual nvinfer1::IErrorRecorder* getErrorRecorder() const = 0; 57 | 58 | protected: 59 | virtual ~IImporterContext() 60 | { 61 | } 62 | }; 63 | 64 | } // namespace onnx2trt 65 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/onnx2trt_common.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #if NV_TENSORRT_MAJOR < 4 11 | namespace nvinfer1 12 | { 13 | 14 | enum class PluginFormat : uint8_t 15 | { 16 | kNCHW = 0, //!< NCHW 17 | kNC2HW2 = 1, //!< NCHW with 2-element packed channels 18 | kNHWC8 = 2 //!< NHWC with 8-element packed channels (C 19 | //! must be a multiple of 8) 20 | }; 21 | // from NvInfer.h 22 | class IPluginExt : public IPlugin 23 | { 24 | public: 25 | virtual int getTensorRTVersion() const noexcept 26 | { 27 | return NV_TENSORRT_VERSION; 28 | } 29 | virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0; 30 | virtual void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, 31 | DataType type, PluginFormat format, int maxBatchSize) noexcept 32 | = 0; 33 | 34 | protected: 35 | void configure( 36 | const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) noexcept final 37 | { 38 | try 39 | { 40 | DataType type = nvinfer1::DataType::kFLOAT; 41 | PluginFormat format = nvinfer1::PluginFormat::kLINEAR; 42 | return this->configureWithFormat(inputDims, nbInputs, outputDims, nbOutputs, type, format, maxBatchSize); 43 | } 44 | catch (const std::exception& e) 45 | { 46 | nvinfer1::getLogger()->log(nvinfer1::ILogger::Severity::kERROR, e.what().c_str()); 47 | } 48 | } 49 | virtual ~IPluginExt() 50 | { 51 | } 52 | }; 53 | 54 | } // namespace nvinfer1 55 | #endif 56 | 57 | namespace onnx2trt 58 | { 59 | 60 | struct IOwnable 61 | { 62 | virtual void destroy() = 0; 63 | 64 | protected: 65 | virtual ~IOwnable() 66 | { 67 | } 68 | }; 69 | 70 | struct OwnableDeleter 71 | { 72 | void operator()(IOwnable* obj) const 73 | { 74 | obj->destroy(); 75 | } 76 | }; 77 | 78 | using UniqueOwnable = std::unique_ptr; 79 | class Plugin; 80 | class PluginV2; 81 | 82 | } // namespace onnx2trt 83 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/onnx2trt_runtime.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "onnx2trt_common.hpp" 8 | 9 | namespace onnx2trt 10 | { 11 | 12 | typedef Plugin* (*plugin_deserializer)(const void* serialData, size_t serialLength); 13 | 14 | } // namespace onnx2trt 15 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/onnxErrorRecorder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "onnxErrorRecorder.hpp" 6 | #include 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | 12 | ONNXParserErrorRecorder* ONNXParserErrorRecorder::create( 13 | nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder) 14 | { 15 | try 16 | { 17 | auto recorder = new ONNXParserErrorRecorder(logger, otherRecorder); 18 | if (recorder) 19 | { 20 | recorder->incRefCount(); 21 | } 22 | return recorder; 23 | } 24 | catch (const std::exception& e) 25 | { 26 | logError(logger, e.what()); 27 | return nullptr; 28 | } 29 | } 30 | 31 | void ONNXParserErrorRecorder::destroy(ONNXParserErrorRecorder*& recorder) 32 | { 33 | if (recorder) 34 | { 35 | recorder->decRefCount(); 36 | recorder = nullptr; 37 | } 38 | } 39 | 40 | void ONNXParserErrorRecorder::logError(nvinfer1::ILogger* logger, const char* str) 41 | { 42 | if (logger) 43 | { 44 | logger->log(ILogger::Severity::kERROR, str); 45 | } 46 | } 47 | 48 | ONNXParserErrorRecorder::ONNXParserErrorRecorder( 49 | nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder) 50 | : mUserRecorder(otherRecorder) 51 | , mLogger(logger) 52 | { 53 | if (mUserRecorder) 54 | { 55 | mUserRecorder->incRefCount(); 56 | } 57 | } 58 | 59 | ONNXParserErrorRecorder::~ONNXParserErrorRecorder() noexcept 60 | { 61 | if (mUserRecorder) 62 | { 63 | mUserRecorder->decRefCount(); 64 | } 65 | } 66 | 67 | void ONNXParserErrorRecorder::clear() noexcept 68 | { 69 | try 70 | { 71 | // grab a lock so that there is no addition while clearing. 72 | std::lock_guard guard(mStackLock); 73 | mErrorStack.clear(); 74 | } 75 | catch (const std::exception& e) 76 | { 77 | logError(mLogger, e.what()); 78 | } 79 | }; 80 | 81 | bool ONNXParserErrorRecorder::reportError( 82 | nvinfer1::ErrorCode val, nvinfer1::IErrorRecorder::ErrorDesc desc) noexcept 83 | { 84 | try 85 | { 86 | std::lock_guard guard(mStackLock); 87 | mErrorStack.push_back(errorPair(val, desc)); 88 | if (mUserRecorder) 89 | { 90 | mUserRecorder->reportError(val, desc); 91 | } 92 | else 93 | { 94 | logError(mLogger, desc); 95 | } 96 | } 97 | catch (const std::exception& e) 98 | { 99 | logError(mLogger, e.what()); 100 | } 101 | // All errors are considered fatal. 102 | return true; 103 | } 104 | 105 | nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::incRefCount() noexcept 106 | { 107 | // Atomically increment or decrement the ref counter. 108 | return ++mRefCount; 109 | } 110 | 111 | nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::decRefCount() noexcept 112 | { 113 | auto newVal = --mRefCount; 114 | if (newVal == 0) 115 | { 116 | delete this; 117 | } 118 | return newVal; 119 | } 120 | 121 | } // namespace onnx2trt 122 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/onnxErrorRecorder.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "NvInferRuntimeCommon.h" 8 | #include "onnx2trt_utils.hpp" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace onnx2trt 16 | { 17 | 18 | //! 19 | //! A simple implementation of the IErrorRecorder interface for 20 | //! use by ONNX importer. 21 | //! ONNX-importer Error recorder is based on a vector that pairs the error 22 | //! code and the error string into a single element. It also uses 23 | //! standard mutex and atomics in order to make sure that the code 24 | //! works in a multi-threaded environment. 25 | //! 26 | class ONNXParserErrorRecorder : public nvinfer1::IErrorRecorder 27 | { 28 | using RefCount = nvinfer1::IErrorRecorder::RefCount; 29 | using ErrorDesc = nvinfer1::IErrorRecorder::ErrorDesc; 30 | using ErrorCode = nvinfer1::ErrorCode; 31 | using IErrorRecorder = nvinfer1::IErrorRecorder; 32 | using ILogger = nvinfer1::ILogger; 33 | 34 | using errorPair = std::pair; 35 | using errorStack = std::vector; 36 | 37 | public: 38 | static ONNXParserErrorRecorder* create( 39 | ILogger* logger, IErrorRecorder* otherRecorder = nullptr); 40 | 41 | static void destroy(ONNXParserErrorRecorder*& recorder); 42 | 43 | void clear() noexcept final; 44 | RefCount incRefCount() noexcept final; 45 | RefCount decRefCount() noexcept final; 46 | bool reportError(ErrorCode val, ErrorDesc desc) noexcept final; 47 | 48 | int32_t getNbErrors() const noexcept final 49 | { 50 | return mErrorStack.size(); 51 | } 52 | 53 | ErrorCode getErrorCode(int32_t errorIdx) const noexcept final 54 | { 55 | return invalidIndexCheck(errorIdx) ? ErrorCode::kINVALID_ARGUMENT : (*this)[errorIdx].first; 56 | } 57 | 58 | ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept final 59 | { 60 | return invalidIndexCheck(errorIdx) ? "errorIdx out of range." : (*this)[errorIdx].second.c_str(); 61 | } 62 | 63 | bool hasOverflowed() const noexcept final 64 | { 65 | // This class can never overflow since we have dynamic resize via std::vector usage. 66 | return false; 67 | } 68 | 69 | protected: 70 | ONNXParserErrorRecorder(ILogger* logger, IErrorRecorder* otherRecorder = nullptr); 71 | 72 | virtual ~ONNXParserErrorRecorder() noexcept; 73 | 74 | static void logError(ILogger* logger, const char* str); 75 | 76 | // Simple helper functions. 77 | const errorPair& operator[](size_t index) const noexcept 78 | { 79 | return mErrorStack[index]; 80 | } 81 | 82 | bool invalidIndexCheck(int32_t index) const noexcept 83 | { 84 | // By converting signed to unsigned, we only need a single check since 85 | // negative numbers turn into large positive greater than the size. 86 | size_t sIndex = index; 87 | return sIndex >= mErrorStack.size(); 88 | } 89 | // Mutex to hold when locking mErrorStack. 90 | std::mutex mStackLock; 91 | 92 | // Reference count of the class. Destruction of the class when mRefCount 93 | // is not zero causes undefined behavior. 94 | std::atomic mRefCount{0}; 95 | 96 | // The error stack that holds the errors recorded by TensorRT. 97 | errorStack mErrorStack; 98 | 99 | // Original error recorder (set by user) 100 | IErrorRecorder* mUserRecorder{nullptr}; 101 | 102 | // logger 103 | ILogger* mLogger{nullptr}; 104 | }; // class ONNXParserErrorRecorder 105 | 106 | } // namespace onnx2trt 107 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/onnx_utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #pragma once 14 | 15 | namespace 16 | { 17 | 18 | template 19 | bool convertOnnxDims(OnnxDims const& onnxDims, nvinfer1::Dims& trtDims) 20 | { 21 | std::vector onnxDims_vector; 22 | for (const auto& onnxDim : onnxDims) 23 | { 24 | const int dim = onnxDim.dim_param() == "" ? (onnxDim.dim_value() >= 0 ? onnxDim.dim_value() : -1) : -1; 25 | onnxDims_vector.emplace_back(dim); 26 | } 27 | trtDims.nbDims = onnxDims_vector.size(); 28 | assert(trtDims.nbDims <= nvinfer1::Dims::MAX_DIMS); 29 | std::copy(onnxDims_vector.begin(), onnxDims_vector.end(), trtDims.d); 30 | return true; 31 | } 32 | 33 | // Removes raw data from the text representation of an ONNX model 34 | void remove_raw_data_strings(std::string& s) 35 | { 36 | std::string::size_type beg = 0; 37 | const std::string key = "raw_data: \""; 38 | const std::string sub = "..."; 39 | while ((beg = s.find(key, beg)) != std::string::npos) 40 | { 41 | beg += key.length(); 42 | std::string::size_type end = beg - 1; 43 | // Note: Must skip over escaped end-quotes 44 | while (s[(end = s.find("\"", ++end)) - 1] == '\\') 45 | { 46 | } 47 | if (end - beg > 128) 48 | { // Only remove large data strings 49 | s.replace(beg, end - beg, "..."); 50 | } 51 | beg += sub.length(); 52 | } 53 | } 54 | 55 | // Removes float_data, int32_data etc. from the text representation of an ONNX model 56 | std::string remove_repeated_data_strings(std::string& s) 57 | { 58 | std::istringstream iss(s); 59 | std::ostringstream oss; 60 | bool is_repeat = false; 61 | for (std::string line; std::getline(iss, line);) 62 | { 63 | if (line.find("float_data:") != std::string::npos || line.find("int32_data:") != std::string::npos 64 | || line.find("int64_data:") != std::string::npos) 65 | { 66 | if (!is_repeat) 67 | { 68 | is_repeat = true; 69 | oss << line.substr(0, line.find(":") + 1) << " ...\n"; 70 | } 71 | } 72 | else 73 | { 74 | is_repeat = false; 75 | oss << line << "\n"; 76 | } 77 | } 78 | return oss.str(); 79 | } 80 | 81 | } // anonymous namespace 82 | 83 | inline std::string pretty_print_onnx_to_string(::google::protobuf::Message const& message) 84 | { 85 | std::string s; 86 | ::google::protobuf::TextFormat::PrintToString(message, &s); 87 | remove_raw_data_strings(s); 88 | s = remove_repeated_data_strings(s); 89 | return s; 90 | } 91 | 92 | inline std::ostream& operator<<(std::ostream& stream, ::onnx::ModelProto const& message) 93 | { 94 | stream << pretty_print_onnx_to_string(message); 95 | return stream; 96 | } 97 | 98 | inline std::ostream& operator<<(std::ostream& stream, ::onnx::NodeProto const& message) 99 | { 100 | stream << pretty_print_onnx_to_string(message); 101 | return stream; 102 | } 103 | 104 | //... 105 | //...Consider moving all of the below functions into a stand alone 106 | //... 107 | 108 | inline bool ParseFromFile_WAR(google::protobuf::Message* msg, const char* filename) 109 | { 110 | 111 | std::ifstream stream(filename, std::ios::in | std::ios::binary); 112 | if (!stream) 113 | { 114 | std::cerr << "Could not open file " << std::string(filename) << std::endl; 115 | return false; 116 | } 117 | google::protobuf::io::IstreamInputStream rawInput(&stream); 118 | 119 | google::protobuf::io::CodedInputStream coded_input(&rawInput); 120 | // Note: This WARs the very low default size limit (64MB) 121 | coded_input.SetTotalBytesLimit(std::numeric_limits::max()); 122 | return msg->ParseFromCodedStream(&coded_input); 123 | } 124 | 125 | inline bool ParseFromTextFile(google::protobuf::Message* msg, const char* filename) 126 | { 127 | std::ifstream stream(filename, std::ios::in); 128 | if (!stream) 129 | { 130 | std::cerr << "Could not open file " << std::string(filename) << std::endl; 131 | return false; 132 | } 133 | 134 | google::protobuf::io::IstreamInputStream rawInput(&stream); 135 | 136 | return google::protobuf::TextFormat::Parse(&rawInput, msg); 137 | } 138 | 139 | inline std::string onnx_ir_version_string(int64_t ir_version = ::onnx::IR_VERSION) 140 | { 141 | int onnx_ir_major = ir_version / 1000000; 142 | int onnx_ir_minor = ir_version % 1000000 / 10000; 143 | int onnx_ir_patch = ir_version % 10000; 144 | return (std::to_string(onnx_ir_major) + "." + std::to_string(onnx_ir_minor) + "." + std::to_string(onnx_ir_patch)); 145 | } 146 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/readme.md: -------------------------------------------------------------------------------- 1 | # ONNX Parser 2 | - 这几个文件提取自官方的onnx-tensorrt,去掉python方面,其他都在 3 | - 另外增加了Plugin节点的支持 4 | - https://github.com/onnx/onnx-tensorrt -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/toposort.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | using std::cout; 12 | using std::cerr; 13 | using std::endl; 14 | 15 | namespace 16 | { 17 | 18 | enum NodeState 19 | { 20 | NODE_UNVISITED, 21 | NODE_ACTIVE, 22 | NODE_VISITED 23 | }; 24 | 25 | template 26 | bool get_post_order(size_t node_idx, Container const& nodes, std::unordered_map const& node_map, 27 | std::vector* node_states, std::vector* order) 28 | { 29 | NodeState& node_state = node_states->at(node_idx); 30 | if (node_state == NODE_ACTIVE) 31 | { 32 | // Cycle detected! 33 | cerr << "ERROR: Graph contains a cycle" << endl; 34 | return false; 35 | } 36 | else if (node_state == NODE_VISITED) 37 | { 38 | return true; 39 | } 40 | else 41 | { 42 | node_state = NODE_ACTIVE; 43 | // TODO: This .Get().input() is highly specific to protobuf, should 44 | // generalise it somehow. 45 | for (auto const& input : nodes.Get(node_idx).input()) 46 | { 47 | if (!node_map.count(input)) 48 | { 49 | // Input node not found in graph! 50 | // cerr << "ERROR: Input node not found in graph: " 51 | // << input << endl; 52 | // return false; 53 | continue; // Skip missing input edges 54 | } 55 | size_t input_node_idx = node_map.at(input); 56 | if (!get_post_order(input_node_idx, nodes, node_map, node_states, order)) 57 | { 58 | return false; 59 | } 60 | } 61 | node_state = NODE_VISITED; 62 | order->push_back(node_idx); 63 | } 64 | return true; 65 | } 66 | 67 | } // anonymous namespace 68 | 69 | template 70 | bool toposort(Container const& nodes, std::vector* order) 71 | { 72 | std::unordered_map node_map; 73 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 74 | { 75 | // TODO: This .Get().input() is highly specific to protobuf, should 76 | // generalise it somehow. 77 | for (auto const& output : nodes.Get(i).output()) 78 | { 79 | if (!node_map.emplace(output, i).second) 80 | { 81 | // Output name appears more than once in graph! 82 | cerr << "ERROR: Output name is not unique: " << output << endl; 83 | return false; 84 | } 85 | } 86 | } 87 | order->reserve(nodes.size()); 88 | std::vector node_states(nodes.size(), NODE_UNVISITED); 89 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 90 | { 91 | if (!get_post_order(i, nodes, node_map, &node_states, order)) 92 | { 93 | return false; 94 | } 95 | } 96 | return true; 97 | } 98 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/trt_utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "Status.hpp" 8 | #include "TensorOrWeights.hpp" 9 | #include "onnx2trt.hpp" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace onnx2trt 17 | { 18 | 19 | inline int getDtypeSize(nvinfer1::DataType trtDtype) 20 | { 21 | switch (trtDtype) 22 | { 23 | case nvinfer1::DataType::kFLOAT: return 4; 24 | case nvinfer1::DataType::kINT8: return 1; 25 | case nvinfer1::DataType::kHALF: return 2; 26 | case nvinfer1::DataType::kINT32: 27 | return 4; 28 | // TRT does not support booleans as a native type, so we treat them like int32 values. 29 | case nvinfer1::DataType::kBOOL: 30 | return 4; 31 | // TODO: Some sort of error handling 32 | default: return -1; 33 | } 34 | } 35 | 36 | inline nvinfer1::Dims insert_dim(nvinfer1::Dims const& dims, int idx, int value) 37 | { 38 | assert(idx < dims.nbDims + 1); 39 | nvinfer1::Dims new_dims; 40 | new_dims.nbDims = dims.nbDims + 1; 41 | for (int i = 0; i < idx; ++i) 42 | { 43 | new_dims.d[i] = dims.d[i]; 44 | } 45 | new_dims.d[idx] = value; 46 | for (int i = idx + 1; i < new_dims.nbDims; ++i) 47 | { 48 | new_dims.d[i] = dims.d[i - 1]; 49 | } 50 | return new_dims; 51 | } 52 | 53 | inline nvinfer1::Dims remove_dim(nvinfer1::Dims const& dims, int idx) 54 | { 55 | assert(idx < dims.nbDims); 56 | nvinfer1::Dims new_dims; 57 | new_dims.nbDims = dims.nbDims - 1; 58 | for (int i = 0; i < idx; ++i) 59 | { 60 | new_dims.d[i] = dims.d[i]; 61 | } 62 | for (int i = idx; i < new_dims.nbDims; ++i) 63 | { 64 | new_dims.d[i] = dims.d[i + 1]; 65 | } 66 | // Special case for scalar result (i.e., there was only one dim originally) 67 | if (new_dims.nbDims == 0) 68 | { 69 | new_dims.nbDims = 1; 70 | new_dims.d[0] = 1; 71 | } 72 | return new_dims; 73 | } 74 | 75 | // Adds unitary dimensions on the left 76 | inline nvinfer1::Dims expand_dims(nvinfer1::Dims const& dims, int ndim_new) 77 | { 78 | assert(dims.nbDims <= ndim_new); 79 | nvinfer1::Dims new_dims; 80 | new_dims.nbDims = ndim_new; 81 | int j = 0; 82 | for (; j < ndim_new - dims.nbDims; ++j) 83 | { 84 | new_dims.d[j] = 1; 85 | } 86 | for (int i = 0; i < dims.nbDims; ++i, ++j) 87 | { 88 | new_dims.d[j] = dims.d[i]; 89 | } 90 | return new_dims; 91 | } 92 | 93 | inline nvinfer1::Permutation remove_first_dim(nvinfer1::Permutation const& perm) 94 | { 95 | assert(perm.order[0] == 0); 96 | nvinfer1::Permutation new_perm; 97 | int ndim = nvinfer1::Dims::MAX_DIMS; 98 | for (int i = 0; i < ndim - 1; ++i) 99 | { 100 | new_perm.order[i] = perm.order[i + 1] - 1; 101 | } 102 | return new_perm; 103 | } 104 | 105 | inline nvinfer1::Dims squeeze_trailing_dims(nvinfer1::Dims const& dims) 106 | { 107 | nvinfer1::Dims new_dims = dims; 108 | // Note: TRT requires at least one dimension, so we don't squeeze [1]->[] 109 | while (new_dims.nbDims > 1 && new_dims.d[new_dims.nbDims - 1] == 1) 110 | { 111 | --new_dims.nbDims; 112 | } 113 | return new_dims; 114 | } 115 | 116 | inline nvinfer1::Dims squeeze_leading_dims(const nvinfer1::Dims& dims) 117 | { 118 | nvinfer1::Dims newDims; 119 | // Copy dims only if a non-1 has been seen already. 120 | bool non1Seen{false}; 121 | newDims.nbDims = std::copy_if(dims.d, dims.d + dims.nbDims, newDims.d, 122 | [&non1Seen](int x) { 123 | non1Seen = (x != 1) ? true : non1Seen; 124 | return non1Seen; 125 | }) 126 | - newDims.d; 127 | return newDims; 128 | } 129 | 130 | inline nvinfer1::DimsHW operator-(nvinfer1::DimsHW dims) 131 | { 132 | return nvinfer1::DimsHW(-dims.h(), -dims.w()); 133 | } 134 | 135 | // Note: These are used for checking beg_padding == end_padding 136 | inline bool operator==(nvinfer1::Dims const& a, nvinfer1::Dims const& b) 137 | { 138 | if (a.nbDims != b.nbDims) 139 | { 140 | return false; 141 | } 142 | for (int i = 0; i < a.nbDims; ++i) 143 | { 144 | if (a.d[i] != b.d[i]) 145 | { 146 | return false; 147 | } 148 | } 149 | return true; 150 | } 151 | inline bool operator!=(nvinfer1::Dims const& a, nvinfer1::Dims const& b) 152 | { 153 | return !(a == b); 154 | } 155 | 156 | inline TensorOrWeights identity(IImporterContext* ctx, TensorOrWeights input) 157 | { 158 | if (input.is_weights()) 159 | { 160 | return input; 161 | } 162 | else 163 | { 164 | auto* layer = ctx->network()->addIdentity(input.tensor()); 165 | if (!layer) 166 | { 167 | return nullptr; 168 | } 169 | return layer->getOutput(0); 170 | } 171 | } 172 | 173 | inline ::onnx::TensorProto_DataType trtDataTypeToONNX(nvinfer1::DataType dt) 174 | { 175 | switch (dt) 176 | { 177 | case nvinfer1::DataType::kFLOAT: return ::onnx::TensorProto::FLOAT; 178 | case nvinfer1::DataType::kHALF: return ::onnx::TensorProto::FLOAT16; 179 | case nvinfer1::DataType::kINT32: return ::onnx::TensorProto::INT32; 180 | case nvinfer1::DataType::kINT8: return ::onnx::TensorProto::INT8; 181 | case nvinfer1::DataType::kBOOL: return ::onnx::TensorProto::BOOL; 182 | default: return ::onnx::TensorProto_DataType_UNDEFINED; 183 | } 184 | throw std::runtime_error{"Unreachable"}; 185 | } 186 | 187 | } // namespace onnx2trt 188 | -------------------------------------------------------------------------------- /onnx_parser/onnx_parser_8.x/utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | template 10 | using string_map = std::unordered_map; 11 | -------------------------------------------------------------------------------- /onnx_parser/readme.md: -------------------------------------------------------------------------------- 1 | # Onnx parser for 7.x/8.x 2 | - Origin Code 7.x: https://github.com/onnx/onnx-tensorrt/releases/tag/release%2F7.2.1 3 | - Origin Code 8.x: https://github.com/onnx/onnx-tensorrt/releases/tag/release%2F8.0 4 | 5 | # TensorRT 7.x support 6 | 1. Replace onnx_parser_for_7.x/onnx_parser to src/tensorRT/onnx_parser 7 | - `rm -rf src/tensorRT/onnx_parser` 8 | - `cp -r onnx_parser/onnx_parser_7.x src/tensorRT/onnx_parser` 9 | - or execute `bash onnx_parser/use_tensorrt_7.x.sh` 10 | 2. Configure Makefile/CMakeLists.txt path to TensorRT7.x 11 | 3. Execute `make yolo -j64` 12 | 13 | # TensorRT 8.x support 14 | 1. Replace onnx_parser_for_8.x/onnx_parser to src/tensorRT/onnx_parser 15 | - `rm -rf src/tensorRT/onnx_parser` 16 | - `cp -r onnx_parser/onnx_parser_8.x src/tensorRT/onnx_parser` 17 | - or execute `bash onnx_parser/use_tensorrt_8.x.sh` 18 | 2. Configure Makefile/CMakeLists.txt path to TensorRT8.x 19 | 3. Execute `make yolo -j64` 20 | 21 | # Unsupported TensorRT for less 7.x version -------------------------------------------------------------------------------- /onnx_parser/use_tensorrt_7.x.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo Remove src/tensorRT/onnx_parser 4 | rm -rf src/tensorRT/onnx_parser 5 | 6 | echo Copy [onnx_parser/onnx_parser_7.x] to [src/tensorRT/onnx_parser] 7 | cp -r onnx_parser/onnx_parser_7.x src/tensorRT/onnx_parser 8 | 9 | echo Configure your tensorRT path to 7.x 10 | echo After that, you can execute the command 'make yolo -j64' -------------------------------------------------------------------------------- /onnx_parser/use_tensorrt_8.x.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo Remove src/tensorRT/onnx_parser 4 | rm -rf src/tensorRT/onnx_parser 5 | 6 | echo Copy [onnx_parser/onnx_parser_8.x] to [src/tensorRT/onnx_parser] 7 | cp -r onnx_parser/onnx_parser_8.x src/tensorRT/onnx_parser 8 | 9 | echo Configure your tensorRT path to 8.x 10 | echo After that, you can execute the command 'make yolo -j64' -------------------------------------------------------------------------------- /src/application/app_demuxer.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | 6 | using namespace std; 7 | 8 | static void test_demuxer(){ 9 | 10 | auto demuxer = FFHDDemuxer::create_ffmpeg_demuxer("exp/fall_video.mp4"); 11 | if(demuxer == nullptr){ 12 | INFOE("demuxer create failed"); 13 | return; 14 | } 15 | 16 | INFO("demuxer create done."); 17 | 18 | uint8_t* packet_data = nullptr; 19 | int packet_size = 0; 20 | int64_t pts = 0; 21 | 22 | demuxer->get_extra_data(&packet_data, &packet_size); 23 | 24 | vector extra_data(packet_size + 3); 25 | memcpy(extra_data.data() + 3, packet_data, packet_size); 26 | 27 | int ipacket = 0; 28 | auto frame_type = NALU::format_nalu_type(NALU::find_all_nalu_info(extra_data.data(), packet_size, 0)); 29 | INFO("Extra Data size: %d, type: %s", packet_size, frame_type.c_str()); 30 | 31 | do{ 32 | demuxer->demux(&packet_data, &packet_size, &pts); 33 | 34 | frame_type = "Empty"; 35 | if(packet_size > 0){ 36 | frame_type = NALU::format_nalu_frame_type(NALU::find_all_nalu_info(packet_data, packet_size, 0)); 37 | } 38 | 39 | INFO("Packet %d NALU size: %d, pts = %lld, type = %s", 40 | ipacket++, 41 | packet_size, 42 | pts, 43 | frame_type.c_str() 44 | ); 45 | 46 | }while(packet_size > 0); 47 | } 48 | 49 | /* 50 | 一个GOP,就是一个group,有N个frame 51 | N又 = I + B/P * M M = N - 1 52 | 53 | GOP是H264的最小单元。要是理解了这些。就可以轻易操作H264分为一段一段的。储存也好,解码也好。都会很容易 54 | */ 55 | int app_demuxer(){ 56 | 57 | test_demuxer(); 58 | //INFO("%s", NALU::slice_type_string(NALU::get_slice_type_from_slice_header(0x00D8E002))); 59 | return 0; 60 | } -------------------------------------------------------------------------------- /src/application/app_hard_decode.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | 10 | static void test_hard_decode(){ 11 | 12 | auto demuxer = FFHDDemuxer::create_ffmpeg_demuxer("exp/number100.mp4"); 13 | if(demuxer == nullptr){ 14 | INFOE("demuxer create failed"); 15 | return; 16 | } 17 | 18 | auto decoder = FFHDDecoder::create_cuvid_decoder( 19 | false, FFHDDecoder::ffmpeg2NvCodecId(demuxer->get_video_codec()), -1, 0 20 | ); 21 | 22 | if(decoder == nullptr){ 23 | INFOE("decoder create failed"); 24 | return; 25 | } 26 | 27 | uint8_t* packet_data = nullptr; 28 | int packet_size = 0; 29 | int64_t pts = 0; 30 | 31 | demuxer->get_extra_data(&packet_data, &packet_size); 32 | decoder->decode(packet_data, packet_size); 33 | 34 | iLogger::rmtree("imgs"); 35 | iLogger::mkdir("imgs"); 36 | 37 | do{ 38 | demuxer->demux(&packet_data, &packet_size, &pts); 39 | int ndecoded_frame = decoder->decode(packet_data, packet_size, pts); 40 | for(int i = 0; i < ndecoded_frame; ++i){ 41 | unsigned int frame_index = 0; 42 | 43 | /* 因为decoder获取的frame内存,是YUV-NV12格式的。储存内存大小是 [height * 1.5] * width byte 44 | 因此构造一个height * 1.5, width 大小的空间 45 | 然后由opencv函数,把YUV-NV12转换到BGR,转换后的image则是正常的height, width, CV_8UC3 46 | */ 47 | cv::Mat image(decoder->get_height() * 1.5, decoder->get_width(), CV_8U, decoder->get_frame(&pts, &frame_index)); 48 | cv::cvtColor(image, image, cv::COLOR_YUV2BGR_NV12); 49 | 50 | frame_index = frame_index + 1; 51 | INFO("write imgs/img_%05d.jpg %dx%d", frame_index, image.cols, image.rows); 52 | cv::imwrite(cv::format("imgs/img_%05d.jpg", frame_index), image); 53 | } 54 | }while(packet_size > 0); 55 | } 56 | 57 | int app_hard_decode(){ 58 | 59 | test_hard_decode(); 60 | return 0; 61 | } -------------------------------------------------------------------------------- /src/application/app_yolo/yolo.hpp: -------------------------------------------------------------------------------- 1 | #ifndef YOLO_HPP 2 | #define YOLO_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | /** 13 | * @brief 发挥极致的性能体验 14 | * 支持YoloX和YoloV5 15 | */ 16 | namespace Yolo{ 17 | 18 | using namespace std; 19 | using namespace ObjectDetector; 20 | 21 | enum class Type : int{ 22 | V5 = 0, 23 | X = 1 24 | }; 25 | 26 | enum class NMSMethod : int{ 27 | CPU = 0, // General, for estimate mAP 28 | FastGPU = 1 // Fast NMS with a small loss of accuracy in corner cases 29 | }; 30 | 31 | enum class ImageType : int{ 32 | CVMat = 0, 33 | GPUYUV = 1 // nv12 34 | }; 35 | 36 | struct Image{ 37 | ImageType type = ImageType::CVMat; 38 | cv::Mat cvmat; 39 | 40 | // GPU YUV image 41 | TRT::CUStream stream = nullptr; 42 | uint8_t* device_data = nullptr; 43 | int width = 0, height = 0; 44 | int device_id = 0; 45 | 46 | Image() = default; 47 | Image(const cv::Mat& cvmat):cvmat(cvmat), type(ImageType::CVMat){} 48 | Image(uint8_t* yuvdata_device, int width, int height, int device_id, TRT::CUStream stream) 49 | :device_data(yuvdata_device), width(width), height(height), device_id(device_id), stream(stream), type(ImageType::GPUYUV){} 50 | 51 | int get_width() const{return type == ImageType::CVMat ? cvmat.cols : width;} 52 | int get_height() const{return type == ImageType::CVMat ? cvmat.rows : height;} 53 | cv::Size get_size() const{return cv::Size(get_width(), get_height());} 54 | bool empty() const{return type == ImageType::CVMat ? cvmat.empty() : (device_data == nullptr || width < 1 || height < 1);} 55 | }; 56 | 57 | void image_to_tensor(const cv::Mat& image, shared_ptr& tensor, Type type, int ibatch); 58 | 59 | class Infer{ 60 | public: 61 | virtual shared_future commit(const Image& image) = 0; 62 | virtual vector> commits(const vector& images) = 0; 63 | }; 64 | 65 | shared_ptr create_infer( 66 | const string& engine_file, Type type, int gpuid, 67 | float confidence_threshold=0.25f, float nms_threshold=0.5f, 68 | NMSMethod nms_method = NMSMethod::FastGPU, int max_objects = 1024, 69 | bool use_multi_preprocess_stream = false 70 | ); 71 | const char* type_name(Type type); 72 | 73 | }; // namespace Yolo 74 | 75 | #endif // YOLO_HPP -------------------------------------------------------------------------------- /src/application/app_yolo/yolo_decode.cu: -------------------------------------------------------------------------------- 1 | 2 | 3 | #include 4 | 5 | namespace Yolo{ 6 | 7 | const int NUM_BOX_ELEMENT = 7; // left, top, right, bottom, confidence, class, keepflag 8 | static __device__ void affine_project(float* matrix, float x, float y, float* ox, float* oy){ 9 | *ox = matrix[0] * x + matrix[1] * y + matrix[2]; 10 | *oy = matrix[3] * x + matrix[4] * y + matrix[5]; 11 | } 12 | 13 | static __global__ void decode_kernel(float* predict, int num_bboxes, int num_classes, float confidence_threshold, float* invert_affine_matrix, float* parray, int max_objects){ 14 | 15 | int position = blockDim.x * blockIdx.x + threadIdx.x; 16 | if (position >= num_bboxes) return; 17 | 18 | float* pitem = predict + (5 + num_classes) * position; 19 | float objectness = pitem[4]; 20 | if(objectness < confidence_threshold) 21 | return; 22 | 23 | float* class_confidence = pitem + 5; 24 | float confidence = *class_confidence++; 25 | int label = 0; 26 | for(int i = 1; i < num_classes; ++i, ++class_confidence){ 27 | if(*class_confidence > confidence){ 28 | confidence = *class_confidence; 29 | label = i; 30 | } 31 | } 32 | 33 | confidence *= objectness; 34 | if(confidence < confidence_threshold) 35 | return; 36 | 37 | int index = atomicAdd(parray, 1); 38 | if(index >= max_objects) 39 | return; 40 | 41 | float cx = *pitem++; 42 | float cy = *pitem++; 43 | float width = *pitem++; 44 | float height = *pitem++; 45 | float left = cx - width * 0.5f; 46 | float top = cy - height * 0.5f; 47 | float right = cx + width * 0.5f; 48 | float bottom = cy + height * 0.5f; 49 | affine_project(invert_affine_matrix, left, top, &left, &top); 50 | affine_project(invert_affine_matrix, right, bottom, &right, &bottom); 51 | 52 | float* pout_item = parray + 1 + index * NUM_BOX_ELEMENT; 53 | *pout_item++ = left; 54 | *pout_item++ = top; 55 | *pout_item++ = right; 56 | *pout_item++ = bottom; 57 | *pout_item++ = confidence; 58 | *pout_item++ = label; 59 | *pout_item++ = 1; // 1 = keep, 0 = ignore 60 | } 61 | 62 | static __device__ float box_iou( 63 | float aleft, float atop, float aright, float abottom, 64 | float bleft, float btop, float bright, float bbottom 65 | ){ 66 | 67 | float cleft = max(aleft, bleft); 68 | float ctop = max(atop, btop); 69 | float cright = min(aright, bright); 70 | float cbottom = min(abottom, bbottom); 71 | 72 | float c_area = max(cright - cleft, 0.0f) * max(cbottom - ctop, 0.0f); 73 | if(c_area == 0.0f) 74 | return 0.0f; 75 | 76 | float a_area = max(0.0f, aright - aleft) * max(0.0f, abottom - atop); 77 | float b_area = max(0.0f, bright - bleft) * max(0.0f, bbottom - btop); 78 | return c_area / (a_area + b_area - c_area); 79 | } 80 | 81 | static __global__ void nms_kernel(float* bboxes, int max_objects, float threshold){ 82 | 83 | int position = (blockDim.x * blockIdx.x + threadIdx.x); 84 | int count = min((int)*bboxes, max_objects); 85 | if (position >= count) 86 | return; 87 | 88 | // left, top, right, bottom, confidence, class, keepflag 89 | float* pcurrent = bboxes + 1 + position * NUM_BOX_ELEMENT; 90 | for(int i = 0; i < count; ++i){ 91 | float* pitem = bboxes + 1 + i * NUM_BOX_ELEMENT; 92 | if(i == position || pcurrent[5] != pitem[5]) continue; 93 | 94 | if(pitem[4] >= pcurrent[4]){ 95 | if(pitem[4] == pcurrent[4] && i < position) 96 | continue; 97 | 98 | float iou = box_iou( 99 | pcurrent[0], pcurrent[1], pcurrent[2], pcurrent[3], 100 | pitem[0], pitem[1], pitem[2], pitem[3] 101 | ); 102 | 103 | if(iou > threshold){ 104 | pcurrent[6] = 0; // 1=keep, 0=ignore 105 | return; 106 | } 107 | } 108 | } 109 | } 110 | 111 | void decode_kernel_invoker(float* predict, int num_bboxes, int num_classes, float confidence_threshold, float nms_threshold, float* invert_affine_matrix, float* parray, int max_objects, cudaStream_t stream){ 112 | 113 | auto grid = CUDATools::grid_dims(num_bboxes); 114 | auto block = CUDATools::block_dims(num_bboxes); 115 | checkCudaKernel(decode_kernel<<>>(predict, num_bboxes, num_classes, confidence_threshold, invert_affine_matrix, parray, max_objects)); 116 | } 117 | 118 | void nms_kernel_invoker(float* parray, float nms_threshold, int max_objects, cudaStream_t stream){ 119 | 120 | auto grid = CUDATools::grid_dims(max_objects); 121 | auto block = CUDATools::block_dims(max_objects); 122 | checkCudaKernel(nms_kernel<<>>(parray, max_objects, nms_threshold)); 123 | } 124 | }; -------------------------------------------------------------------------------- /src/application/common/object_detector.hpp: -------------------------------------------------------------------------------- 1 | #ifndef OBJECT_DETECTOR_HPP 2 | #define OBJECT_DETECTOR_HPP 3 | 4 | #include 5 | 6 | namespace ObjectDetector{ 7 | 8 | struct Box{ 9 | float left, top, right, bottom, confidence; 10 | int class_label; 11 | 12 | Box() = default; 13 | 14 | Box(float left, float top, float right, float bottom, float confidence, int class_label) 15 | :left(left), top(top), right(right), bottom(bottom), confidence(confidence), class_label(class_label){} 16 | }; 17 | 18 | typedef std::vector BoxArray; 19 | }; 20 | 21 | 22 | #endif // OBJECT_DETECTOR_HPP -------------------------------------------------------------------------------- /src/application/tools/auto_download.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | using namespace std; 6 | 7 | bool requires(const char* name) { 8 | 9 | auto onnx_file = iLogger::format("%s.onnx", name); 10 | if (not iLogger::exists(onnx_file)) { 11 | INFO("Auto download %s", onnx_file.c_str()); 12 | system(iLogger::format("wget http://zifuture.com:1556/fs/25.shared/%s", onnx_file.c_str()).c_str()); 13 | } 14 | 15 | bool exists = iLogger::exists(onnx_file); 16 | if (not exists) { 17 | INFOE("Download %s failed", onnx_file.c_str()); 18 | } 19 | return exists; 20 | } -------------------------------------------------------------------------------- /src/application/tools/zmq_remote_show.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "zmq_remote_show.hpp" 3 | #include "zmq_u.hpp" 4 | #include 5 | 6 | using namespace std; 7 | 8 | class ZMQRemoteShowImpl : public ZMQRemoteShow{ 9 | public: 10 | bool listen(const char* url){ 11 | try{ 12 | context_.reset(new zmq::context_t()); 13 | socket_.reset(new zmq::socket_t(*context_.get(), zmq::socket_type::rep)); 14 | socket_->bind(url); 15 | return true; 16 | }catch(zmq::error_t err){ 17 | INFOE("ZMQ exception: %s", err.what()); 18 | socket_.reset(); 19 | context_.reset(); 20 | } 21 | return false; 22 | } 23 | 24 | virtual void post(const void* data, int size) override{ 25 | 26 | if(size < 1 || data == nullptr){ 27 | INFOE("Null data to post"); 28 | return; 29 | } 30 | 31 | zmq::message_t msg; 32 | socket_->recv(msg); 33 | socket_->send(zmq::message_t(data, size)); 34 | } 35 | 36 | virtual void post(const cv::Mat& image) override{ 37 | 38 | vector data; 39 | cv::imencode(".jpg", image, data); 40 | post(data.data(), data.size()); 41 | } 42 | 43 | private: 44 | shared_ptr context_; 45 | shared_ptr socket_; 46 | }; 47 | 48 | std::shared_ptr create_zmq_remote_show(const char* listen){ 49 | 50 | shared_ptr instance(new ZMQRemoteShowImpl()); 51 | if(!instance->listen(listen)){ 52 | instance.reset(); 53 | } 54 | return instance; 55 | } 56 | -------------------------------------------------------------------------------- /src/application/tools/zmq_remote_show.hpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef ZMQ_REMOTE_SHOW_HPP 4 | #define ZMQ_REMOTE_SHOW_HPP 5 | 6 | #include 7 | #include 8 | 9 | class ZMQRemoteShow{ 10 | public: 11 | virtual void post(const void* data, int size) = 0; 12 | virtual void post(const cv::Mat& image) = 0; 13 | }; 14 | 15 | std::shared_ptr create_zmq_remote_show(const char* listen="tcp://0.0.0.0:15556"); 16 | 17 | #endif // ZMQ_REMOTE_SHOW_HPP -------------------------------------------------------------------------------- /src/ffhdd/cuvid_decoder.hpp: -------------------------------------------------------------------------------- 1 | 2 | #ifndef CUVID_DECODER_HPP 3 | #define CUVID_DECODER_HPP 4 | 5 | #include 6 | // 就不用在这里包含cuda_runtime.h 7 | 8 | struct CUstream_st; 9 | 10 | namespace FFHDDecoder{ 11 | 12 | #define IcudaVideoCodec_H264 4 13 | 14 | typedef CUstream_st* ICUStream; 15 | typedef unsigned int IcudaVideoCodec; 16 | 17 | struct CropRect { 18 | int l, t, r, b; 19 | }; 20 | 21 | struct ResizeDim { 22 | int w, h; 23 | }; 24 | 25 | class CUVIDDecoder{ 26 | public: 27 | virtual int get_frame_size() = 0; 28 | virtual int get_width() = 0; 29 | virtual int get_height() = 0; 30 | virtual unsigned int get_frame_index() = 0; 31 | virtual unsigned int get_num_decoded_frame() = 0; 32 | virtual uint8_t* get_frame(int64_t* pTimestamp = nullptr, unsigned int* pFrameIndex = nullptr) = 0; 33 | virtual int decode(const uint8_t *pData, int nSize, int64_t nTimestamp=0) = 0; 34 | virtual ICUStream get_stream() = 0; 35 | }; 36 | 37 | IcudaVideoCodec ffmpeg2NvCodecId(int ffmpeg_codec_id); 38 | 39 | /* max_cache 取 -1 时,无限缓存,根据实际情况缓存。实际上一般不超过5帧 */ 40 | // gpu_id = -1, current_device_id 41 | std::shared_ptr create_cuvid_decoder( 42 | bool use_device_frame, IcudaVideoCodec codec, int max_cache = -1, int gpu_id = -1, 43 | const CropRect *crop_rect = nullptr, const ResizeDim *resize_dim = nullptr 44 | ); 45 | }; // FFHDDecoder 46 | 47 | #endif // CUVID_DECODER_HPP -------------------------------------------------------------------------------- /src/ffhdd/ffmpeg_demuxer.hpp: -------------------------------------------------------------------------------- 1 | 2 | #ifndef FFMPEG_DEMUXER_HPP 3 | #define FFMPEG_DEMUXER_HPP 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | 10 | namespace FFHDDemuxer{ 11 | 12 | typedef int IAVCodecID; 13 | typedef int IAVPixelFormat; 14 | 15 | class DataProvider { 16 | public: 17 | virtual int get_data(uint8_t *pBuf, int nBuf) = 0; 18 | }; 19 | 20 | class FFmpegDemuxer{ 21 | public: 22 | virtual IAVCodecID get_video_codec() = 0; 23 | virtual IAVPixelFormat get_chroma_format() = 0; 24 | virtual int get_width() = 0; 25 | virtual int get_height() = 0; 26 | virtual int get_bit_depth() = 0; 27 | virtual int get_fps() = 0; 28 | virtual int get_total_frames() = 0; 29 | virtual void get_extra_data(uint8_t **ppData, int *bytes) = 0; 30 | virtual bool isreboot() = 0; 31 | virtual void reset_reboot_flag() = 0; 32 | virtual bool demux(uint8_t **ppVideo, int *pnVideoBytes, int64_t *pts = nullptr, bool *iskey_frame = nullptr) = 0; 33 | virtual bool reopen() = 0; 34 | }; 35 | 36 | std::shared_ptr create_ffmpeg_demuxer(const std::string& uri, bool auto_reboot = false); 37 | std::shared_ptr create_ffmpeg_demuxer(std::shared_ptr provider); 38 | }; // namespace FFHDDemuxer 39 | 40 | #endif // FFMPEG_DEMUXER_HPP -------------------------------------------------------------------------------- /src/main.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | int app_demuxer(); 8 | int app_hard_decode(); 9 | int app_yolo(); 10 | 11 | int main(int argc, char** argv){ 12 | 13 | const char* method = "yolo"; 14 | if(argc > 1){ 15 | method = argv[1]; 16 | } 17 | 18 | if(strcmp(method, "demuxer") == 0){ 19 | app_demuxer(); 20 | }else if(strcmp(method, "hard_decode") == 0){ 21 | app_hard_decode(); 22 | }else if(strcmp(method, "yolo") == 0){ 23 | app_yolo(); 24 | }else{ 25 | printf("Unknow method: %s\n", method); 26 | printf( 27 | "Help: \n" 28 | " ./pro method[demuxer]\n" 29 | "\n" 30 | " ./pro yolo\n" 31 | " ./pro alphapose\n" 32 | " ./pro fall\n" 33 | ); 34 | } 35 | return 0; 36 | } 37 | -------------------------------------------------------------------------------- /src/tensorRT/builder/trt_builder.hpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef TRT_BUILDER_HPP 4 | #define TRT_BUILDER_HPP 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | namespace TRT { 12 | 13 | typedef std::function& files, std::shared_ptr& tensor)> Int8Process; 14 | typedef std::function(const std::string& name, const std::vector& shape)> LayerHookFuncReshape; 15 | 16 | enum class ModelSourceType : int{ 17 | OnnX, 18 | OnnXData 19 | }; 20 | 21 | class ModelSource { 22 | public: 23 | ModelSource() = default; 24 | ModelSource(const std::string& onnxmodel); 25 | ModelSource(const char* onnxmodel); 26 | ModelSourceType type() const; 27 | std::string onnxmodel() const; 28 | std::string descript() const; 29 | const void* onnx_data() const; 30 | size_t onnx_data_size() const; 31 | 32 | static ModelSource onnx(const std::string& file){ 33 | ModelSource output; 34 | output.onnxmodel_ = file; 35 | output.type_ = ModelSourceType::OnnX; 36 | return output; 37 | } 38 | 39 | static ModelSource onnx_data(const void* ptr, size_t size){ 40 | ModelSource output; 41 | output.onnx_data_ = ptr; 42 | output.onnx_data_size_ = size; 43 | output.type_ = ModelSourceType::OnnXData; 44 | return output; 45 | } 46 | 47 | private: 48 | std::string onnxmodel_; 49 | const void* onnx_data_ = nullptr; 50 | size_t onnx_data_size_ = 0; 51 | ModelSourceType type_; 52 | }; 53 | 54 | enum class CompileOutputType : int{ 55 | File, 56 | Memory 57 | }; 58 | 59 | class CompileOutput{ 60 | public: 61 | CompileOutput(CompileOutputType type = CompileOutputType::Memory); 62 | CompileOutput(const std::string& file); 63 | CompileOutput(const char* file); 64 | void set_data(const std::vector& data); 65 | void set_data(std::vector&& data); 66 | 67 | const std::vector& data() const{return data_;}; 68 | CompileOutputType type() const{return type_;} 69 | std::string file() const{return file_;} 70 | 71 | private: 72 | CompileOutputType type_ = CompileOutputType::Memory; 73 | std::vector data_; 74 | std::string file_; 75 | }; 76 | 77 | class InputDims { 78 | public: 79 | InputDims() = default; 80 | 81 | // 当为-1时,保留导入时的网络结构尺寸 82 | InputDims(const std::initializer_list& dims); 83 | InputDims(const std::vector& dims); 84 | 85 | const std::vector& dims() const; 86 | 87 | private: 88 | std::vector dims_; 89 | }; 90 | 91 | enum class Mode : int { 92 | FP32, 93 | FP16, 94 | INT8 95 | }; 96 | 97 | const char* mode_string(Mode type); 98 | 99 | void set_layer_hook_reshape(const LayerHookFuncReshape& func); 100 | 101 | /** 当处于INT8模式时,int8process必须制定 102 | // int8ImageDirectory和int8EntropyCalibratorFile指定一个即可 103 | // 如果初次生成,指定了int8EntropyCalibratorFile,calibrator会保存到int8EntropyCalibratorFile指定的文件 104 | // 如果已经生成过,指定了int8EntropyCalibratorFile,calibrator会从int8EntropyCalibratorFile指定的文件加载,而不是 105 | // 从int8ImageDirectory读取图片再重新生成 106 | //当处于FP32或者FP16时,int8process、int8ImageDirectory、int8EntropyCalibratorFile都不需要指定 **/ 107 | bool compile( 108 | Mode mode, 109 | unsigned int maxBatchSize, 110 | const ModelSource& source, 111 | const CompileOutput& saveto, 112 | const std::vector inputsDimsSetup = {}, 113 | Int8Process int8process = nullptr, 114 | const std::string& int8ImageDirectory = "", 115 | const std::string& int8EntropyCalibratorFile = ""); 116 | }; 117 | 118 | #endif //TRT_BUILDER_HPP -------------------------------------------------------------------------------- /src/tensorRT/common/cuda_tools.cpp: -------------------------------------------------------------------------------- 1 | 2 | /* 3 | * 系统关于CUDA的功能函数 4 | */ 5 | 6 | 7 | #include "cuda_tools.hpp" 8 | 9 | namespace CUDATools{ 10 | bool check_driver(CUresult e, const char* call, int line, const char *file) { 11 | if (e != CUDA_SUCCESS) { 12 | 13 | const char* message = nullptr; 14 | const char* name = nullptr; 15 | cuGetErrorString(e, &message); 16 | cuGetErrorName(e, &name); 17 | INFOE("CUDA Driver error %s # %s, code = %s [ %d ] in file %s:%d", call, message, name, e, file, line); 18 | return false; 19 | } 20 | return true; 21 | } 22 | 23 | bool check_runtime(cudaError_t e, const char* call, int line, const char *file){ 24 | if (e != cudaSuccess) { 25 | INFOE("CUDA Runtime error %s # %s, code = %s [ %d ] in file %s:%d", call, cudaGetErrorString(e), cudaGetErrorName(e), e, file, line); 26 | return false; 27 | } 28 | return true; 29 | } 30 | 31 | bool check_device_id(int device_id){ 32 | int device_count = -1; 33 | checkCudaRuntime(cudaGetDeviceCount(&device_count)); 34 | if(device_id < 0 || device_id >= device_count){ 35 | INFOE("Invalid device id: %d, count = %d", device_id, device_count); 36 | return false; 37 | } 38 | return true; 39 | } 40 | 41 | dim3 grid_dims(int numJobs) { 42 | int numBlockThreads = numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS; 43 | return dim3(((numJobs + numBlockThreads - 1) / (float)numBlockThreads)); 44 | } 45 | 46 | dim3 block_dims(int numJobs) { 47 | return numJobs < GPU_BLOCK_THREADS ? numJobs : GPU_BLOCK_THREADS; 48 | } 49 | 50 | std::string device_capability(int device_id){ 51 | cudaDeviceProp prop; 52 | checkCudaRuntime(cudaGetDeviceProperties(&prop, device_id)); 53 | return iLogger::format("%d.%d", prop.major, prop.minor); 54 | } 55 | 56 | AutoDevice::AutoDevice(int device_id){ 57 | 58 | cudaGetDevice(&old_); 59 | if(old_ != device_id && device_id != -1){ 60 | checkCudaRuntime(cudaSetDevice(device_id)); 61 | return; 62 | } 63 | 64 | CUcontext context = nullptr; 65 | cuCtxGetCurrent(&context); 66 | if(context == nullptr){ 67 | checkCudaRuntime(cudaSetDevice(device_id)); 68 | return; 69 | } 70 | } 71 | 72 | AutoDevice::~AutoDevice(){ 73 | if(old_ != -1){ 74 | checkCudaRuntime(cudaSetDevice(old_)); 75 | } 76 | } 77 | } -------------------------------------------------------------------------------- /src/tensorRT/common/cuda_tools.hpp: -------------------------------------------------------------------------------- 1 | #ifndef CUDA_TOOLS_HPP 2 | #define CUDA_TOOLS_HPP 3 | 4 | 5 | /* 6 | * 系统关于CUDA的功能函数 7 | */ 8 | 9 | #include 10 | #include 11 | #include "ilogger.hpp" 12 | 13 | #define GPU_BLOCK_THREADS 512 14 | 15 | 16 | #define KernelPositionBlock \ 17 | int position = (blockDim.x * blockIdx.x + threadIdx.x); \ 18 | if (position >= (edge)) return; 19 | 20 | 21 | #define checkCudaDriver(call) CUDATools::check_driver(call, #call, __LINE__, __FILE__) 22 | #define checkCudaRuntime(call) CUDATools::check_runtime(call, #call, __LINE__, __FILE__) 23 | 24 | #define checkCudaKernel(...) \ 25 | __VA_ARGS__; \ 26 | do{cudaError_t cudaStatus = cudaPeekAtLastError(); \ 27 | if (cudaStatus != cudaSuccess){ \ 28 | INFOE("launch failed: %s", cudaGetErrorString(cudaStatus)); \ 29 | }} while(0); 30 | 31 | 32 | #define Assert(op) \ 33 | do{ \ 34 | bool cond = !(!(op)); \ 35 | if(!cond){ \ 36 | INFOF("Assert failed, " #op); \ 37 | } \ 38 | }while(false) 39 | 40 | 41 | struct CUctx_st; 42 | struct CUstream_st; 43 | 44 | typedef CUstream_st* ICUStream; 45 | typedef CUctx_st* ICUContext; 46 | typedef void* ICUDeviceptr; 47 | typedef int DeviceID; 48 | 49 | namespace CUDATools{ 50 | bool check_driver(CUresult e, const char* call, int iLine, const char *szFile); 51 | bool check_runtime(cudaError_t e, const char* call, int iLine, const char *szFile); 52 | bool check_device_id(int device_id); 53 | 54 | dim3 grid_dims(int numJobs); 55 | dim3 block_dims(int numJobs); 56 | 57 | // return 8.6 etc. 58 | std::string device_capability(int device_id); 59 | 60 | class AutoDevice{ 61 | public: 62 | AutoDevice(int device_id = 0); 63 | virtual ~AutoDevice(); 64 | 65 | private: 66 | int old_ = -1; 67 | }; 68 | } 69 | 70 | 71 | #endif // CUDA_TOOLS_HPP -------------------------------------------------------------------------------- /src/tensorRT/common/ilogger.hpp: -------------------------------------------------------------------------------- 1 | 2 | #ifndef ILOGGER_HPP 3 | #define ILOGGER_HPP 4 | 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | 12 | #if defined(_WIN32) 13 | # define U_OS_WINDOWS 14 | #else 15 | # define U_OS_LINUX 16 | #endif 17 | 18 | 19 | namespace iLogger{ 20 | 21 | using namespace std; 22 | 23 | enum class LogLevel : int{ 24 | Debug = 5, 25 | Verbose = 4, 26 | Info = 3, 27 | Warning = 2, 28 | Error = 1, 29 | Fatal = 0 30 | }; 31 | 32 | #define INFOD(...) iLogger::__log_func(__FILE__, __LINE__, iLogger::LogLevel::Debug, __VA_ARGS__) 33 | #define INFOV(...) iLogger::__log_func(__FILE__, __LINE__, iLogger::LogLevel::Verbose, __VA_ARGS__) 34 | #define INFO(...) iLogger::__log_func(__FILE__, __LINE__, iLogger::LogLevel::Info, __VA_ARGS__) 35 | #define INFOW(...) iLogger::__log_func(__FILE__, __LINE__, iLogger::LogLevel::Warning, __VA_ARGS__) 36 | #define INFOE(...) iLogger::__log_func(__FILE__, __LINE__, iLogger::LogLevel::Error, __VA_ARGS__) 37 | #define INFOF(...) iLogger::__log_func(__FILE__, __LINE__, iLogger::LogLevel::Fatal, __VA_ARGS__) 38 | 39 | string date_now(); 40 | string time_now(); 41 | string gmtime_now(); 42 | string gmtime(time_t t); 43 | time_t gmtime2ctime(const string& gmt); 44 | void sleep(int ms); 45 | 46 | bool isfile(const string& file); 47 | bool mkdir(const string& path); 48 | bool mkdirs(const string& path); 49 | bool delete_file(const string& path); 50 | bool rmtree(const string& directory, bool ignore_fail=false); 51 | bool exists(const string& path); 52 | string format(const char* fmt, ...); 53 | FILE* fopen_mkdirs(const string& path, const string& mode); 54 | string file_name(const string& path, bool include_suffix=true); 55 | string directory(const string& path); 56 | long long timestamp_now(); 57 | double timestamp_now_float(); 58 | time_t last_modify(const string& file); 59 | vector load_file(const string& file); 60 | string load_text_file(const string& file); 61 | size_t file_size(const string& file); 62 | 63 | bool begin_with(const string& str, const string& with); 64 | bool end_with(const string& str, const string& with); 65 | vector split_string(const string& str, const std::string& spstr); 66 | string replace_string(const string& str, const string& token, const string& value, int nreplace=-1, int* out_num_replace=nullptr); 67 | 68 | // h[0-1], s[0-1], v[0-1] 69 | // return, 0-255, 0-255, 0-255 70 | tuple hsv2rgb(float h, float s, float v); 71 | tuple random_color(int id); 72 | 73 | // abcdefg.pnga *.png > false 74 | // abcdefg.png *.png > true 75 | // abcdefg.png a?cdefg.png > true 76 | bool pattern_match(const char* str, const char* matcher, bool igrnoe_case = true); 77 | vector find_files( 78 | const string& directory, 79 | const string& filter = "*", bool findDirectory = false, bool includeSubDirectory = false); 80 | 81 | string align_blank(const string& input, int align_size, char blank=' '); 82 | bool save_file(const string& file, const vector& data, bool mk_dirs = true); 83 | bool save_file(const string& file, const string& data, bool mk_dirs = true); 84 | bool save_file(const string& file, const void* data, size_t length, bool mk_dirs = true); 85 | 86 | // 捕获:SIGINT(2)、SIGQUIT(3) 87 | int while_loop(); 88 | 89 | // 关于logger的api 90 | const char* level_string(LogLevel level); 91 | void set_logger_save_directory(const string& loggerDirectory); 92 | 93 | void set_log_level(LogLevel level); 94 | LogLevel get_log_level(); 95 | void __log_func(const char* file, int line, LogLevel level, const char* fmt, ...); 96 | void destroy_logger(); 97 | 98 | string base64_decode(const string& base64); 99 | string base64_encode(const void* data, size_t size); 100 | 101 | inline int upbound(int n, int align = 32){return (n + align - 1) / align * align;} 102 | string join_dims(const vector& dims); 103 | }; 104 | 105 | 106 | #endif // ILOGGER_HPP -------------------------------------------------------------------------------- /src/tensorRT/common/infer_controller.hpp: -------------------------------------------------------------------------------- 1 | #ifndef INFER_CONTROLLER_HPP 2 | #define INFER_CONTROLLER_HPP 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "monopoly_allocator.hpp" 13 | 14 | template, class JobAdditional=int> 15 | class InferController{ 16 | public: 17 | struct Job{ 18 | Input input; 19 | Output output; 20 | JobAdditional additional; 21 | MonopolyAllocator::MonopolyDataPointer mono_tensor; 22 | std::shared_ptr> pro; 23 | }; 24 | 25 | virtual ~InferController(){ 26 | stop(); 27 | } 28 | 29 | void stop(){ 30 | run_ = false; 31 | cond_.notify_all(); 32 | 33 | ////////////////////////////////////////// cleanup jobs 34 | { 35 | std::unique_lock l(jobs_lock_); 36 | while(!jobs_.empty()){ 37 | auto& item = jobs_.front(); 38 | if(item.pro) 39 | item.pro->set_value(Output()); 40 | jobs_.pop(); 41 | } 42 | }; 43 | 44 | if(worker_){ 45 | worker_->join(); 46 | worker_.reset(); 47 | } 48 | } 49 | 50 | bool startup(const StartParam& param){ 51 | run_ = true; 52 | 53 | std::promise pro; 54 | start_param_ = param; 55 | worker_ = std::make_shared(&InferController::worker, this, std::ref(pro)); 56 | return pro.get_future().get(); 57 | } 58 | 59 | virtual std::shared_future commit(const Input& input){ 60 | 61 | Job job; 62 | job.pro = std::make_shared>(); 63 | if(!preprocess(job, input)){ 64 | job.pro->set_value(Output()); 65 | return job.pro->get_future(); 66 | } 67 | 68 | /////////////////////////////////////////////////////////// 69 | { 70 | std::unique_lock l(jobs_lock_); 71 | jobs_.push(job); 72 | }; 73 | cond_.notify_one(); 74 | return job.pro->get_future(); 75 | } 76 | 77 | virtual std::vector> commits(const std::vector& inputs){ 78 | 79 | int batch_size = std::min((int)inputs.size(), this->tensor_allocator_->capacity()); 80 | std::vector jobs(inputs.size()); 81 | std::vector> results(inputs.size()); 82 | 83 | int nepoch = (inputs.size() + batch_size - 1) / batch_size; 84 | for(int epoch = 0; epoch < nepoch; ++epoch){ 85 | int begin = epoch * batch_size; 86 | int end = std::min((int)inputs.size(), begin + batch_size); 87 | 88 | for(int i = begin; i < end; ++i){ 89 | Job& job = jobs[i]; 90 | job.pro = std::make_shared>(); 91 | if(!preprocess(job, inputs[i])){ 92 | job.pro->set_value(Output()); 93 | } 94 | results[i] = job.pro->get_future(); 95 | } 96 | 97 | /////////////////////////////////////////////////////////// 98 | { 99 | std::unique_lock l(jobs_lock_); 100 | for(int i = begin; i < end; ++i){ 101 | jobs_.emplace(std::move(jobs[i])); 102 | }; 103 | } 104 | cond_.notify_one(); 105 | } 106 | return results; 107 | } 108 | 109 | protected: 110 | virtual void worker(std::promise& result) = 0; 111 | virtual bool preprocess(Job& job, const Input& input) = 0; 112 | 113 | virtual bool get_jobs_and_wait(std::vector& fetch_jobs, int max_size){ 114 | 115 | std::unique_lock l(jobs_lock_); 116 | cond_.wait(l, [&](){ 117 | return !run_ || !jobs_.empty(); 118 | }); 119 | 120 | if(!run_) return false; 121 | 122 | fetch_jobs.clear(); 123 | for(int i = 0; i < max_size && !jobs_.empty(); ++i){ 124 | fetch_jobs.emplace_back(std::move(jobs_.front())); 125 | jobs_.pop(); 126 | } 127 | return true; 128 | } 129 | 130 | virtual bool get_job_and_wait(Job& fetch_job){ 131 | 132 | std::unique_lock l(jobs_lock_); 133 | cond_.wait(l, [&](){ 134 | return !run_ || !jobs_.empty(); 135 | }); 136 | 137 | if(!run_) return false; 138 | 139 | fetch_job = std::move(jobs_.front()); 140 | jobs_.pop(); 141 | return true; 142 | } 143 | 144 | protected: 145 | StartParam start_param_; 146 | std::atomic run_; 147 | std::mutex jobs_lock_; 148 | std::queue jobs_; 149 | std::shared_ptr worker_; 150 | std::condition_variable cond_; 151 | std::shared_ptr> tensor_allocator_; 152 | }; 153 | 154 | #endif // INFER_CONTROLLER_HPP -------------------------------------------------------------------------------- /src/tensorRT/common/monopoly_allocator.hpp: -------------------------------------------------------------------------------- 1 | /** 2 | * 独占分配器 3 | * 用以解决以下问题: 4 | * 1. 实现tensor复用的问题 5 | * 2. 对于tensor使用的两个阶段实现并行,时间重叠 6 | * 阶段一:预处理准备 7 | * 阶段二:模型推理 8 | * 9 | * 设计思路: 10 | * 以海底捞吃火锅为类比,座位分为两种:堂内吃饭的座位、厅外等候的座位 11 | * 12 | * 1. 初始状态,堂内有10个座位,厅外有10个座位,全部空 13 | * 2. 来了30个人吃火锅 14 | * 3. 流程是,先安排10个人坐在厅外修整,20人个人排队等候 15 | * 4. 由于堂内没人,所以调度坐在厅外的10个人进入堂内,开始吃火锅。厅外的10个座位为空 16 | * 5. 由于厅外没人,所以可以让排队的20人中,取10个人在厅外修整 17 | * 6. 此时状态为,堂内10人,厅外10人,等候10人 18 | * 7. 经过60分钟后,堂内10人吃完,紧接着执行步骤4 19 | * 20 | * 在实际工作中,通常图像输入过程有预处理、推理 21 | * 我们的目的是让预处理和推理时间进行重叠。因此设计了一个缓冲区,类似厅外等候区的那种形式 22 | * 当我们输入图像时,具有2倍batch的空间进行预处理用于缓存 23 | * 而引擎推理时,每次拿1个batch的数据进行推理 24 | * 当引擎推理速度慢而预处理速度快时,输入图像势必需要进行等候。否则缓存队列会越来越大 25 | * 而这里提到的几个点就是设计的主要目标 26 | **/ 27 | 28 | #ifndef MONOPOLY_ALLOCATOR_HPP 29 | #define MONOPOLY_ALLOCATOR_HPP 30 | 31 | #include 32 | #include 33 | #include 34 | #include 35 | 36 | template 37 | class MonopolyAllocator{ 38 | public: 39 | /* Data是数据容器类 40 | 允许query获取的item执行item->release释放自身所有权,该对象可以被复用 41 | 通过item->data()获取储存的对象的指针 42 | */ 43 | class MonopolyData{ 44 | public: 45 | std::shared_ptr<_ItemType>& data(){ return data_; } 46 | void release(){manager_->release_one(this);} 47 | 48 | private: 49 | MonopolyData(MonopolyAllocator* pmanager){manager_ = pmanager;} 50 | 51 | private: 52 | friend class MonopolyAllocator; 53 | MonopolyAllocator* manager_ = nullptr; 54 | std::shared_ptr<_ItemType> data_; 55 | bool available_ = true; 56 | }; 57 | typedef std::shared_ptr MonopolyDataPointer; 58 | 59 | MonopolyAllocator(int size){ 60 | capacity_ = size; 61 | num_available_ = size; 62 | datas_.resize(size); 63 | 64 | for(int i = 0; i < size; ++i) 65 | datas_[i] = std::shared_ptr(new MonopolyData(this)); 66 | } 67 | 68 | virtual ~MonopolyAllocator(){ 69 | run_ = false; 70 | cv_.notify_all(); 71 | 72 | std::unique_lock l(lock_); 73 | cv_exit_.wait(l, [&](){ 74 | return num_wait_thread_ == 0; 75 | }); 76 | } 77 | 78 | /* 获取一个可用的对象 79 | timeout:超时时间,如果没有可用的对象,将会进入阻塞等待,如果等待超时则返回空指针 80 | 请求得到一个对象后,该对象被占用,除非他执行了release释放该对象所有权 81 | */ 82 | MonopolyDataPointer query(int timeout = 10000){ 83 | 84 | std::unique_lock l(lock_); 85 | if(!run_) return nullptr; 86 | 87 | if(num_available_ == 0){ 88 | num_wait_thread_++; 89 | 90 | auto state = cv_.wait_for(l, std::chrono::milliseconds(timeout), [&](){ 91 | return num_available_ > 0 || !run_; 92 | }); 93 | 94 | num_wait_thread_--; 95 | cv_exit_.notify_one(); 96 | 97 | // timeout, no available, exit program 98 | if(!state || num_available_ == 0 || !run_) 99 | return nullptr; 100 | } 101 | 102 | auto item = std::find_if(datas_.begin(), datas_.end(), [](MonopolyDataPointer& item){return item->available_;}); 103 | if(item == datas_.end()) 104 | return nullptr; 105 | 106 | (*item)->available_ = false; 107 | num_available_--; 108 | return *item; 109 | } 110 | 111 | int num_available(){ 112 | return num_available_; 113 | } 114 | 115 | int capacity(){ 116 | return capacity_; 117 | } 118 | 119 | private: 120 | void release_one(MonopolyData* prq){ 121 | std::unique_lock l(lock_); 122 | if(!prq->available_){ 123 | prq->available_ = true; 124 | num_available_++; 125 | cv_.notify_one(); 126 | } 127 | } 128 | 129 | private: 130 | std::mutex lock_; 131 | std::condition_variable cv_; 132 | std::condition_variable cv_exit_; 133 | std::vector datas_; 134 | int capacity_ = 0; 135 | volatile int num_available_ = 0; 136 | volatile int num_wait_thread_ = 0; 137 | volatile bool run_ = true; 138 | }; 139 | 140 | #endif // MONOPOLY_ALLOCATOR_HPP -------------------------------------------------------------------------------- /src/tensorRT/common/preprocess_kernel.cuh: -------------------------------------------------------------------------------- 1 | #ifndef PREPROCESS_KERNEL_CUH 2 | #define PREPROCESS_KERNEL_CUH 3 | 4 | #include "cuda_tools.hpp" 5 | 6 | namespace CUDAKernel{ 7 | 8 | enum class NormType : int{ 9 | None = 0, 10 | MeanStd = 1, 11 | AlphaBeta = 2 12 | }; 13 | 14 | enum class ChannelType : int{ 15 | None = 0, 16 | Invert = 1 17 | }; 18 | 19 | struct Norm{ 20 | float mean[3]; 21 | float std[3]; 22 | float alpha, beta; 23 | NormType type = NormType::None; 24 | ChannelType channel_type = ChannelType::None; 25 | 26 | // out = (x * alpha - mean) / std 27 | static Norm mean_std(const float mean[3], const float std[3], float alpha = 1/255.0f, ChannelType channel_type=ChannelType::None); 28 | 29 | // out = x * alpha + beta 30 | static Norm alpha_beta(float alpha, float beta = 0, ChannelType channel_type=ChannelType::None); 31 | 32 | // None 33 | static Norm None(); 34 | }; 35 | 36 | void resize_bilinear_and_normalize( 37 | uint8_t* src, int src_line_size, int src_width, int src_height, float* dst, int dst_width, int dst_height, 38 | const Norm& norm, 39 | cudaStream_t stream); 40 | 41 | void warp_affine_bilinear_and_normalize_plane( 42 | uint8_t* src, int src_line_size, int src_width, int src_height, 43 | float* dst , int dst_width, int dst_height, 44 | float* matrix_2_3, uint8_t const_value, const Norm& norm, 45 | cudaStream_t stream); 46 | 47 | void warp_affine_bilinear_and_normalize_focus( 48 | uint8_t* src, int src_line_size, int src_width, int src_height, 49 | float* dst , int dst_width, int dst_height, 50 | float* matrix_2_3, uint8_t const_value, const Norm& norm, 51 | cudaStream_t stream); 52 | 53 | void norm_feature( 54 | float* feature_array, int num_feature, int feature_length, 55 | cudaStream_t stream 56 | ); 57 | 58 | void convert_nv12_to_bgr_invoke( 59 | const uint8_t* y, const uint8_t* uv, int width, int height, 60 | int linesize, uint8_t* dst, 61 | cudaStream_t stream); 62 | }; 63 | 64 | #endif // PREPROCESS_KERNEL_CUH -------------------------------------------------------------------------------- /src/tensorRT/import_lib.cpp: -------------------------------------------------------------------------------- 1 |  2 | #if defined(_WIN32) 3 | # define U_OS_WINDOWS 4 | #else 5 | # define U_OS_LINUX 6 | #endif 7 | 8 | #ifdef U_OS_WINDOWS 9 | #if defined(_DEBUG) 10 | # pragma comment(lib, "opencv_world346d.lib") 11 | #else 12 | # pragma comment(lib, "opencv_world346.lib") 13 | #endif 14 | 15 | //导入cuda 16 | #pragma comment(lib, "cuda.lib") 17 | #pragma comment(lib, "cudart.lib") 18 | #pragma comment(lib, "cublas.lib") 19 | #pragma comment(lib, "cudnn.lib") 20 | 21 | //导入tensorRT 22 | #pragma comment(lib, "nvinfer.lib") 23 | #pragma comment(lib, "nvinfer_plugin.lib") 24 | //#pragma comment(lib, "nvparsers.lib") 25 | 26 | #if defined(_DEBUG) 27 | #pragma comment(lib, "libprotobufd.lib") 28 | #else 29 | #pragma comment(lib, "libprotobuf.lib") 30 | #endif 31 | 32 | #ifdef HAS_PYTHON 33 | #pragma comment(lib, "python37.lib") 34 | #endif 35 | 36 | #endif // U_OS_WINDOWS -------------------------------------------------------------------------------- /src/tensorRT/infer/trt_infer.hpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | #ifndef TRT_INFER_HPP 4 | #define TRT_INFER_HPP 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace TRT { 13 | 14 | class Infer { 15 | public: 16 | virtual void forward(bool sync = true) = 0; 17 | virtual int get_max_batch_size() = 0; 18 | virtual void set_stream(CUStream stream) = 0; 19 | virtual CUStream get_stream() = 0; 20 | virtual void synchronize() = 0; 21 | virtual size_t get_device_memory_size() = 0; 22 | virtual std::shared_ptr get_workspace() = 0; 23 | virtual std::shared_ptr input (int index = 0) = 0; 24 | virtual std::shared_ptr output(int index = 0) = 0; 25 | virtual std::shared_ptr tensor(const std::string& name) = 0; 26 | virtual std::string get_input_name (int index = 0) = 0; 27 | virtual std::string get_output_name(int index = 0) = 0; 28 | virtual bool is_output_name(const std::string& name) = 0; 29 | virtual bool is_input_name (const std::string& name) = 0; 30 | virtual int num_output() = 0; 31 | virtual int num_input() = 0; 32 | virtual void print() = 0; 33 | virtual int device() = 0; 34 | virtual void set_input (int index, std::shared_ptr tensor) = 0; 35 | virtual void set_output(int index, std::shared_ptr tensor) = 0; 36 | virtual std::shared_ptr> serial_engine() = 0; 37 | }; 38 | 39 | struct DeviceMemorySummary { 40 | size_t total; 41 | size_t available; 42 | }; 43 | 44 | DeviceMemorySummary get_current_device_summary(); 45 | int get_device_count(); 46 | int get_device(); 47 | 48 | void set_device(int device_id); 49 | std::shared_ptr load_infer_from_memory(const void* pdata, size_t size); 50 | std::shared_ptr load_infer(const std::string& file); 51 | bool init_nv_plugins(); 52 | 53 | }; //TRTInfer 54 | 55 | 56 | #endif //TRT_INFER_HPP -------------------------------------------------------------------------------- /src/tensorRT/onnx/onnx_pb.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) ONNX Project Contributors. 2 | // Licensed under the MIT license. 3 | 4 | #ifndef ONNX_ONNX_PB_H 5 | #define ONNX_ONNX_PB_H 6 | 7 | // Defines ONNX_EXPORT and ONNX_IMPORT. On Windows, this corresponds to 8 | // different declarations (dllexport and dllimport). On Linux/Mac, it just 9 | // resolves to the same "default visibility" setting. 10 | #if defined(_MSC_VER) 11 | #if defined(ONNX_BUILD_SHARED_LIBS) || defined(ONNX_BUILD_MAIN_LIB) 12 | #define ONNX_EXPORT __declspec(dllexport) 13 | #define ONNX_IMPORT __declspec(dllimport) 14 | #else 15 | #define ONNX_EXPORT 16 | #define ONNX_IMPORT 17 | #endif 18 | #else 19 | #if defined(__GNUC__) 20 | #define ONNX_EXPORT __attribute__((__visibility__("default"))) 21 | #else 22 | #define ONNX_EXPORT 23 | #endif 24 | #define ONNX_IMPORT ONNX_EXPORT 25 | #endif 26 | 27 | // ONNX_API is a macro that, depends on whether you are building the 28 | // main ONNX library or not, resolves to either ONNX_EXPORT or 29 | // ONNX_IMPORT. 30 | // 31 | // This is used in e.g. ONNX's protobuf files: when building the main library, 32 | // it is defined as ONNX_EXPORT to fix a Windows global-variable-in-dll 33 | // issue, and for anyone dependent on ONNX it will be defined as 34 | // ONNX_IMPORT. ONNX_BUILD_MAIN_LIB can also be set when being built 35 | // statically if ONNX is being linked into a shared library that wants 36 | // to export the ONNX APIs and classes. 37 | // 38 | // More details on Windows dllimport / dllexport can be found at 39 | // https://msdn.microsoft.com/en-us/library/3y1sfaz2.aspx 40 | // 41 | // This solution is similar to 42 | // https://github.com/pytorch/pytorch/blob/master/caffe2/core/common.h 43 | #define ONNX_API 44 | #include "onnx-ml.pb.h" 45 | 46 | #endif // ! ONNX_ONNX_PB_H 47 | -------------------------------------------------------------------------------- /src/tensorRT/onnx/readme.md: -------------------------------------------------------------------------------- 1 | # ONNX 2 | - 这几个文件来自于对ONNX的编译后提取的结果,由protoc生成的cpp 3 | - https://github.com/onnx/onnx -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/LoopHelpers.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "LoopHelpers.hpp" 6 | #include "onnx2trt_utils.hpp" 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial) 12 | { 13 | nvinfer1::ITensor* initialTensor = addConstantScalar(ctx, initial, ::onnx::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0); 14 | nvinfer1::ITensor* one = addConstantScalar(ctx, 1, ::onnx::TensorProto::INT32, nvinfer1::Dims{1, 1})->getOutput(0); 15 | 16 | auto counter = loop->addRecurrence(*initialTensor); 17 | nvinfer1::ITensor* addOne = ctx->network()->addElementWise(*counter->getOutput(0), *one, nvinfer1::ElementWiseOperation::kSUM)->getOutput(0); 18 | counter->setInput(1, *addOne); 19 | return counter->getOutput(0); 20 | } 21 | 22 | } // namespace onnx2trt 23 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/LoopHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | #include "ImporterContext.hpp" 10 | 11 | namespace onnx2trt 12 | { 13 | 14 | nvinfer1::ITensor* addLoopCounter(IImporterContext* ctx, nvinfer1::ILoop* loop, int32_t initial = 0); 15 | 16 | } // namespace onnx2trt 17 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/ModelImporter.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ImporterContext.hpp" 8 | #include "NvInferPlugin.h" 9 | #include "NvOnnxParser.h" 10 | #include "builtin_op_importers.hpp" 11 | #include "utils.hpp" 12 | 13 | namespace onnx2trt 14 | { 15 | 16 | Status parseGraph(IImporterContext* ctx, const ::onnx::GraphProto& graph, bool deserializingINetwork = false, int* currentNode = nullptr); 17 | 18 | class ModelImporter : public nvonnxparser::IParser 19 | { 20 | protected: 21 | string_map _op_importers; 22 | virtual Status importModel(::onnx::ModelProto const& model); 23 | 24 | private: 25 | ImporterContext _importer_ctx; 26 | std::list<::onnx::ModelProto> _onnx_models; // Needed for ownership of weights 27 | int _current_node; 28 | std::vector _errors; 29 | std::vector _input_dims; 30 | 31 | public: 32 | ModelImporter(nvinfer1::INetworkDefinition* network, nvinfer1::ILogger* logger, const std::vector& input_dims) 33 | : _op_importers(getBuiltinOpImporterMap()) 34 | , _importer_ctx(network, logger) 35 | , _input_dims(input_dims) 36 | { 37 | } 38 | bool parseWithWeightDescriptors(void const* serialized_onnx_model, size_t serialized_onnx_model_size) override; 39 | bool parse(void const* serialized_onnx_model, size_t serialized_onnx_model_size, const char* model_path = nullptr) override; 40 | bool supportsModel(void const* serialized_onnx_model, size_t serialized_onnx_model_size, 41 | SubGraphCollection_t& sub_graph_collection, const char* model_path = nullptr) override; 42 | 43 | bool supportsOperator(const char* op_name) const override; 44 | void destroy() override 45 | { 46 | delete this; 47 | } 48 | // virtual void registerOpImporter(std::string op, 49 | // NodeImporter const &node_importer) override { 50 | // // Note: This allows existing importers to be replaced 51 | // _op_importers[op] = node_importer; 52 | //} 53 | // virtual Status const &setInput(const char *name, 54 | // nvinfer1::ITensor *input) override; 55 | // virtual Status const& setOutput(const char* name, nvinfer1::ITensor** output) override; 56 | int getNbErrors() const override 57 | { 58 | return _errors.size(); 59 | } 60 | nvonnxparser::IParserError const* getError(int index) const override 61 | { 62 | assert(0 <= index && index < (int) _errors.size()); 63 | return &_errors[index]; 64 | } 65 | void clearErrors() override 66 | { 67 | _errors.clear(); 68 | } 69 | 70 | //...LG: Move the implementation to .cpp 71 | bool parseFromFile(const char* onnxModelFile, int verbosity) override; 72 | bool parseFromData(const void* onnx_data, size_t size, int verbosity) override; 73 | }; 74 | 75 | } // namespace onnx2trt 76 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/NvOnnxParser.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "NvOnnxParser.h" 6 | #include "ModelImporter.hpp" 7 | 8 | extern "C" void* createNvOnnxParser_INTERNAL(void* network_, void* logger_, int version, const std::vector& input_dims) 9 | { 10 | auto network = static_cast(network_); 11 | auto logger = static_cast(logger_); 12 | return new onnx2trt::ModelImporter(network, logger, input_dims); 13 | } 14 | 15 | extern "C" int getNvOnnxParserVersion() 16 | { 17 | return NV_ONNX_PARSER_VERSION; 18 | } -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/OnnxAttrs.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "ImporterContext.hpp" 13 | 14 | class OnnxAttrs 15 | { 16 | template 17 | using string_map = std::unordered_map; 18 | typedef string_map<::onnx::AttributeProto const*> AttrMap; 19 | AttrMap _attrs; 20 | onnx2trt::IImporterContext* mCtx; 21 | 22 | public: 23 | explicit OnnxAttrs(::onnx::NodeProto const& onnx_node, onnx2trt::IImporterContext* ctx) 24 | : mCtx{ctx} 25 | { 26 | for (auto const& attr : onnx_node.attribute()) 27 | { 28 | _attrs.insert({attr.name(), &attr}); 29 | } 30 | } 31 | 32 | bool count(const std::string& key) const 33 | { 34 | return _attrs.count(key); 35 | } 36 | 37 | ::onnx::AttributeProto const* at(std::string key) const 38 | { 39 | if (!_attrs.count(key)) 40 | { 41 | throw std::out_of_range("Attribute not found: " + key); 42 | } 43 | return _attrs.at(key); 44 | } 45 | 46 | ::onnx::AttributeProto::AttributeType type(const std::string& key) const 47 | { 48 | return this->at(key)->type(); 49 | } 50 | 51 | 52 | template 53 | T get(const std::string& key) const; 54 | 55 | template 56 | T get(const std::string& key, T const& default_value) const 57 | { 58 | return _attrs.count(key) ? this->get(key) : default_value; 59 | } 60 | }; 61 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/RNNHelpers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include "TensorOrWeights.hpp" 12 | #include "ImporterContext.hpp" 13 | 14 | namespace onnx2trt 15 | { 16 | 17 | nvinfer1::ITensor* addRNNInput(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, std::vector& inputs, const std::string& direction); 18 | 19 | // Zeros out invalid timesteps in toMask. maxLen must be provided if reverse is true 20 | nvinfer1::ITensor* clearMissingSequenceElements(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* toMask, nvinfer1::ITensor* maxLen, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 21 | 22 | // Returns a bool tensor which is true during valid timesteps 23 | nvinfer1::ITensor* getRaggedMask(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen = nullptr, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 24 | 25 | // Selects between prevH and Ht to forward previous hidden state through invalid timesteps 26 | nvinfer1::ITensor* maskRNNHidden(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* prevH, nvinfer1::ITensor* Ht, nvinfer1::ITensor* maxLen = nullptr, bool reverse = false, nvinfer1::ITensor* counter = nullptr); 27 | 28 | // Splits a bidirectional hidden state into forward and reverse passes, masks each using maskRNNHidden, then concatenates 29 | nvinfer1::ITensor* maskBidirRNNHidden(IImporterContext* ctx, const ::onnx::NodeProto& node, nvinfer1::ILoop* loop, nvinfer1::ITensor* seqLens, nvinfer1::ITensor* maxLen, nvinfer1::ITensor* Ht1, nvinfer1::ITensor* Ht, nvinfer1::ITensor* singlePassShape); 30 | 31 | } // namespace onnx2trt 32 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/ShapedWeights.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "ShapedWeights.hpp" 6 | #include "onnx2trt_utils.hpp" 7 | #include "trt_utils.hpp" 8 | #include 9 | #include 10 | 11 | namespace onnx2trt 12 | { 13 | 14 | size_t ShapedWeights::count() const 15 | { 16 | if (this->values == nullptr && this->shape.nbDims <= 0) 17 | { 18 | return 0; 19 | } 20 | // TRT supports scalars, so 0D tensors should have a count of 1. 21 | size_t c = 1; 22 | for (int i = 0; i < this->shape.nbDims; ++i) 23 | { 24 | c *= this->shape.d[i]; 25 | } 26 | return c; 27 | } 28 | 29 | ShapedWeights ShapedWeights::empty(DataType type) 30 | { 31 | return ShapedWeights(type, nullptr, nvinfer1::Dims{0}); 32 | } 33 | 34 | ShapedWeights::ShapedWeights() 35 | : values(nullptr) 36 | , shape{0} 37 | { 38 | } 39 | 40 | ShapedWeights::ShapedWeights(DataType type_, void* values_, nvinfer1::Dims shape_) 41 | : type(type_) 42 | , values(values_) 43 | , shape(shape_) 44 | { 45 | // Note: this->shape.type[] is not used 46 | } 47 | 48 | size_t ShapedWeights::size_bytes() const 49 | { 50 | return this->count() * getDtypeSize(this->type); 51 | } 52 | 53 | ShapedWeights::operator bool() const 54 | { 55 | return (bool) this->values; 56 | } 57 | 58 | ShapedWeights::operator nvinfer1::Weights() const 59 | { 60 | nvinfer1::Weights w{}; 61 | w.values = this->values; 62 | bool supported_type = convertDtype(this->type, &w.type); 63 | (void) supported_type; 64 | assert(supported_type); 65 | w.count = this->count(); 66 | return w; 67 | } 68 | 69 | const char* ShapedWeights::getName() const 70 | { 71 | return this->name; 72 | } 73 | 74 | void ShapedWeights::setName(const char* name) 75 | { 76 | this->name = name; 77 | } 78 | 79 | template 80 | void transpose4DWeights(ShapedWeights const& weights, nvinfer1::Permutation const perm, ShapedWeights* result) 81 | { 82 | nvinfer1::Dims original_shape = weights.shape; 83 | nvinfer1::Dims new_shape = result->shape; 84 | int nbDims = new_shape.nbDims; 85 | DType const* src = reinterpret_cast(weights.values); 86 | DType* dst = reinterpret_cast(result->values); 87 | 88 | nvinfer1::Dims expanded_original_shape{4, {1, 1, 1, 1}}; 89 | nvinfer1::Dims expanded_new_shape{4, {1, 1, 1, 1}}; 90 | nvinfer1::Permutation expanded_perm{0, 1, 2, 3}; 91 | 92 | int pad = 4 - nbDims; 93 | for (int i = 0; i < nbDims; ++i) 94 | { 95 | expanded_original_shape.d[pad + i] = original_shape.d[i]; 96 | expanded_new_shape.d[pad + i] = new_shape.d[i]; 97 | expanded_perm.order[pad + i] = perm.order[i] + pad; 98 | } 99 | 100 | 101 | int src_strides[4] = {1, 1, 1, 1}; 102 | int dst_strides[4] = {1, 1, 1, 1}; 103 | 104 | for (int i = 2; i >= 0; --i) 105 | { 106 | src_strides[i] = expanded_original_shape.d[i + 1] * src_strides[i + 1]; 107 | dst_strides[i] = expanded_new_shape.d[i + 1] * dst_strides[i + 1]; 108 | } 109 | 110 | for (int n = 0; n < expanded_original_shape.d[0]; ++n) 111 | { 112 | for (int c = 0; c < expanded_original_shape.d[1]; ++c) 113 | { 114 | for (int h = 0; h < expanded_original_shape.d[2]; ++h) 115 | { 116 | for (int w = 0; w < expanded_original_shape.d[3]; ++w) 117 | { 118 | int src_index = 0; 119 | int dst_index = 0; 120 | int src_coord[4] = {n, c, h, w}; 121 | int dst_coord[4]; 122 | for (int i = 0 ; i < 4; ++i) 123 | { 124 | dst_coord[i] = src_coord[expanded_perm.order[i]]; 125 | src_index += src_coord[i] * src_strides[i]; 126 | dst_index += dst_coord[i] * dst_strides[i]; 127 | } 128 | dst[dst_index] = src[src_index]; 129 | } 130 | } 131 | } 132 | } 133 | } 134 | 135 | bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result, IImporterContext* ctx) 136 | { 137 | nvinfer1::Dims shape = weights.shape; 138 | int nbDims = shape.nbDims; 139 | nvinfer1::Dims new_shape; 140 | new_shape.nbDims = nbDims; 141 | for (int d = 0; d < nbDims; ++d) 142 | { 143 | new_shape.d[d] = shape.d[perm.order[d]]; 144 | result->shape.d[d] = new_shape.d[d]; 145 | } 146 | 147 | if (shape.nbDims <= 4) 148 | { 149 | if (weights.type == ::onnx::TensorProto::FLOAT) 150 | { 151 | transpose4DWeights(weights, perm, result); 152 | } 153 | else if (weights.type == ::onnx::TensorProto::FLOAT16) 154 | { 155 | transpose4DWeights(weights, perm, result); 156 | } 157 | else 158 | { 159 | return false; 160 | } 161 | } 162 | else 163 | { 164 | // TODO: Implement general transposes and multiple data types 165 | // Unsupported weights transpose 166 | return false; 167 | } 168 | nvinfer1::Dims permDims{nbDims, {}}; 169 | std::copy_n(perm.order, nbDims, permDims.d); 170 | LOG_WARNING("Weights " 171 | << weights.getName() << " has been transposed with permutation of " << permDims 172 | << "! If you plan on overwriting the weights with the Refitter API, the new weights must be pre-transposed."); 173 | result->setName(weights.getName()); 174 | return true; 175 | } 176 | 177 | } // namespace onnx2trt 178 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/ShapedWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | class ShapedWeights 14 | { 15 | public: 16 | using DataType = int32_t; 17 | 18 | static ShapedWeights empty(DataType type); 19 | 20 | ShapedWeights(); 21 | 22 | explicit ShapedWeights(DataType type, void* values, nvinfer1::Dims shape_); 23 | 24 | size_t count() const; 25 | 26 | size_t size_bytes() const; 27 | 28 | const char* getName() const; 29 | 30 | void setName(const char* name); 31 | 32 | explicit operator bool() const; 33 | 34 | operator nvinfer1::Weights() const; 35 | 36 | template 37 | T& at(size_t index) 38 | { 39 | assert(index >= 0 && (index * sizeof(T)) < size_bytes()); 40 | return static_cast(values)[index]; 41 | } 42 | 43 | template 44 | const T& at(size_t index) const 45 | { 46 | assert(index >= 0 && (index * sizeof(T)) < size_bytes()); 47 | return static_cast(values)[index]; 48 | } 49 | 50 | public: 51 | DataType type; 52 | void* values; 53 | nvinfer1::Dims shape; 54 | const char* name{}; 55 | }; 56 | 57 | class IImporterContext; 58 | bool transposeWeights(ShapedWeights const& weights, nvinfer1::Permutation const& perm, ShapedWeights* result, IImporterContext* ctx); 59 | 60 | } // namespace onnx2trt 61 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/TensorOrWeights.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "ShapedWeights.hpp" 8 | 9 | #include 10 | #include 11 | 12 | namespace onnx2trt 13 | { 14 | 15 | class TensorOrWeights 16 | { 17 | union 18 | { 19 | nvinfer1::ITensor* _tensor; 20 | ShapedWeights _weights; 21 | }; 22 | enum 23 | { 24 | NODE_TENSOR, 25 | NODE_WEIGHTS 26 | } _variant; 27 | 28 | public: 29 | TensorOrWeights() 30 | : _tensor(nullptr) 31 | , _variant(NODE_TENSOR) 32 | { 33 | } 34 | TensorOrWeights(nvinfer1::ITensor* tensor) 35 | : _tensor(tensor) 36 | , _variant(NODE_TENSOR) 37 | { 38 | } 39 | TensorOrWeights(ShapedWeights const& weights) 40 | : _weights(weights) 41 | , _variant(NODE_WEIGHTS) 42 | { 43 | } 44 | bool is_tensor() const 45 | { 46 | return _variant == NODE_TENSOR; 47 | } 48 | bool is_weights() const 49 | { 50 | return _variant == NODE_WEIGHTS; 51 | } 52 | bool isNullTensor() const 53 | { 54 | return is_tensor() && _tensor == nullptr; 55 | } 56 | nvinfer1::ITensor& tensor() 57 | { 58 | assert(!isNullTensor()); 59 | return *_tensor; 60 | } 61 | nvinfer1::ITensor const& tensor() const 62 | { 63 | assert(!isNullTensor()); 64 | return *_tensor; 65 | } 66 | ShapedWeights& weights() 67 | { 68 | assert(is_weights()); 69 | return _weights; 70 | } 71 | ShapedWeights const& weights() const 72 | { 73 | assert(is_weights()); 74 | return _weights; 75 | } 76 | nvinfer1::Dims shape() const 77 | { 78 | return is_tensor() ? _tensor->getDimensions() : _weights.shape; 79 | } 80 | explicit operator bool() const 81 | { 82 | return is_tensor() ? _tensor != nullptr : static_cast(_weights); 83 | } 84 | bool isInt32() const 85 | { 86 | return is_tensor() ? _tensor->getType() == nvinfer1::DataType::kINT32 : _weights.type == ::onnx::TensorProto_DataType_INT32; 87 | } 88 | bool isBool() const 89 | { 90 | return is_tensor() ? _tensor->getType() == nvinfer1::DataType::kBOOL : _weights.type == ::onnx::TensorProto_DataType_BOOL; 91 | } 92 | std::string getName() const 93 | { 94 | return is_tensor() ? _tensor->getName() : _weights.getName(); 95 | } 96 | std::string getType() const 97 | { 98 | if (is_tensor()) 99 | { 100 | switch(_tensor->getType()) 101 | { 102 | case nvinfer1::DataType::kFLOAT:return "FLOAT"; 103 | case nvinfer1::DataType::kHALF: return "HALF"; 104 | case nvinfer1::DataType::kINT8: return "INT8"; 105 | case nvinfer1::DataType::kINT32: return "INT32"; 106 | case nvinfer1::DataType::kBOOL: return "BOOL"; 107 | default: return "UNKNOWN TYPE"; 108 | } 109 | } 110 | else 111 | { 112 | switch(_weights.type) 113 | { 114 | case ::onnx::TensorProto::DOUBLE: return "DOUBLE -> FLOAT"; 115 | case ::onnx::TensorProto::FLOAT: return "FLOAT"; 116 | case ::onnx::TensorProto::INT8: return "INT8"; 117 | case ::onnx::TensorProto::FLOAT16: return "HALF"; 118 | case ::onnx::TensorProto::BOOL: return "BOOL"; 119 | case ::onnx::TensorProto::INT32: return "INT32"; 120 | case ::onnx::TensorProto::INT64: return "INT64 -> INT32"; 121 | default: return "UNKNOWN TYPE"; 122 | } 123 | } 124 | } 125 | }; 126 | 127 | } // namespace onnx2trt 128 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/builtin_op_importers.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "onnx2trt.hpp" 8 | #include "utils.hpp" 9 | 10 | namespace onnx2trt 11 | { 12 | 13 | string_map& getBuiltinOpImporterMap(); 14 | 15 | } // namespace onnx2trt 16 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/onnx2trt.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "NvOnnxParser.h" 8 | #include "ShapedWeights.hpp" 9 | #include "Status.hpp" 10 | #include "TensorOrWeights.hpp" 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | namespace onnx2trt 21 | { 22 | 23 | class IImporterContext; 24 | 25 | // TODO: Find ABI-safe alternative approach for this: 26 | // Can't use std::vector 27 | // Can't use ::onnx::NodeProto 28 | // Can't use std::function 29 | typedef ValueOrStatus> NodeImportResult; 30 | typedef std::function& inputs)> 32 | NodeImporter; 33 | 34 | template 35 | using StringMap = std::unordered_map; 36 | 37 | class IImporterContext 38 | { 39 | public: 40 | virtual nvinfer1::INetworkDefinition* network() = 0; 41 | virtual StringMap& tensors() = 0; 42 | virtual StringMap& tensorLocations() = 0; 43 | virtual StringMap& tensorRangeMins() = 0; 44 | virtual StringMap& tensorRangeMaxes() = 0; 45 | virtual StringMap& layerPrecisions() = 0; 46 | virtual std::unordered_set& unsupportedShapeTensors() = 0; 47 | virtual StringMap& loopTensors() = 0; 48 | virtual void setOnnxFileLocation(std::string location) = 0; 49 | virtual std::string getOnnxFileLocation() = 0; 50 | virtual void registerTensor(TensorOrWeights tensor, const std::string& basename) = 0; 51 | virtual void registerLayer(nvinfer1::ILayer* layer, const std::string& basename) = 0; 52 | virtual ShapedWeights createTempWeights(ShapedWeights::DataType type, nvinfer1::Dims shape, uint8_t value = 0) = 0; 53 | virtual int64_t getOpsetVersion(const char* domain = "") const = 0; 54 | virtual nvinfer1::ILogger& logger() = 0; 55 | virtual bool hasError() const = 0; 56 | virtual nvinfer1::IErrorRecorder* getErrorRecorder() const = 0; 57 | 58 | protected: 59 | virtual ~IImporterContext() 60 | { 61 | } 62 | }; 63 | 64 | } // namespace onnx2trt 65 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/onnx2trt_common.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #if NV_TENSORRT_MAJOR < 4 11 | namespace nvinfer1 12 | { 13 | 14 | enum class PluginFormat : uint8_t 15 | { 16 | kNCHW = 0, //!< NCHW 17 | kNC2HW2 = 1, //!< NCHW with 2-element packed channels 18 | kNHWC8 = 2 //!< NHWC with 8-element packed channels (C 19 | //! must be a multiple of 8) 20 | }; 21 | // from NvInfer.h 22 | class IPluginExt : public IPlugin 23 | { 24 | public: 25 | virtual int getTensorRTVersion() const noexcept 26 | { 27 | return NV_TENSORRT_VERSION; 28 | } 29 | virtual bool supportsFormat(DataType type, PluginFormat format) const noexcept = 0; 30 | virtual void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, 31 | DataType type, PluginFormat format, int maxBatchSize) noexcept 32 | = 0; 33 | 34 | protected: 35 | void configure( 36 | const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, int maxBatchSize) noexcept final 37 | { 38 | try 39 | { 40 | DataType type = nvinfer1::DataType::kFLOAT; 41 | PluginFormat format = nvinfer1::PluginFormat::kLINEAR; 42 | return this->configureWithFormat(inputDims, nbInputs, outputDims, nbOutputs, type, format, maxBatchSize); 43 | } 44 | catch (const std::exception& e) 45 | { 46 | nvinfer1::getLogger()->log(nvinfer1::ILogger::Severity::kERROR, e.what().c_str()); 47 | } 48 | } 49 | virtual ~IPluginExt() 50 | { 51 | } 52 | }; 53 | 54 | } // namespace nvinfer1 55 | #endif 56 | 57 | namespace onnx2trt 58 | { 59 | 60 | struct IOwnable 61 | { 62 | virtual void destroy() = 0; 63 | 64 | protected: 65 | virtual ~IOwnable() 66 | { 67 | } 68 | }; 69 | 70 | struct OwnableDeleter 71 | { 72 | void operator()(IOwnable* obj) const 73 | { 74 | obj->destroy(); 75 | } 76 | }; 77 | 78 | using UniqueOwnable = std::unique_ptr; 79 | class Plugin; 80 | class PluginV2; 81 | 82 | } // namespace onnx2trt 83 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/onnx2trt_runtime.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "onnx2trt_common.hpp" 8 | 9 | namespace onnx2trt 10 | { 11 | 12 | typedef Plugin* (*plugin_deserializer)(const void* serialData, size_t serialLength); 13 | 14 | } // namespace onnx2trt 15 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/onnxErrorRecorder.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include "onnxErrorRecorder.hpp" 6 | #include 7 | 8 | namespace onnx2trt 9 | { 10 | 11 | 12 | ONNXParserErrorRecorder* ONNXParserErrorRecorder::create( 13 | nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder) 14 | { 15 | try 16 | { 17 | auto recorder = new ONNXParserErrorRecorder(logger, otherRecorder); 18 | if (recorder) 19 | { 20 | recorder->incRefCount(); 21 | } 22 | return recorder; 23 | } 24 | catch (const std::exception& e) 25 | { 26 | logError(logger, e.what()); 27 | return nullptr; 28 | } 29 | } 30 | 31 | void ONNXParserErrorRecorder::destroy(ONNXParserErrorRecorder*& recorder) 32 | { 33 | if (recorder) 34 | { 35 | recorder->decRefCount(); 36 | recorder = nullptr; 37 | } 38 | } 39 | 40 | void ONNXParserErrorRecorder::logError(nvinfer1::ILogger* logger, const char* str) 41 | { 42 | if (logger) 43 | { 44 | logger->log(ILogger::Severity::kERROR, str); 45 | } 46 | } 47 | 48 | ONNXParserErrorRecorder::ONNXParserErrorRecorder( 49 | nvinfer1::ILogger* logger, nvinfer1::IErrorRecorder* otherRecorder) 50 | : mUserRecorder(otherRecorder) 51 | , mLogger(logger) 52 | { 53 | if (mUserRecorder) 54 | { 55 | mUserRecorder->incRefCount(); 56 | } 57 | } 58 | 59 | ONNXParserErrorRecorder::~ONNXParserErrorRecorder() noexcept 60 | { 61 | if (mUserRecorder) 62 | { 63 | mUserRecorder->decRefCount(); 64 | } 65 | } 66 | 67 | void ONNXParserErrorRecorder::clear() noexcept 68 | { 69 | try 70 | { 71 | // grab a lock so that there is no addition while clearing. 72 | std::lock_guard guard(mStackLock); 73 | mErrorStack.clear(); 74 | } 75 | catch (const std::exception& e) 76 | { 77 | logError(mLogger, e.what()); 78 | } 79 | }; 80 | 81 | bool ONNXParserErrorRecorder::reportError( 82 | nvinfer1::ErrorCode val, nvinfer1::IErrorRecorder::ErrorDesc desc) noexcept 83 | { 84 | try 85 | { 86 | std::lock_guard guard(mStackLock); 87 | mErrorStack.push_back(errorPair(val, desc)); 88 | if (mUserRecorder) 89 | { 90 | mUserRecorder->reportError(val, desc); 91 | } 92 | else 93 | { 94 | logError(mLogger, desc); 95 | } 96 | } 97 | catch (const std::exception& e) 98 | { 99 | logError(mLogger, e.what()); 100 | } 101 | // All errors are considered fatal. 102 | return true; 103 | } 104 | 105 | nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::incRefCount() noexcept 106 | { 107 | // Atomically increment or decrement the ref counter. 108 | return ++mRefCount; 109 | } 110 | 111 | nvinfer1::IErrorRecorder::RefCount ONNXParserErrorRecorder::decRefCount() noexcept 112 | { 113 | auto newVal = --mRefCount; 114 | if (newVal == 0) 115 | { 116 | delete this; 117 | } 118 | return newVal; 119 | } 120 | 121 | } // namespace onnx2trt 122 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/onnxErrorRecorder.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "NvInferRuntimeCommon.h" 8 | #include "onnx2trt_utils.hpp" 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | namespace onnx2trt 16 | { 17 | 18 | //! 19 | //! A simple implementation of the IErrorRecorder interface for 20 | //! use by ONNX importer. 21 | //! ONNX-importer Error recorder is based on a vector that pairs the error 22 | //! code and the error string into a single element. It also uses 23 | //! standard mutex and atomics in order to make sure that the code 24 | //! works in a multi-threaded environment. 25 | //! 26 | class ONNXParserErrorRecorder : public nvinfer1::IErrorRecorder 27 | { 28 | using RefCount = nvinfer1::IErrorRecorder::RefCount; 29 | using ErrorDesc = nvinfer1::IErrorRecorder::ErrorDesc; 30 | using ErrorCode = nvinfer1::ErrorCode; 31 | using IErrorRecorder = nvinfer1::IErrorRecorder; 32 | using ILogger = nvinfer1::ILogger; 33 | 34 | using errorPair = std::pair; 35 | using errorStack = std::vector; 36 | 37 | public: 38 | static ONNXParserErrorRecorder* create( 39 | ILogger* logger, IErrorRecorder* otherRecorder = nullptr); 40 | 41 | static void destroy(ONNXParserErrorRecorder*& recorder); 42 | 43 | void clear() noexcept final; 44 | RefCount incRefCount() noexcept final; 45 | RefCount decRefCount() noexcept final; 46 | bool reportError(ErrorCode val, ErrorDesc desc) noexcept final; 47 | 48 | int32_t getNbErrors() const noexcept final 49 | { 50 | return mErrorStack.size(); 51 | } 52 | 53 | ErrorCode getErrorCode(int32_t errorIdx) const noexcept final 54 | { 55 | return invalidIndexCheck(errorIdx) ? ErrorCode::kINVALID_ARGUMENT : (*this)[errorIdx].first; 56 | } 57 | 58 | ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept final 59 | { 60 | return invalidIndexCheck(errorIdx) ? "errorIdx out of range." : (*this)[errorIdx].second.c_str(); 61 | } 62 | 63 | bool hasOverflowed() const noexcept final 64 | { 65 | // This class can never overflow since we have dynamic resize via std::vector usage. 66 | return false; 67 | } 68 | 69 | protected: 70 | ONNXParserErrorRecorder(ILogger* logger, IErrorRecorder* otherRecorder = nullptr); 71 | 72 | virtual ~ONNXParserErrorRecorder() noexcept; 73 | 74 | static void logError(ILogger* logger, const char* str); 75 | 76 | // Simple helper functions. 77 | const errorPair& operator[](size_t index) const noexcept 78 | { 79 | return mErrorStack[index]; 80 | } 81 | 82 | bool invalidIndexCheck(int32_t index) const noexcept 83 | { 84 | // By converting signed to unsigned, we only need a single check since 85 | // negative numbers turn into large positive greater than the size. 86 | size_t sIndex = index; 87 | return sIndex >= mErrorStack.size(); 88 | } 89 | // Mutex to hold when locking mErrorStack. 90 | std::mutex mStackLock; 91 | 92 | // Reference count of the class. Destruction of the class when mRefCount 93 | // is not zero causes undefined behavior. 94 | std::atomic mRefCount{0}; 95 | 96 | // The error stack that holds the errors recorded by TensorRT. 97 | errorStack mErrorStack; 98 | 99 | // Original error recorder (set by user) 100 | IErrorRecorder* mUserRecorder{nullptr}; 101 | 102 | // logger 103 | ILogger* mLogger{nullptr}; 104 | }; // class ONNXParserErrorRecorder 105 | 106 | } // namespace onnx2trt 107 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/onnx_utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | 13 | #pragma once 14 | 15 | namespace 16 | { 17 | 18 | template 19 | bool convertOnnxDims(OnnxDims const& onnxDims, nvinfer1::Dims& trtDims) 20 | { 21 | std::vector onnxDims_vector; 22 | for (const auto& onnxDim : onnxDims) 23 | { 24 | const int dim = onnxDim.dim_param() == "" ? (onnxDim.dim_value() >= 0 ? onnxDim.dim_value() : -1) : -1; 25 | onnxDims_vector.emplace_back(dim); 26 | } 27 | trtDims.nbDims = onnxDims_vector.size(); 28 | assert(trtDims.nbDims <= nvinfer1::Dims::MAX_DIMS); 29 | std::copy(onnxDims_vector.begin(), onnxDims_vector.end(), trtDims.d); 30 | return true; 31 | } 32 | 33 | // Removes raw data from the text representation of an ONNX model 34 | void remove_raw_data_strings(std::string& s) 35 | { 36 | std::string::size_type beg = 0; 37 | const std::string key = "raw_data: \""; 38 | const std::string sub = "..."; 39 | while ((beg = s.find(key, beg)) != std::string::npos) 40 | { 41 | beg += key.length(); 42 | std::string::size_type end = beg - 1; 43 | // Note: Must skip over escaped end-quotes 44 | while (s[(end = s.find("\"", ++end)) - 1] == '\\') 45 | { 46 | } 47 | if (end - beg > 128) 48 | { // Only remove large data strings 49 | s.replace(beg, end - beg, "..."); 50 | } 51 | beg += sub.length(); 52 | } 53 | } 54 | 55 | // Removes float_data, int32_data etc. from the text representation of an ONNX model 56 | std::string remove_repeated_data_strings(std::string& s) 57 | { 58 | std::istringstream iss(s); 59 | std::ostringstream oss; 60 | bool is_repeat = false; 61 | for (std::string line; std::getline(iss, line);) 62 | { 63 | if (line.find("float_data:") != std::string::npos || line.find("int32_data:") != std::string::npos 64 | || line.find("int64_data:") != std::string::npos) 65 | { 66 | if (!is_repeat) 67 | { 68 | is_repeat = true; 69 | oss << line.substr(0, line.find(":") + 1) << " ...\n"; 70 | } 71 | } 72 | else 73 | { 74 | is_repeat = false; 75 | oss << line << "\n"; 76 | } 77 | } 78 | return oss.str(); 79 | } 80 | 81 | } // anonymous namespace 82 | 83 | inline std::string pretty_print_onnx_to_string(::google::protobuf::Message const& message) 84 | { 85 | std::string s; 86 | ::google::protobuf::TextFormat::PrintToString(message, &s); 87 | remove_raw_data_strings(s); 88 | s = remove_repeated_data_strings(s); 89 | return s; 90 | } 91 | 92 | inline std::ostream& operator<<(std::ostream& stream, ::onnx::ModelProto const& message) 93 | { 94 | stream << pretty_print_onnx_to_string(message); 95 | return stream; 96 | } 97 | 98 | inline std::ostream& operator<<(std::ostream& stream, ::onnx::NodeProto const& message) 99 | { 100 | stream << pretty_print_onnx_to_string(message); 101 | return stream; 102 | } 103 | 104 | //... 105 | //...Consider moving all of the below functions into a stand alone 106 | //... 107 | 108 | inline bool ParseFromFile_WAR(google::protobuf::Message* msg, const char* filename) 109 | { 110 | 111 | std::ifstream stream(filename, std::ios::in | std::ios::binary); 112 | if (!stream) 113 | { 114 | std::cerr << "Could not open file " << std::string(filename) << std::endl; 115 | return false; 116 | } 117 | google::protobuf::io::IstreamInputStream rawInput(&stream); 118 | 119 | google::protobuf::io::CodedInputStream coded_input(&rawInput); 120 | // Note: This WARs the very low default size limit (64MB) 121 | coded_input.SetTotalBytesLimit(std::numeric_limits::max()); 122 | return msg->ParseFromCodedStream(&coded_input); 123 | } 124 | 125 | inline bool ParseFromTextFile(google::protobuf::Message* msg, const char* filename) 126 | { 127 | std::ifstream stream(filename, std::ios::in); 128 | if (!stream) 129 | { 130 | std::cerr << "Could not open file " << std::string(filename) << std::endl; 131 | return false; 132 | } 133 | 134 | google::protobuf::io::IstreamInputStream rawInput(&stream); 135 | 136 | return google::protobuf::TextFormat::Parse(&rawInput, msg); 137 | } 138 | 139 | inline std::string onnx_ir_version_string(int64_t ir_version = ::onnx::IR_VERSION) 140 | { 141 | int onnx_ir_major = ir_version / 1000000; 142 | int onnx_ir_minor = ir_version % 1000000 / 10000; 143 | int onnx_ir_patch = ir_version % 10000; 144 | return (std::to_string(onnx_ir_major) + "." + std::to_string(onnx_ir_minor) + "." + std::to_string(onnx_ir_patch)); 145 | } 146 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/readme.md: -------------------------------------------------------------------------------- 1 | # ONNX Parser 2 | - 这几个文件提取自官方的onnx-tensorrt,去掉python方面,其他都在 3 | - 另外增加了Plugin节点的支持 4 | - https://github.com/onnx/onnx-tensorrt -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/toposort.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | #include 9 | 10 | #include 11 | using std::cout; 12 | using std::cerr; 13 | using std::endl; 14 | 15 | namespace 16 | { 17 | 18 | enum NodeState 19 | { 20 | NODE_UNVISITED, 21 | NODE_ACTIVE, 22 | NODE_VISITED 23 | }; 24 | 25 | template 26 | bool get_post_order(size_t node_idx, Container const& nodes, std::unordered_map const& node_map, 27 | std::vector* node_states, std::vector* order) 28 | { 29 | NodeState& node_state = node_states->at(node_idx); 30 | if (node_state == NODE_ACTIVE) 31 | { 32 | // Cycle detected! 33 | cerr << "ERROR: Graph contains a cycle" << endl; 34 | return false; 35 | } 36 | else if (node_state == NODE_VISITED) 37 | { 38 | return true; 39 | } 40 | else 41 | { 42 | node_state = NODE_ACTIVE; 43 | // TODO: This .Get().input() is highly specific to protobuf, should 44 | // generalise it somehow. 45 | for (auto const& input : nodes.Get(node_idx).input()) 46 | { 47 | if (!node_map.count(input)) 48 | { 49 | // Input node not found in graph! 50 | // cerr << "ERROR: Input node not found in graph: " 51 | // << input << endl; 52 | // return false; 53 | continue; // Skip missing input edges 54 | } 55 | size_t input_node_idx = node_map.at(input); 56 | if (!get_post_order(input_node_idx, nodes, node_map, node_states, order)) 57 | { 58 | return false; 59 | } 60 | } 61 | node_state = NODE_VISITED; 62 | order->push_back(node_idx); 63 | } 64 | return true; 65 | } 66 | 67 | } // anonymous namespace 68 | 69 | template 70 | bool toposort(Container const& nodes, std::vector* order) 71 | { 72 | std::unordered_map node_map; 73 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 74 | { 75 | // TODO: This .Get().input() is highly specific to protobuf, should 76 | // generalise it somehow. 77 | for (auto const& output : nodes.Get(i).output()) 78 | { 79 | if (!node_map.emplace(output, i).second) 80 | { 81 | // Output name appears more than once in graph! 82 | cerr << "ERROR: Output name is not unique: " << output << endl; 83 | return false; 84 | } 85 | } 86 | } 87 | order->reserve(nodes.size()); 88 | std::vector node_states(nodes.size(), NODE_UNVISITED); 89 | for (size_t i = 0; i < (size_t) nodes.size(); ++i) 90 | { 91 | if (!get_post_order(i, nodes, node_map, &node_states, order)) 92 | { 93 | return false; 94 | } 95 | } 96 | return true; 97 | } 98 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/trt_utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include "Status.hpp" 8 | #include "TensorOrWeights.hpp" 9 | #include "onnx2trt.hpp" 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | namespace onnx2trt 17 | { 18 | 19 | inline int getDtypeSize(nvinfer1::DataType trtDtype) 20 | { 21 | switch (trtDtype) 22 | { 23 | case nvinfer1::DataType::kFLOAT: return 4; 24 | case nvinfer1::DataType::kINT8: return 1; 25 | case nvinfer1::DataType::kHALF: return 2; 26 | case nvinfer1::DataType::kINT32: 27 | return 4; 28 | // TRT does not support booleans as a native type, so we treat them like int32 values. 29 | case nvinfer1::DataType::kBOOL: 30 | return 4; 31 | // TODO: Some sort of error handling 32 | default: return -1; 33 | } 34 | } 35 | 36 | inline nvinfer1::Dims insert_dim(nvinfer1::Dims const& dims, int idx, int value) 37 | { 38 | assert(idx < dims.nbDims + 1); 39 | nvinfer1::Dims new_dims; 40 | new_dims.nbDims = dims.nbDims + 1; 41 | for (int i = 0; i < idx; ++i) 42 | { 43 | new_dims.d[i] = dims.d[i]; 44 | } 45 | new_dims.d[idx] = value; 46 | for (int i = idx + 1; i < new_dims.nbDims; ++i) 47 | { 48 | new_dims.d[i] = dims.d[i - 1]; 49 | } 50 | return new_dims; 51 | } 52 | 53 | inline nvinfer1::Dims remove_dim(nvinfer1::Dims const& dims, int idx) 54 | { 55 | assert(idx < dims.nbDims); 56 | nvinfer1::Dims new_dims; 57 | new_dims.nbDims = dims.nbDims - 1; 58 | for (int i = 0; i < idx; ++i) 59 | { 60 | new_dims.d[i] = dims.d[i]; 61 | } 62 | for (int i = idx; i < new_dims.nbDims; ++i) 63 | { 64 | new_dims.d[i] = dims.d[i + 1]; 65 | } 66 | // Special case for scalar result (i.e., there was only one dim originally) 67 | if (new_dims.nbDims == 0) 68 | { 69 | new_dims.nbDims = 1; 70 | new_dims.d[0] = 1; 71 | } 72 | return new_dims; 73 | } 74 | 75 | // Adds unitary dimensions on the left 76 | inline nvinfer1::Dims expand_dims(nvinfer1::Dims const& dims, int ndim_new) 77 | { 78 | assert(dims.nbDims <= ndim_new); 79 | nvinfer1::Dims new_dims; 80 | new_dims.nbDims = ndim_new; 81 | int j = 0; 82 | for (; j < ndim_new - dims.nbDims; ++j) 83 | { 84 | new_dims.d[j] = 1; 85 | } 86 | for (int i = 0; i < dims.nbDims; ++i, ++j) 87 | { 88 | new_dims.d[j] = dims.d[i]; 89 | } 90 | return new_dims; 91 | } 92 | 93 | inline nvinfer1::Permutation remove_first_dim(nvinfer1::Permutation const& perm) 94 | { 95 | assert(perm.order[0] == 0); 96 | nvinfer1::Permutation new_perm; 97 | int ndim = nvinfer1::Dims::MAX_DIMS; 98 | for (int i = 0; i < ndim - 1; ++i) 99 | { 100 | new_perm.order[i] = perm.order[i + 1] - 1; 101 | } 102 | return new_perm; 103 | } 104 | 105 | inline nvinfer1::Dims squeeze_trailing_dims(nvinfer1::Dims const& dims) 106 | { 107 | nvinfer1::Dims new_dims = dims; 108 | // Note: TRT requires at least one dimension, so we don't squeeze [1]->[] 109 | while (new_dims.nbDims > 1 && new_dims.d[new_dims.nbDims - 1] == 1) 110 | { 111 | --new_dims.nbDims; 112 | } 113 | return new_dims; 114 | } 115 | 116 | inline nvinfer1::Dims squeeze_leading_dims(const nvinfer1::Dims& dims) 117 | { 118 | nvinfer1::Dims newDims; 119 | // Copy dims only if a non-1 has been seen already. 120 | bool non1Seen{false}; 121 | newDims.nbDims = std::copy_if(dims.d, dims.d + dims.nbDims, newDims.d, 122 | [&non1Seen](int x) { 123 | non1Seen = (x != 1) ? true : non1Seen; 124 | return non1Seen; 125 | }) 126 | - newDims.d; 127 | return newDims; 128 | } 129 | 130 | inline nvinfer1::DimsHW operator-(nvinfer1::DimsHW dims) 131 | { 132 | return nvinfer1::DimsHW(-dims.h(), -dims.w()); 133 | } 134 | 135 | // Note: These are used for checking beg_padding == end_padding 136 | inline bool operator==(nvinfer1::Dims const& a, nvinfer1::Dims const& b) 137 | { 138 | if (a.nbDims != b.nbDims) 139 | { 140 | return false; 141 | } 142 | for (int i = 0; i < a.nbDims; ++i) 143 | { 144 | if (a.d[i] != b.d[i]) 145 | { 146 | return false; 147 | } 148 | } 149 | return true; 150 | } 151 | inline bool operator!=(nvinfer1::Dims const& a, nvinfer1::Dims const& b) 152 | { 153 | return !(a == b); 154 | } 155 | 156 | inline TensorOrWeights identity(IImporterContext* ctx, TensorOrWeights input) 157 | { 158 | if (input.is_weights()) 159 | { 160 | return input; 161 | } 162 | else 163 | { 164 | auto* layer = ctx->network()->addIdentity(input.tensor()); 165 | if (!layer) 166 | { 167 | return nullptr; 168 | } 169 | return layer->getOutput(0); 170 | } 171 | } 172 | 173 | inline ::onnx::TensorProto_DataType trtDataTypeToONNX(nvinfer1::DataType dt) 174 | { 175 | switch (dt) 176 | { 177 | case nvinfer1::DataType::kFLOAT: return ::onnx::TensorProto::FLOAT; 178 | case nvinfer1::DataType::kHALF: return ::onnx::TensorProto::FLOAT16; 179 | case nvinfer1::DataType::kINT32: return ::onnx::TensorProto::INT32; 180 | case nvinfer1::DataType::kINT8: return ::onnx::TensorProto::INT8; 181 | case nvinfer1::DataType::kBOOL: return ::onnx::TensorProto::BOOL; 182 | default: return ::onnx::TensorProto_DataType_UNDEFINED; 183 | } 184 | throw std::runtime_error{"Unreachable"}; 185 | } 186 | 187 | } // namespace onnx2trt 188 | -------------------------------------------------------------------------------- /src/tensorRT/onnx_parser/utils.hpp: -------------------------------------------------------------------------------- 1 | /* 2 | * SPDX-License-Identifier: Apache-2.0 3 | */ 4 | 5 | #pragma once 6 | 7 | #include 8 | 9 | template 10 | using string_map = std::unordered_map; 11 | -------------------------------------------------------------------------------- /src/tensorRT/onnxplugin/plugin_binary_io.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "plugin_binary_io.hpp" 3 | #include "ilogger.hpp" 4 | #include 5 | 6 | namespace Plugin{ 7 | 8 | using namespace std; 9 | 10 | BinIO::~BinIO(){ 11 | close(); 12 | } 13 | 14 | bool BinIO::opened(){ 15 | if (flag_ == MemoryRead) 16 | return memoryRead_ != nullptr; 17 | else if (flag_ == MemoryWrite) 18 | return true; 19 | return false; 20 | } 21 | 22 | void BinIO::close(){ 23 | if (flag_ == MemoryRead) { 24 | memoryRead_ = nullptr; 25 | memoryCursor_ = 0; 26 | memoryLength_ = -1; 27 | } 28 | else if (flag_ == MemoryWrite) { 29 | memoryWrite_.clear(); 30 | memoryCursor_ = 0; 31 | memoryLength_ = -1; 32 | } 33 | } 34 | 35 | string BinIO::readData(int numBytes){ 36 | string output; 37 | output.resize(numBytes); 38 | 39 | int readlen = read((void*)output.data(), output.size()); 40 | output.resize(readlen); 41 | return output; 42 | } 43 | 44 | int BinIO::read(void* pdata, size_t length){ 45 | 46 | if (flag_ == MemoryRead) { 47 | if (memoryLength_ != -1) { 48 | 49 | if (memoryLength_ < memoryCursor_ + length) { 50 | int remain = memoryLength_ - memoryCursor_; 51 | if (remain > 0) { 52 | memcpy(pdata, memoryRead_ + memoryCursor_, remain); 53 | memoryCursor_ += remain; 54 | return remain; 55 | } 56 | else { 57 | return -1; 58 | } 59 | } 60 | } 61 | memcpy(pdata, memoryRead_ + memoryCursor_, length); 62 | memoryCursor_ += length; 63 | return length; 64 | } 65 | else { 66 | return -1; 67 | } 68 | } 69 | 70 | bool BinIO::eof(){ 71 | if (!opened()) return true; 72 | 73 | if (flag_ == MemoryRead){ 74 | return this->memoryCursor_ >= this->memoryLength_; 75 | } 76 | else if (flag_ == MemoryWrite){ 77 | return false; 78 | } 79 | else { 80 | opstate_ = false; 81 | INFO("Unsupport flag: %d", flag_); 82 | return true; 83 | } 84 | } 85 | 86 | int BinIO::write(const void* pdata, size_t length){ 87 | 88 | if (flag_ == MemoryWrite) { 89 | memoryWrite_.append((char*)pdata, (char*)pdata + length); 90 | return length; 91 | } 92 | else { 93 | return -1; 94 | } 95 | } 96 | 97 | int BinIO::writeData(const string& data){ 98 | return write(data.data(), data.size()); 99 | } 100 | 101 | BinIO& BinIO::operator >> (string& value){ 102 | //read 103 | int length = 0; 104 | (*this) >> length; 105 | value = readData(length); 106 | return *this; 107 | } 108 | 109 | int BinIO::readInt(){ 110 | int value = 0; 111 | (*this) >> value; 112 | return value; 113 | } 114 | 115 | float BinIO::readFloat(){ 116 | float value = 0; 117 | (*this) >> value; 118 | return value; 119 | } 120 | 121 | BinIO& BinIO::operator << (const string& value){ 122 | //write 123 | (*this) << (int)value.size(); 124 | writeData(value); 125 | return *this; 126 | } 127 | 128 | BinIO& BinIO::operator << (const char* value){ 129 | 130 | int length = strlen(value); 131 | (*this) << (int)length; 132 | write(value, length); 133 | return *this; 134 | } 135 | 136 | BinIO& BinIO::operator << (const vector& value){ 137 | (*this) << (int)value.size(); 138 | for (int i = 0; i < value.size(); ++i){ 139 | (*this) << value[i]; 140 | } 141 | return *this; 142 | } 143 | 144 | BinIO& BinIO::operator >> (vector& value){ 145 | int num; 146 | (*this) >> num; 147 | 148 | value.resize(num); 149 | for (int i = 0; i < value.size(); ++i) 150 | (*this) >> value[i]; 151 | return *this; 152 | } 153 | 154 | bool BinIO::openMemoryRead(const void* ptr, int memoryLength) { 155 | close(); 156 | 157 | if (!ptr) return false; 158 | memoryRead_ = (const char*)ptr; 159 | memoryCursor_ = 0; 160 | memoryLength_ = memoryLength; 161 | flag_ = MemoryRead; 162 | return true; 163 | } 164 | 165 | void BinIO::openMemoryWrite() { 166 | close(); 167 | 168 | memoryWrite_.clear(); 169 | memoryCursor_ = 0; 170 | memoryLength_ = -1; 171 | flag_ = MemoryWrite; 172 | } 173 | 174 | }; // namespace Plugin -------------------------------------------------------------------------------- /src/tensorRT/onnxplugin/plugin_binary_io.hpp: -------------------------------------------------------------------------------- 1 | #ifndef PLUGIN_BINARY_IO_HPP 2 | #define PLUGIN_BINARY_IO_HPP 3 | 4 | #include 5 | #include 6 | 7 | namespace Plugin{ 8 | 9 | class BinIO { 10 | public: 11 | enum Head { 12 | MemoryRead = 1, 13 | MemoryWrite = 2 14 | }; 15 | 16 | BinIO() { openMemoryWrite(); } 17 | BinIO(const void* ptr, int memoryLength = -1) { openMemoryRead(ptr, memoryLength); } 18 | virtual ~BinIO(); 19 | bool opened(); 20 | bool openMemoryRead(const void* ptr, int memoryLength = -1); 21 | void openMemoryWrite(); 22 | const std::string& writedMemory() { return memoryWrite_; } 23 | void close(); 24 | int write(const void* pdata, size_t length); 25 | int writeData(const std::string& data); 26 | int read(void* pdata, size_t length); 27 | std::string readData(int numBytes); 28 | int readInt(); 29 | float readFloat(); 30 | bool eof(); 31 | 32 | BinIO& operator >> (std::string& value); 33 | BinIO& operator << (const std::string& value); 34 | BinIO& operator << (const char* value); 35 | BinIO& operator << (const std::vector& value); 36 | BinIO& operator >> (std::vector& value); 37 | 38 | template 39 | BinIO& operator >> (std::vector<_T>& value) { 40 | int length = 0; 41 | (*this) >> length; 42 | 43 | value.resize(length); 44 | read(value.data(), length * sizeof(_T)); 45 | return *this; 46 | } 47 | 48 | template 49 | BinIO& operator << (const std::vector<_T>& value) { 50 | (*this) << (int)value.size(); 51 | write(value.data(), sizeof(_T) * value.size()); 52 | return *this; 53 | } 54 | 55 | template 56 | BinIO& operator >> (_T& value) { 57 | read(&value, sizeof(_T)); 58 | return *this; 59 | } 60 | 61 | template 62 | BinIO& operator << (const _T& value) { 63 | write(&value, sizeof(_T)); 64 | return *this; 65 | } 66 | 67 | bool opstate() const { 68 | return opstate_; 69 | } 70 | 71 | private: 72 | size_t readModeEndSEEK_ = 0; 73 | std::string memoryWrite_; 74 | const char* memoryRead_ = nullptr; 75 | int memoryCursor_ = 0; 76 | int memoryLength_ = -1; 77 | Head flag_ = MemoryWrite; 78 | bool opstate_ = true; 79 | }; 80 | }; // namespace Plugin 81 | 82 | #endif //PLUGIN_BINARY_IO_HPP -------------------------------------------------------------------------------- /src/tensorRT/onnxplugin/plugins/HSigmoid.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | using namespace ONNXPlugin; 6 | 7 | static __global__ void hsigmoid_kernel_fp32(float* input, float* output, int edge) { 8 | 9 | KernelPositionBlock; 10 | float x = input[position]; 11 | float a = x + 3; 12 | a = a < 0 ? 0 : (a >= 6 ? 6 : a); 13 | output[position] = a / 6; 14 | } 15 | 16 | static __global__ void hsigmoid_kernel_fp16(__half* input, __half* output, int edge) { 17 | 18 | KernelPositionBlock; 19 | 20 | __half _six = 6.0f; 21 | __half x = input[position]; 22 | 23 | __half a = x + __half(3.0f); 24 | __half _zero = 0.0f; 25 | a = a < _zero ? _zero : (a >= _six ? _six : a); 26 | output[position] = a / _six; 27 | } 28 | 29 | class HSigmoid : public TRTPlugin { 30 | public: 31 | SetupPlugin(HSigmoid); 32 | 33 | virtual void config_finish() override{ 34 | 35 | // INFO("init hsigmoid config: %s", config_->info_.c_str()); 36 | // INFO("weights = %d", config_->weights_.size()); 37 | // for(int i = 0; i < config_->weights_.size(); ++i){ 38 | // auto& w = config_->weights_[i]; 39 | // if(w->type() == TRT::DataType::Float16){ 40 | // INFO("Weight[%d] shape is %s, dtype = %s, value[0] = %f", i, w->shape_string(), data_type_string(w->type()), float(w->at<__half>(0))); 41 | // }else{ 42 | // INFO("Weight[%d] shape is %s, dtype = %s, value[0] = %f", i, w->shape_string(), data_type_string(w->type()), w->at(0)); 43 | // } 44 | // } 45 | } 46 | 47 | virtual std::shared_ptr new_config() override{ 48 | auto cfg = TRTPlugin::new_config(); 49 | 50 | cfg->support_dtype_set_ = {nvinfer1::DataType::kHALF, nvinfer1::DataType::kFLOAT}; 51 | //cfg->support_dtype_set_ = {nvinfer1::DataType::kFLOAT}; 52 | return cfg; 53 | } 54 | 55 | virtual nvinfer1::DimsExprs getOutputDimensions( 56 | int32_t outputIndex, const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override{ 57 | 58 | return inputs[0]; 59 | } 60 | 61 | int enqueue(const std::vector& inputs, std::vector& outputs, const std::vector& weights, void* workspace, cudaStream_t stream) override{ 62 | 63 | int count = inputs[0].count(); 64 | auto grid = CUDATools::grid_dims(count); 65 | auto block = CUDATools::block_dims(count); 66 | 67 | if (config_->usage_dtype_ == TRT::DataType::Float) { 68 | hsigmoid_kernel_fp32 <<>> (inputs[0].ptr(), outputs[0].ptr(), count); 69 | } 70 | else if (config_->usage_dtype_ == TRT::DataType::Float16) { 71 | hsigmoid_kernel_fp16 <<>> (inputs[0].ptr<__half>(), outputs[0].ptr<__half>(), count); 72 | } 73 | else{ 74 | INFOF("not implement function"); 75 | } 76 | return 0; 77 | } 78 | }; 79 | 80 | RegisterPlugin(HSigmoid); -------------------------------------------------------------------------------- /src/tensorRT/onnxplugin/plugins/HSwish.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | using namespace ONNXPlugin; 6 | 7 | static __global__ void hswish_kernel_fp32(float* input, float* output, int edge) { 8 | 9 | KernelPositionBlock; 10 | float x = input[position]; 11 | float a = x + 3; 12 | a = a < 0 ? 0 : (a >= 6 ? 6 : a); 13 | output[position] = x * a / 6; 14 | } 15 | 16 | static __global__ void hswish_kernel_fp16(__half* input, __half* output, int edge) { 17 | 18 | KernelPositionBlock; 19 | 20 | __half _six = 6.0f; 21 | __half x = input[position]; 22 | 23 | __half a = x + __half(3.0f); 24 | __half _zero = 0.0f; 25 | a = a < _zero ? _zero : (a >= _six ? _six : a); 26 | output[position] = x * a / _six; 27 | } 28 | 29 | class HSwish : public TRTPlugin { 30 | public: 31 | SetupPlugin(HSwish); 32 | 33 | virtual void config_finish() override{ 34 | 35 | // INFO("init hswish config: %s", config_->info_.c_str()); 36 | // INFO("weights = %d", config_->weights_.size()); 37 | // for(int i = 0; i < config_->weights_.size(); ++i){ 38 | // auto& w = config_->weights_[i]; 39 | // if(w->type() == TRT::DataType::Float16){ 40 | // INFO("Weight[%d] shape is %s, dtype = %s, value[0] = %f", i, w->shape_string(), data_type_string(w->type()), float(w->at<__half>(0))); 41 | // }else{ 42 | // INFO("Weight[%d] shape is %s, dtype = %s, value[0] = %f", i, w->shape_string(), data_type_string(w->type()), w->at(0)); 43 | // } 44 | // } 45 | } 46 | 47 | virtual std::shared_ptr new_config() override{ 48 | auto cfg = TRTPlugin::new_config(); 49 | 50 | //cfg->support_dtype_set_ = {nvinfer1::DataType::kHALF, nvinfer1::DataType::kFLOAT}; 51 | cfg->support_dtype_set_ = {nvinfer1::DataType::kFLOAT}; 52 | return cfg; 53 | } 54 | 55 | virtual nvinfer1::DimsExprs getOutputDimensions( 56 | int32_t outputIndex, const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override{ 57 | 58 | return inputs[0]; 59 | } 60 | 61 | int enqueue(const std::vector& inputs, std::vector& outputs, const std::vector& weights, void* workspace, cudaStream_t stream) override{ 62 | 63 | int count = inputs[0].count(); 64 | auto grid = CUDATools::grid_dims(count); 65 | auto block = CUDATools::block_dims(count); 66 | 67 | if (config_->usage_dtype_ == TRT::DataType::Float) { 68 | hswish_kernel_fp32 <<>> (inputs[0].ptr(), outputs[0].ptr(), count); 69 | } 70 | else if (config_->usage_dtype_ == TRT::DataType::Float16) { 71 | hswish_kernel_fp16 <<>> (inputs[0].ptr<__half>(), outputs[0].ptr<__half>(), count); 72 | } 73 | else{ 74 | INFOF("not implement function"); 75 | } 76 | return 0; 77 | } 78 | }; 79 | 80 | RegisterPlugin(HSwish); -------------------------------------------------------------------------------- /workspace/exp/fall_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shouxieai/hard_decode_trt/443ccd6d85c81742c45ba5733d0d676755fa7248/workspace/exp/fall_video.mp4 -------------------------------------------------------------------------------- /workspace/exp/number100.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shouxieai/hard_decode_trt/443ccd6d85c81742c45ba5733d0d676755fa7248/workspace/exp/number100.mp4 --------------------------------------------------------------------------------