├── README.md ├── YoloLayer_TRT_v7.0 ├── CMakeLists.txt ├── macros.h ├── script │ └── add_custom_yolo_op.py ├── yololayer.cu └── yololayer.h ├── assets ├── 1.jpg ├── 2.jpg └── _result.jpg └── yolov7-pose ├── CMakeLists.txt ├── logging.h ├── macros.h ├── person.jpg ├── utils.h ├── yololayer.cu ├── yololayer.h └── yolov7_pose.cpp /README.md: -------------------------------------------------------------------------------- 1 | 10 | 11 | # Pose detection base on Yolov7 Deploy TensorRT :two_hearts: :collision: 12 | 13 | This project base on https://github.com/WongKinYiu/yolov7 14 | 15 | 16 | # System Requirements 17 | 18 | cuda 11.4 19 | 20 | TensorRT 8+ 21 | 22 | OpenCV 4.0+ (build with opencv-contrib module) [how to build](https://gist.github.com/nanmi/c5cc1753ed98d7e3482031fc379a3f3d#%E6%BA%90%E7%A0%81%E7%BC%96%E8%AF%91gpu%E7%89%88opencv) 23 | 24 | # Export onnx model 25 | Need to shield reshap and permute operators like this in the keypoint-related code `class IKeypoint(nn.Module)` 26 | ```python 27 | def forward(self, x): 28 | # x = x.copy() # for profiling 29 | z = [] # inference output 30 | self.training |= self.export 31 | for i in range(self.nl): 32 | if self.nkpt is None or self.nkpt==0: 33 | x[i] = self.im[i](self.m[i](self.ia[i](x[i]))) # conv 34 | else : 35 | x[i] = torch.cat((self.im[i](self.m[i](self.ia[i](x[i]))), self.m_kpt[i](x[i])), axis=1) 36 | 37 | if not self.training: # inference <------ new add 38 | bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) 39 | x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() 40 | x_det = x[i][..., :6] 41 | x_kpt = x[i][..., 6:] 42 | 43 | if not self.training: # inference 44 | if self.grid[i].shape[2:4] != x[i].shape[2:4]: 45 | self.grid[i] = self._make_grid(nx, ny).to(x[i].device) 46 | kpt_grid_x = self.grid[i][..., 0:1] 47 | kpt_grid_y = self.grid[i][..., 1:2] 48 | ... 49 | ``` 50 | 51 | use this script to export onnx model. 52 | ```python 53 | import sys 54 | sys.path.append('./') # to run '$ python *.py' files in subdirectories 55 | import torch 56 | import torch.nn as nn 57 | import models 58 | from models.experimental import attempt_load 59 | from utils.activations import Hardswish, SiLU 60 | 61 | # Load PyTorch model 62 | weights = 'yolov7-w6-pose.pt' 63 | device = torch.device('cuda:0') 64 | model = attempt_load(weights, map_location=device) # load FP32 model 65 | 66 | # Update model 67 | for k, m in model.named_modules(): 68 | m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility 69 | if isinstance(m, models.common.Conv): # assign export-friendly activations 70 | if isinstance(m.act, nn.Hardswish): 71 | m.act = Hardswish() 72 | elif isinstance(m.act, nn.SiLU): 73 | m.act = SiLU() 74 | model.model[-1].export = True # set Detect() layer grid export 75 | model.eval() 76 | 77 | # Input 78 | img = torch.randn(1, 3, 960, 960).to(device) # image size(1,3,320,192) iDetection 79 | torch.onnx.export(model, img, 'yolov7-w6-pose.onnx', verbose=False, opset_version=12, input_names=['images']) 80 | ``` 81 | 82 | you will get: 83 | ![](assets/1.jpg) 84 | 85 | use `YoloLayer_TRT_v7.0/script/add_custom_yolo_op.py` to add a new op lookes like this: 86 | 87 | ![](assets/2.jpg) 88 | 89 | 90 | # Build yolo layer tensorrt plugin 91 | 92 | ```shell 93 | cd {this repo}/YoloLayer_TRT_v7.0 94 | mkdir build && cd build 95 | cmake .. && make 96 | ``` 97 | 98 | generate `libyolo.so` when build successfully. 99 | 100 | # Build TensorRT engine 101 | 102 | ```shell 103 | cd {this repo}/ 104 | 105 | trtexec --onnx=yolov7-w6-pose-sim-yolo.onnx --fp16 --saveEngine=yolov7-w6-pose-sim-yolo-fp16.engine --plugins={this repo}/YoloLayer_TRT_v7.0/build/libyolo.so 106 | ``` 107 | 108 | wait a long time :satisfied: 109 | 110 | TensorRT engine is generated successfully. 111 | 112 | 113 | # Inference 114 | 115 | ```shell 116 | cd {this repo}/yolov7-pose/ 117 | mkdir build && cd build 118 | cmake .. && make 119 | 120 | # Inference test 121 | cd {this repo}/yolov7-pose/build/ 122 | ./yolov7_pose {your build engine} -i ../person.jpg 123 | ``` 124 | 125 | # Result 126 | ![](assets/_result.jpg) 127 | 128 | # About License 129 | 130 | For the 3rd-party module and TensorRT, you need to follow their license 131 | 132 | For the part I wrote, you can do anything you want 133 | 134 | -------------------------------------------------------------------------------- /YoloLayer_TRT_v7.0/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(plugin_build_example) 4 | 5 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED") 6 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 7 | set(CMAKE_BUILD_TYPE Release) 8 | 9 | # cuda/cudnn 10 | find_package(CUDA REQUIRED) 11 | include_directories(${CUDA_INCLUDE_DIRS}) 12 | 13 | # # tensorrt 14 | set(TENSORRT_INCLUDE_DIR /usr/local/TensorRT-8.4.1.5/include/) 15 | set(TENSORRT_LIBRARY_DIR /usr/local/TensorRT-8.4.1.5/lib/) 16 | include_directories(${TENSORRT_INCLUDE_DIR}) 17 | link_directories(${TENSORRT_LIBRARY_DIR}) 18 | 19 | cuda_add_library(yolo SHARED ${PROJECT_SOURCE_DIR}/yololayer.cu 20 | ) 21 | target_link_libraries(yolo nvinfer ${CUDA_LIBRARIES}) 22 | 23 | add_definitions(-O2 -pthread) 24 | 25 | -------------------------------------------------------------------------------- /YoloLayer_TRT_v7.0/macros.h: -------------------------------------------------------------------------------- 1 | #ifndef __MACROS_H 2 | #define __MACROS_H 3 | 4 | #ifdef API_EXPORTS 5 | #if defined(_MSC_VER) 6 | #define API __declspec(dllexport) 7 | #else 8 | #define API __attribute__((visibility("default"))) 9 | #endif 10 | #else 11 | 12 | #if defined(_MSC_VER) 13 | #define API __declspec(dllimport) 14 | #else 15 | #define API 16 | #endif 17 | #endif // API_EXPORTS 18 | 19 | #if NV_TENSORRT_MAJOR >= 8 20 | #define TRT_NOEXCEPT noexcept 21 | #define TRT_CONST_ENQUEUE const 22 | #else 23 | #define TRT_NOEXCEPT 24 | #define TRT_CONST_ENQUEUE 25 | #endif 26 | 27 | #endif // __MACROS_H -------------------------------------------------------------------------------- /YoloLayer_TRT_v7.0/script/add_custom_yolo_op.py: -------------------------------------------------------------------------------- 1 | import onnx_graphsurgeon as gs 2 | import numpy as np 3 | import onnx 4 | 5 | 6 | # gs load a graph 7 | graph = gs.import_onnx(onnx.load("yolov7-w6-pose-sim.onnx")) 8 | 9 | # Since we already know the names of the tensors we're interested in, we can 10 | # grab them directly from the tensor map. 11 | # 12 | # NOTE: If you do not know the tensor names you want, you can view the graph in 13 | # Netron to determine them, or use ONNX GraphSurgeon in an interactive shell 14 | # to print the graph. 15 | tensors = graph.tensors() 16 | 17 | # If you want to embed shape information, but cannot use ONNX shape inference, 18 | # you can manually modify the tensors at this point: 19 | # 20 | # IMPORTANT: You must include type information for input and output tensors if it is not already 21 | # present in the graph. 22 | # 23 | # NOTE: ONNX GraphSurgeon will also accept dynamic shapes - simply set the corresponding 24 | # dimension(s) to `gs.Tensor.DYNAMIC`, e.g. `shape=(gs.Tensor.DYNAMIC, 3, 224, 224)` 25 | inputs = [tensors["745"].to_variable(dtype=np.float32), 26 | tensors["783"].to_variable(dtype=np.float32), 27 | tensors["821"].to_variable(dtype=np.float32), 28 | tensors["859"].to_variable(dtype=np.float32)] 29 | 30 | # Add a output tensor of new graph 31 | modified_output = gs.Variable(name="output0", dtype=np.float32, shape=(57001, 1, 1)) 32 | 33 | # Add a new node that you want 34 | new_node = gs.Node(op="YoloLayer_TRT", name="YoloLayer_TRT_0", inputs=inputs, outputs=[modified_output]) 35 | 36 | # append into graph 37 | graph.nodes.append(new_node) 38 | graph.outputs = [modified_output] 39 | 40 | graph.cleanup().toposort() 41 | 42 | # gs save a graph 43 | onnx.save(gs.export_onnx(graph), "yolov7-w6-pose-sim-yolo.onnx") 44 | -------------------------------------------------------------------------------- /YoloLayer_TRT_v7.0/yololayer.cu: -------------------------------------------------------------------------------- 1 | #include "yololayer.h" 2 | 3 | namespace Tn 4 | { 5 | template 6 | void write(char*& buffer, const T& val) 7 | { 8 | *reinterpret_cast(buffer) = val; 9 | buffer += sizeof(T); 10 | } 11 | 12 | template 13 | void read(const char*& buffer, T& val) 14 | { 15 | val = *reinterpret_cast(buffer); 16 | buffer += sizeof(T); 17 | } 18 | } 19 | 20 | #define CUDA_CHECK(callstr)\ 21 | {\ 22 | cudaError_t error_code = callstr;\ 23 | if (error_code != cudaSuccess) {\ 24 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ 25 | assert(0);\ 26 | }\ 27 | } 28 | 29 | using namespace Yolo; 30 | 31 | namespace nvinfer1 32 | { 33 | YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel) 34 | { 35 | mClassCount = classCount; 36 | mYoloV5NetWidth = netWidth; 37 | mYoloV5NetHeight = netHeight; 38 | mMaxOutObject = maxOut; 39 | mYoloKernel = vYoloKernel; 40 | mKernelCount = vYoloKernel.size(); 41 | 42 | CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); 43 | size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; 44 | for (int ii = 0; ii < mKernelCount; ii++) 45 | { 46 | CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); 47 | const auto& yolo = mYoloKernel[ii]; 48 | CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); 49 | } 50 | } 51 | YoloLayerPlugin::~YoloLayerPlugin() 52 | { 53 | for (int ii = 0; ii < mKernelCount; ii++) 54 | { 55 | CUDA_CHECK(cudaFree(mAnchor[ii])); 56 | } 57 | CUDA_CHECK(cudaFreeHost(mAnchor)); 58 | } 59 | 60 | // create the plugin at runtime from a byte stream 61 | YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) 62 | { 63 | using namespace Tn; 64 | const char *d = reinterpret_cast(data), *a = d; 65 | read(d, mClassCount); 66 | read(d, mThreadCount); 67 | read(d, mKernelCount); 68 | read(d, mYoloV5NetWidth); 69 | read(d, mYoloV5NetHeight); 70 | read(d, mMaxOutObject); 71 | mYoloKernel.resize(mKernelCount); 72 | auto kernelSize = mKernelCount * sizeof(YoloKernel); 73 | memcpy(mYoloKernel.data(), d, kernelSize); 74 | d += kernelSize; 75 | CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); 76 | size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; 77 | for (int ii = 0; ii < mKernelCount; ii++) 78 | { 79 | CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); 80 | const auto& yolo = mYoloKernel[ii]; 81 | CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); 82 | } 83 | assert(d == a + length); 84 | } 85 | 86 | void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT 87 | { 88 | using namespace Tn; 89 | char* d = static_cast(buffer), *a = d; 90 | write(d, mClassCount); 91 | write(d, mThreadCount); 92 | write(d, mKernelCount); 93 | write(d, mYoloV5NetWidth); 94 | write(d, mYoloV5NetHeight); 95 | write(d, mMaxOutObject); 96 | auto kernelSize = mKernelCount * sizeof(YoloKernel); 97 | memcpy(d, mYoloKernel.data(), kernelSize); 98 | d += kernelSize; 99 | 100 | assert(d == a + getSerializationSize()); 101 | } 102 | 103 | size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT 104 | { 105 | return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject); 106 | } 107 | 108 | int YoloLayerPlugin::initialize() TRT_NOEXCEPT 109 | { 110 | return 0; 111 | } 112 | 113 | Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT 114 | { 115 | //output the result to channel 116 | int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float); 117 | 118 | return Dims3(totalsize + 1, 1, 1); 119 | } 120 | 121 | // Set plugin namespace 122 | void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT 123 | { 124 | mPluginNamespace = pluginNamespace; 125 | } 126 | 127 | const char* YoloLayerPlugin::getPluginNamespace() const TRT_NOEXCEPT 128 | { 129 | return mPluginNamespace; 130 | } 131 | 132 | // Return the DataType of the plugin output at the requested index 133 | DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT 134 | { 135 | return DataType::kFLOAT; 136 | } 137 | 138 | // Return true if output tensor is broadcast across a batch. 139 | bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT 140 | { 141 | return false; 142 | } 143 | 144 | // Return true if plugin can use input that is broadcast across batch without replication. 145 | bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT 146 | { 147 | return false; 148 | } 149 | 150 | void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT 151 | { 152 | } 153 | 154 | // Attach the plugin object to an execution context and grant the plugin the access to some context resource. 155 | void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT 156 | { 157 | } 158 | 159 | // Detach the plugin object from its execution context. 160 | void YoloLayerPlugin::detachFromContext() TRT_NOEXCEPT {} 161 | 162 | const char* YoloLayerPlugin::getPluginType() const TRT_NOEXCEPT 163 | { 164 | return "YoloLayer_TRT"; 165 | } 166 | 167 | const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT 168 | { 169 | return "1"; 170 | } 171 | 172 | void YoloLayerPlugin::destroy() TRT_NOEXCEPT 173 | { 174 | delete this; 175 | } 176 | 177 | // Clone the plugin 178 | IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT 179 | { 180 | YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel); 181 | p->setPluginNamespace(mPluginNamespace); 182 | return p; 183 | } 184 | 185 | __device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); }; 186 | 187 | __global__ void CalDetection(const float *input, float *output, int noElements, 188 | const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem) 189 | { 190 | 191 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 192 | if (idx >= noElements) return; 193 | 194 | int total_grid = yoloWidth * yoloHeight; 195 | int bnIdx = idx / total_grid; 196 | idx = idx - total_grid * bnIdx; 197 | int info_len_i = 5 + classes; 198 | int info_len_kpt = KEY_POINTS_NUM * 3; 199 | const float* curInput = input + bnIdx * ((info_len_i + info_len_kpt) * total_grid * CHECK_COUNT); 200 | 201 | for (int k = 0; k < CHECK_COUNT; ++k) { 202 | float box_prob = Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 4 * total_grid]); 203 | if (box_prob < IGNORE_THRESH) continue; 204 | int class_id = 0; //person class 205 | // float max_cls_prob = 0.0; 206 | // for (int i = 5; i < info_len_i; ++i) { 207 | // float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]); 208 | // if (p > max_cls_prob) { 209 | // max_cls_prob = p; 210 | // class_id = i - 5; 211 | // } 212 | // } 213 | float max_cls_prob = Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 5 * total_grid]); 214 | 215 | float *res_count = output + bnIdx * outputElem; 216 | int count = (int)atomicAdd(res_count, 1); 217 | if (count >= maxoutobject) return; 218 | char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection); 219 | Detection *det = (Detection*)(data); 220 | 221 | int row = idx / yoloWidth; 222 | int col = idx % yoloWidth; 223 | 224 | //------------bboxs------------ 225 | //Location 226 | // pytorch: 227 | // y = x[i].sigmoid() 228 | // xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy 229 | // wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh 230 | 231 | det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 0 * total_grid])) * netwidth / yoloWidth; 232 | det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 1 * total_grid])) * netheight / yoloHeight; 233 | 234 | // W: (Pw * e^tw) / FeaturemapW * netwidth 235 | // v5: https://github.com/ultralytics/yolov5/issues/471 236 | det->bbox[2] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 2 * total_grid]); 237 | det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k]; 238 | det->bbox[3] = 2.0f * Logist(curInput[idx + k * info_len_i * total_grid + 3 * total_grid]); 239 | det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1]; 240 | det->conf = box_prob * max_cls_prob; 241 | det->class_id = class_id; 242 | 243 | //------------keypoints------------ 244 | // Location 245 | //pytorch: 246 | // x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy 247 | // x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy 248 | // x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid() 249 | for (int kpt_idx = 0; kpt_idx < KEY_POINTS_NUM; ++kpt_idx) 250 | { 251 | det->kpts[kpt_idx].x = (col - 0.5f + 2.0f * (curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + (6 + kpt_idx*3) * total_grid]) ) * netwidth / yoloWidth; 252 | det->kpts[kpt_idx].y = (row - 0.5f + 2.0f * (curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + (7 + kpt_idx*3) * total_grid]) ) * netheight / yoloHeight; 253 | det->kpts[kpt_idx].kpt_conf = Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + (8 + kpt_idx*3) * total_grid]); 254 | } 255 | 256 | } 257 | } 258 | 259 | void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize) 260 | { 261 | int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float); 262 | for (int idx = 0; idx < batchSize; ++idx) { 263 | CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream)); 264 | } 265 | int numElem = 0; 266 | for (unsigned int i = 0; i < mYoloKernel.size(); ++i) { 267 | const auto& yolo = mYoloKernel[i]; 268 | numElem = yolo.width * yolo.height * batchSize; 269 | if (numElem < mThreadCount) mThreadCount = numElem; 270 | 271 | //printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight); 272 | CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> > 273 | (inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem); 274 | } 275 | } 276 | 277 | 278 | int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT 279 | { 280 | forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize); 281 | return 0; 282 | } 283 | 284 | PluginFieldCollection YoloPluginCreator::mFC{}; 285 | std::vector YoloPluginCreator::mPluginAttributes; 286 | 287 | YoloPluginCreator::YoloPluginCreator() 288 | { 289 | mPluginAttributes.clear(); 290 | 291 | mFC.nbFields = mPluginAttributes.size(); 292 | mFC.fields = mPluginAttributes.data(); 293 | } 294 | 295 | const char* YoloPluginCreator::getPluginName() const TRT_NOEXCEPT 296 | { 297 | return "YoloLayer_TRT"; 298 | } 299 | 300 | const char* YoloPluginCreator::getPluginVersion() const TRT_NOEXCEPT 301 | { 302 | return "1"; 303 | } 304 | 305 | const PluginFieldCollection* YoloPluginCreator::getFieldNames() TRT_NOEXCEPT 306 | { 307 | return &mFC; 308 | } 309 | 310 | IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT 311 | { 312 | // int class_count = 1;//p_netinfo[0]; 313 | // int input_w = 960;//p_netinfo[1]; 314 | // int input_h = 960;//p_netinfo[2]; 315 | // int max_output_object_count = 1000;//p_netinfo[3]; 316 | 317 | std::vector kernels{Yolo::yolo4, Yolo::yolo3, Yolo::yolo2, Yolo::yolo1}; 318 | 319 | YoloLayerPlugin* obj = new YoloLayerPlugin(CLASS_NUM, INPUT_W, INPUT_H, MAX_OUTPUT_BBOX_COUNT, kernels); 320 | obj->setPluginNamespace(mNamespace.c_str()); 321 | return obj; 322 | } 323 | 324 | IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT 325 | { 326 | // This object will be deleted when the network is destroyed, which will 327 | // call YoloLayerPlugin::destroy() 328 | YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength); 329 | obj->setPluginNamespace(mNamespace.c_str()); 330 | return obj; 331 | } 332 | } -------------------------------------------------------------------------------- /YoloLayer_TRT_v7.0/yololayer.h: -------------------------------------------------------------------------------- 1 | #ifndef _YOLO_LAYER_H 2 | #define _YOLO_LAYER_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "macros.h" 10 | 11 | namespace Yolo 12 | { 13 | static constexpr int CHECK_COUNT = 3; 14 | static constexpr float IGNORE_THRESH = 0.1f; 15 | static constexpr int KEY_POINTS_NUM = 17; 16 | 17 | struct YoloKernel 18 | { 19 | int width; 20 | int height; 21 | float anchors[CHECK_COUNT * 2]; 22 | }; 23 | 24 | static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000; 25 | static constexpr int CLASS_NUM = 1; 26 | static constexpr int INPUT_H = 960; // yolov5's input height and width must be divisible by 32. 27 | static constexpr int INPUT_W = 960; 28 | 29 | static constexpr int LOCATIONS = 4; 30 | struct Keypoint { 31 | float x; 32 | float y; 33 | float kpt_conf; 34 | }; 35 | 36 | struct alignas(float) Detection { 37 | //center_x center_y w h 38 | float bbox[LOCATIONS]; 39 | float conf; // bbox_conf * cls_conf 40 | float class_id; //person 0 41 | // 17 keypoints 42 | Keypoint kpts[KEY_POINTS_NUM]; 43 | }; 44 | 45 | static constexpr YoloKernel yolo1 = { 46 | INPUT_W / 64, 47 | INPUT_H / 64, 48 | {436.0f,615.0f, 739.0f,380.0f, 925.0f,792.0f} 49 | }; 50 | static constexpr YoloKernel yolo2 = { 51 | INPUT_W / 32, 52 | INPUT_H / 32, 53 | {140.0f,301.0f, 303.0f,264.0f, 238.0f,542.0f} 54 | }; 55 | static constexpr YoloKernel yolo3 = { 56 | INPUT_W / 16, 57 | INPUT_H / 16, 58 | {96.0f,68.0f, 86.0f,152.0f, 180.0f,137.0f} 59 | }; 60 | static constexpr YoloKernel yolo4 = { 61 | INPUT_W / 8, 62 | INPUT_H / 8, 63 | {19.0f,27.0f, 44.0f,40.0f, 38.0f,94.0f} 64 | }; 65 | 66 | 67 | } 68 | 69 | namespace nvinfer1 70 | { 71 | class API YoloLayerPlugin : public IPluginV2IOExt 72 | { 73 | public: 74 | YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel); 75 | YoloLayerPlugin(const void* data, size_t length); 76 | ~YoloLayerPlugin(); 77 | 78 | int getNbOutputs() const TRT_NOEXCEPT override 79 | { 80 | return 1; 81 | } 82 | 83 | Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override; 84 | 85 | int initialize() TRT_NOEXCEPT override; 86 | 87 | virtual void terminate() TRT_NOEXCEPT override {}; 88 | 89 | virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; } 90 | 91 | virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; 92 | 93 | virtual size_t getSerializationSize() const TRT_NOEXCEPT override; 94 | 95 | virtual void serialize(void* buffer) const TRT_NOEXCEPT override; 96 | 97 | bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override { 98 | return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; 99 | } 100 | 101 | const char* getPluginType() const TRT_NOEXCEPT override; 102 | 103 | const char* getPluginVersion() const TRT_NOEXCEPT override; 104 | 105 | void destroy() TRT_NOEXCEPT override; 106 | 107 | IPluginV2IOExt* clone() const TRT_NOEXCEPT override; 108 | 109 | void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override; 110 | 111 | const char* getPluginNamespace() const TRT_NOEXCEPT override; 112 | 113 | DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; 114 | 115 | bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override; 116 | 117 | bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override; 118 | 119 | void attachToContext( 120 | cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; 121 | 122 | void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override; 123 | using IPluginV2Ext::configurePlugin; 124 | void detachFromContext() TRT_NOEXCEPT override; 125 | 126 | private: 127 | void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1); 128 | int mThreadCount = 256; 129 | const char* mPluginNamespace; 130 | int mKernelCount; 131 | int mClassCount; 132 | int mYoloV5NetWidth; 133 | int mYoloV5NetHeight; 134 | int mMaxOutObject; 135 | std::vector mYoloKernel; 136 | void** mAnchor; 137 | }; 138 | 139 | class API YoloPluginCreator : public IPluginCreator 140 | { 141 | public: 142 | YoloPluginCreator(); 143 | 144 | ~YoloPluginCreator() override = default; 145 | 146 | const char* getPluginName() const TRT_NOEXCEPT override; 147 | 148 | const char* getPluginVersion() const TRT_NOEXCEPT override; 149 | 150 | const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; 151 | 152 | IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override; 153 | 154 | IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; 155 | 156 | void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override 157 | { 158 | mNamespace = libNamespace; 159 | } 160 | 161 | const char* getPluginNamespace() const TRT_NOEXCEPT override 162 | { 163 | return mNamespace.c_str(); 164 | } 165 | 166 | private: 167 | std::string mNamespace; 168 | static PluginFieldCollection mFC; 169 | static std::vector mPluginAttributes; 170 | }; 171 | REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); 172 | }; 173 | 174 | #endif // _YOLO_LAYER_H -------------------------------------------------------------------------------- /assets/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanmi/yolov7-pose/bd8f06fc16b3d0df829e5ee4a23cb8cec59f1a29/assets/1.jpg -------------------------------------------------------------------------------- /assets/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanmi/yolov7-pose/bd8f06fc16b3d0df829e5ee4a23cb8cec59f1a29/assets/2.jpg -------------------------------------------------------------------------------- /assets/_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanmi/yolov7-pose/bd8f06fc16b3d0df829e5ee4a23cb8cec59f1a29/assets/_result.jpg -------------------------------------------------------------------------------- /yolov7-pose/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.10) 2 | 3 | project(yolov7_pose) 4 | 5 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -Wall -Ofast -g -Wfatal-errors -D_MWAITXINTRIN_H_INCLUDED") 6 | add_definitions(-DAPI_EXPORTS) 7 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 8 | set(CMAKE_BUILD_TYPE Debug) 9 | 10 | # cuda 11 | find_package(CUDA REQUIRED) 12 | include_directories(${CUDA_INCLUDE_DIRS}) 13 | 14 | # tensorrt 15 | set(TENSORRT_INCLUDE_DIR /usr/local/TensorRT-8.4.1.5/include/) 16 | set(TENSORRT_LIBRARY_DIR /usr/local/TensorRT-8.4.1.5/lib/) 17 | include_directories(${TENSORRT_INCLUDE_DIR}) 18 | link_directories(${TENSORRT_LIBRARY_DIR}) 19 | 20 | # OpenCV 21 | find_package(OpenCV) 22 | include_directories(${OpenCV_INCLUDE_DIRS}) 23 | 24 | 25 | cuda_add_executable(yolov7_pose yolov7_pose.cpp yololayer.cu) 26 | target_link_libraries(yolov7_pose nvinfer ${CUDA_LIBRARIES} ${OpenCV_LIBS}) 27 | 28 | add_definitions(-O2 -pthread) 29 | 30 | -------------------------------------------------------------------------------- /yolov7-pose/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include "macros.h" 29 | 30 | using Severity = nvinfer1::ILogger::Severity; 31 | 32 | class LogStreamConsumerBuffer : public std::stringbuf 33 | { 34 | public: 35 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 36 | : mOutput(stream) 37 | , mPrefix(prefix) 38 | , mShouldLog(shouldLog) 39 | { 40 | } 41 | 42 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 43 | : mOutput(other.mOutput) 44 | { 45 | } 46 | 47 | ~LogStreamConsumerBuffer() 48 | { 49 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 50 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 51 | // if the pointer to the beginning is not equal to the pointer to the current position, 52 | // call putOutput() to log the output to the stream 53 | if (pbase() != pptr()) 54 | { 55 | putOutput(); 56 | } 57 | } 58 | 59 | // synchronizes the stream buffer and returns 0 on success 60 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 61 | // resetting the buffer and flushing the stream 62 | virtual int sync() 63 | { 64 | putOutput(); 65 | return 0; 66 | } 67 | 68 | void putOutput() 69 | { 70 | if (mShouldLog) 71 | { 72 | // prepend timestamp 73 | std::time_t timestamp = std::time(nullptr); 74 | tm* tm_local = std::localtime(×tamp); 75 | std::cout << "["; 76 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 77 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 78 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 81 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 82 | // std::stringbuf::str() gets the string contents of the buffer 83 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 84 | mOutput << mPrefix << str(); 85 | // set the buffer to empty 86 | str(""); 87 | // flush the stream 88 | mOutput.flush(); 89 | } 90 | } 91 | 92 | void setShouldLog(bool shouldLog) 93 | { 94 | mShouldLog = shouldLog; 95 | } 96 | 97 | private: 98 | std::ostream& mOutput; 99 | std::string mPrefix; 100 | bool mShouldLog; 101 | }; 102 | 103 | //! 104 | //! \class LogStreamConsumerBase 105 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 106 | //! 107 | class LogStreamConsumerBase 108 | { 109 | public: 110 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 111 | : mBuffer(stream, prefix, shouldLog) 112 | { 113 | } 114 | 115 | protected: 116 | LogStreamConsumerBuffer mBuffer; 117 | }; 118 | 119 | //! 120 | //! \class LogStreamConsumer 121 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 122 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 123 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 124 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 125 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 126 | //! Please do not change the order of the parent classes. 127 | //! 128 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 129 | { 130 | public: 131 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 132 | //! Reportable severity determines if the messages are severe enough to be logged. 133 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 134 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 135 | , std::ostream(&mBuffer) // links the stream buffer with the stream 136 | , mShouldLog(severity <= reportableSeverity) 137 | , mSeverity(severity) 138 | { 139 | } 140 | 141 | LogStreamConsumer(LogStreamConsumer&& other) 142 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 143 | , std::ostream(&mBuffer) // links the stream buffer with the stream 144 | , mShouldLog(other.mShouldLog) 145 | , mSeverity(other.mSeverity) 146 | { 147 | } 148 | 149 | void setReportableSeverity(Severity reportableSeverity) 150 | { 151 | mShouldLog = mSeverity <= reportableSeverity; 152 | mBuffer.setShouldLog(mShouldLog); 153 | } 154 | 155 | private: 156 | static std::ostream& severityOstream(Severity severity) 157 | { 158 | return severity >= Severity::kINFO ? std::cout : std::cerr; 159 | } 160 | 161 | static std::string severityPrefix(Severity severity) 162 | { 163 | switch (severity) 164 | { 165 | case Severity::kINTERNAL_ERROR: return "[F] "; 166 | case Severity::kERROR: return "[E] "; 167 | case Severity::kWARNING: return "[W] "; 168 | case Severity::kINFO: return "[I] "; 169 | case Severity::kVERBOSE: return "[V] "; 170 | default: assert(0); return ""; 171 | } 172 | } 173 | 174 | bool mShouldLog; 175 | Severity mSeverity; 176 | }; 177 | 178 | //! \class Logger 179 | //! 180 | //! \brief Class which manages logging of TensorRT tools and samples 181 | //! 182 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 183 | //! and supports logging two types of messages: 184 | //! 185 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 186 | //! - Test pass/fail messages 187 | //! 188 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 189 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 190 | //! 191 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 192 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 193 | //! 194 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 195 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 196 | //! library and messages coming from the sample. 197 | //! 198 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 199 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 200 | //! object. 201 | 202 | class Logger : public nvinfer1::ILogger 203 | { 204 | public: 205 | Logger(Severity severity = Severity::kWARNING) 206 | : mReportableSeverity(severity) 207 | { 208 | } 209 | 210 | //! 211 | //! \enum TestResult 212 | //! \brief Represents the state of a given test 213 | //! 214 | enum class TestResult 215 | { 216 | kRUNNING, //!< The test is running 217 | kPASSED, //!< The test passed 218 | kFAILED, //!< The test failed 219 | kWAIVED //!< The test was waived 220 | }; 221 | 222 | //! 223 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 224 | //! \return The nvinfer1::ILogger associated with this Logger 225 | //! 226 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 227 | //! we can eliminate the inheritance of Logger from ILogger 228 | //! 229 | nvinfer1::ILogger& getTRTLogger() 230 | { 231 | return *this; 232 | } 233 | 234 | //! 235 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 236 | //! 237 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 238 | //! inheritance from nvinfer1::ILogger 239 | //! 240 | void log(Severity severity, const char* msg) noexcept override 241 | { 242 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 243 | } 244 | 245 | //! 246 | //! \brief Method for controlling the verbosity of logging output 247 | //! 248 | //! \param severity The logger will only emit messages that have severity of this level or higher. 249 | //! 250 | void setReportableSeverity(Severity severity) 251 | { 252 | mReportableSeverity = severity; 253 | } 254 | 255 | //! 256 | //! \brief Opaque handle that holds logging information for a particular test 257 | //! 258 | //! This object is an opaque handle to information used by the Logger to print test results. 259 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 260 | //! with Logger::reportTest{Start,End}(). 261 | //! 262 | class TestAtom 263 | { 264 | public: 265 | TestAtom(TestAtom&&) = default; 266 | 267 | private: 268 | friend class Logger; 269 | 270 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 271 | : mStarted(started) 272 | , mName(name) 273 | , mCmdline(cmdline) 274 | { 275 | } 276 | 277 | bool mStarted; 278 | std::string mName; 279 | std::string mCmdline; 280 | }; 281 | 282 | //! 283 | //! \brief Define a test for logging 284 | //! 285 | //! \param[in] name The name of the test. This should be a string starting with 286 | //! "TensorRT" and containing dot-separated strings containing 287 | //! the characters [A-Za-z0-9_]. 288 | //! For example, "TensorRT.sample_googlenet" 289 | //! \param[in] cmdline The command line used to reproduce the test 290 | // 291 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 292 | //! 293 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 294 | { 295 | return TestAtom(false, name, cmdline); 296 | } 297 | 298 | //! 299 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 300 | //! as input 301 | //! 302 | //! \param[in] name The name of the test 303 | //! \param[in] argc The number of command-line arguments 304 | //! \param[in] argv The array of command-line arguments (given as C strings) 305 | //! 306 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 307 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 308 | { 309 | auto cmdline = genCmdlineString(argc, argv); 310 | return defineTest(name, cmdline); 311 | } 312 | 313 | //! 314 | //! \brief Report that a test has started. 315 | //! 316 | //! \pre reportTestStart() has not been called yet for the given testAtom 317 | //! 318 | //! \param[in] testAtom The handle to the test that has started 319 | //! 320 | static void reportTestStart(TestAtom& testAtom) 321 | { 322 | reportTestResult(testAtom, TestResult::kRUNNING); 323 | assert(!testAtom.mStarted); 324 | testAtom.mStarted = true; 325 | } 326 | 327 | //! 328 | //! \brief Report that a test has ended. 329 | //! 330 | //! \pre reportTestStart() has been called for the given testAtom 331 | //! 332 | //! \param[in] testAtom The handle to the test that has ended 333 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 334 | //! TestResult::kFAILED, TestResult::kWAIVED 335 | //! 336 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 337 | { 338 | assert(result != TestResult::kRUNNING); 339 | assert(testAtom.mStarted); 340 | reportTestResult(testAtom, result); 341 | } 342 | 343 | static int reportPass(const TestAtom& testAtom) 344 | { 345 | reportTestEnd(testAtom, TestResult::kPASSED); 346 | return EXIT_SUCCESS; 347 | } 348 | 349 | static int reportFail(const TestAtom& testAtom) 350 | { 351 | reportTestEnd(testAtom, TestResult::kFAILED); 352 | return EXIT_FAILURE; 353 | } 354 | 355 | static int reportWaive(const TestAtom& testAtom) 356 | { 357 | reportTestEnd(testAtom, TestResult::kWAIVED); 358 | return EXIT_SUCCESS; 359 | } 360 | 361 | static int reportTest(const TestAtom& testAtom, bool pass) 362 | { 363 | return pass ? reportPass(testAtom) : reportFail(testAtom); 364 | } 365 | 366 | Severity getReportableSeverity() const 367 | { 368 | return mReportableSeverity; 369 | } 370 | 371 | private: 372 | //! 373 | //! \brief returns an appropriate string for prefixing a log message with the given severity 374 | //! 375 | static const char* severityPrefix(Severity severity) 376 | { 377 | switch (severity) 378 | { 379 | case Severity::kINTERNAL_ERROR: return "[F] "; 380 | case Severity::kERROR: return "[E] "; 381 | case Severity::kWARNING: return "[W] "; 382 | case Severity::kINFO: return "[I] "; 383 | case Severity::kVERBOSE: return "[V] "; 384 | default: assert(0); return ""; 385 | } 386 | } 387 | 388 | //! 389 | //! \brief returns an appropriate string for prefixing a test result message with the given result 390 | //! 391 | static const char* testResultString(TestResult result) 392 | { 393 | switch (result) 394 | { 395 | case TestResult::kRUNNING: return "RUNNING"; 396 | case TestResult::kPASSED: return "PASSED"; 397 | case TestResult::kFAILED: return "FAILED"; 398 | case TestResult::kWAIVED: return "WAIVED"; 399 | default: assert(0); return ""; 400 | } 401 | } 402 | 403 | //! 404 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 405 | //! 406 | static std::ostream& severityOstream(Severity severity) 407 | { 408 | return severity >= Severity::kINFO ? std::cout : std::cerr; 409 | } 410 | 411 | //! 412 | //! \brief method that implements logging test results 413 | //! 414 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 415 | { 416 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 417 | << testAtom.mCmdline << std::endl; 418 | } 419 | 420 | //! 421 | //! \brief generate a command line string from the given (argc, argv) values 422 | //! 423 | static std::string genCmdlineString(int argc, char const* const* argv) 424 | { 425 | std::stringstream ss; 426 | for (int i = 0; i < argc; i++) 427 | { 428 | if (i > 0) 429 | ss << " "; 430 | ss << argv[i]; 431 | } 432 | return ss.str(); 433 | } 434 | 435 | Severity mReportableSeverity; 436 | }; 437 | 438 | namespace 439 | { 440 | 441 | //! 442 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 443 | //! 444 | //! Example usage: 445 | //! 446 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 447 | //! 448 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 449 | { 450 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 451 | } 452 | 453 | //! 454 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 455 | //! 456 | //! Example usage: 457 | //! 458 | //! LOG_INFO(logger) << "hello world" << std::endl; 459 | //! 460 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 461 | { 462 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 463 | } 464 | 465 | //! 466 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 467 | //! 468 | //! Example usage: 469 | //! 470 | //! LOG_WARN(logger) << "hello world" << std::endl; 471 | //! 472 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 473 | { 474 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 475 | } 476 | 477 | //! 478 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 479 | //! 480 | //! Example usage: 481 | //! 482 | //! LOG_ERROR(logger) << "hello world" << std::endl; 483 | //! 484 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 485 | { 486 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 487 | } 488 | 489 | //! 490 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 491 | // ("fatal" severity) 492 | //! 493 | //! Example usage: 494 | //! 495 | //! LOG_FATAL(logger) << "hello world" << std::endl; 496 | //! 497 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 498 | { 499 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 500 | } 501 | 502 | } // anonymous namespace 503 | 504 | #endif // TENSORRT_LOGGING_H 505 | -------------------------------------------------------------------------------- /yolov7-pose/macros.h: -------------------------------------------------------------------------------- 1 | #ifndef __MACROS_H 2 | #define __MACROS_H 3 | 4 | #ifdef API_EXPORTS 5 | #if defined(_MSC_VER) 6 | #define API __declspec(dllexport) 7 | #else 8 | #define API __attribute__((visibility("default"))) 9 | #endif 10 | #else 11 | 12 | #if defined(_MSC_VER) 13 | #define API __declspec(dllimport) 14 | #else 15 | #define API 16 | #endif 17 | #endif // API_EXPORTS 18 | 19 | #if NV_TENSORRT_MAJOR >= 8 20 | #define TRT_NOEXCEPT noexcept 21 | #define TRT_CONST_ENQUEUE const 22 | #else 23 | #define TRT_NOEXCEPT 24 | #define TRT_CONST_ENQUEUE 25 | #endif 26 | 27 | #endif // __MACROS_H 28 | -------------------------------------------------------------------------------- /yolov7-pose/person.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nanmi/yolov7-pose/bd8f06fc16b3d0df829e5ee4a23cb8cec59f1a29/yolov7-pose/person.jpg -------------------------------------------------------------------------------- /yolov7-pose/utils.h: -------------------------------------------------------------------------------- 1 | #ifndef TRTX_YOLOV5_UTILS_H_ 2 | #define TRTX_YOLOV5_UTILS_H_ 3 | 4 | #include 5 | #include 6 | 7 | static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h) { 8 | int w, h, x, y; 9 | float r_w = input_w / (img.cols*1.0); 10 | float r_h = input_h / (img.rows*1.0); 11 | if (r_h > r_w) { 12 | w = input_w; 13 | h = r_w * img.rows; 14 | x = 0; 15 | y = (input_h - h) / 2; 16 | } else { 17 | w = r_h * img.cols; 18 | h = input_h; 19 | x = (input_w - w) / 2; 20 | y = 0; 21 | } 22 | cv::Mat re(h, w, CV_8UC3); 23 | cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR); 24 | cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128)); 25 | re.copyTo(out(cv::Rect(x, y, re.cols, re.rows))); 26 | return out; 27 | } 28 | 29 | static inline int read_files_in_dir(const char *p_dir_name, std::vector &file_names) { 30 | DIR *p_dir = opendir(p_dir_name); 31 | if (p_dir == nullptr) { 32 | return -1; 33 | } 34 | 35 | struct dirent* p_file = nullptr; 36 | while ((p_file = readdir(p_dir)) != nullptr) { 37 | if (strcmp(p_file->d_name, ".") != 0 && 38 | strcmp(p_file->d_name, "..") != 0) { 39 | //std::string cur_file_name(p_dir_name); 40 | //cur_file_name += "/"; 41 | //cur_file_name += p_file->d_name; 42 | std::string cur_file_name(p_file->d_name); 43 | file_names.push_back(cur_file_name); 44 | } 45 | } 46 | 47 | closedir(p_dir); 48 | return 0; 49 | } 50 | 51 | #endif // TRTX_YOLOV5_UTILS_H_ 52 | 53 | -------------------------------------------------------------------------------- /yolov7-pose/yololayer.cu: -------------------------------------------------------------------------------- 1 | #include "yololayer.h" 2 | 3 | namespace Tn 4 | { 5 | template 6 | void write(char*& buffer, const T& val) 7 | { 8 | *reinterpret_cast(buffer) = val; 9 | buffer += sizeof(T); 10 | } 11 | 12 | template 13 | void read(const char*& buffer, T& val) 14 | { 15 | val = *reinterpret_cast(buffer); 16 | buffer += sizeof(T); 17 | } 18 | } 19 | 20 | #define CUDA_CHECK(callstr)\ 21 | {\ 22 | cudaError_t error_code = callstr;\ 23 | if (error_code != cudaSuccess) {\ 24 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ 25 | assert(0);\ 26 | }\ 27 | } 28 | 29 | using namespace Yolo; 30 | 31 | namespace nvinfer1 32 | { 33 | YoloLayerPlugin::YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel) 34 | { 35 | mClassCount = classCount; 36 | mYoloV5NetWidth = netWidth; 37 | mYoloV5NetHeight = netHeight; 38 | mMaxOutObject = maxOut; 39 | mYoloKernel = vYoloKernel; 40 | mKernelCount = vYoloKernel.size(); 41 | 42 | CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); 43 | size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; 44 | for (int ii = 0; ii < mKernelCount; ii++) 45 | { 46 | CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); 47 | const auto& yolo = mYoloKernel[ii]; 48 | CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); 49 | } 50 | } 51 | YoloLayerPlugin::~YoloLayerPlugin() 52 | { 53 | for (int ii = 0; ii < mKernelCount; ii++) 54 | { 55 | CUDA_CHECK(cudaFree(mAnchor[ii])); 56 | } 57 | CUDA_CHECK(cudaFreeHost(mAnchor)); 58 | } 59 | 60 | // create the plugin at runtime from a byte stream 61 | YoloLayerPlugin::YoloLayerPlugin(const void* data, size_t length) 62 | { 63 | using namespace Tn; 64 | const char *d = reinterpret_cast(data), *a = d; 65 | read(d, mClassCount); 66 | read(d, mThreadCount); 67 | read(d, mKernelCount); 68 | read(d, mYoloV5NetWidth); 69 | read(d, mYoloV5NetHeight); 70 | read(d, mMaxOutObject); 71 | mYoloKernel.resize(mKernelCount); 72 | auto kernelSize = mKernelCount * sizeof(YoloKernel); 73 | memcpy(mYoloKernel.data(), d, kernelSize); 74 | d += kernelSize; 75 | CUDA_CHECK(cudaMallocHost(&mAnchor, mKernelCount * sizeof(void*))); 76 | size_t AnchorLen = sizeof(float)* CHECK_COUNT * 2; 77 | for (int ii = 0; ii < mKernelCount; ii++) 78 | { 79 | CUDA_CHECK(cudaMalloc(&mAnchor[ii], AnchorLen)); 80 | const auto& yolo = mYoloKernel[ii]; 81 | CUDA_CHECK(cudaMemcpy(mAnchor[ii], yolo.anchors, AnchorLen, cudaMemcpyHostToDevice)); 82 | } 83 | assert(d == a + length); 84 | } 85 | 86 | void YoloLayerPlugin::serialize(void* buffer) const TRT_NOEXCEPT 87 | { 88 | using namespace Tn; 89 | char* d = static_cast(buffer), *a = d; 90 | write(d, mClassCount); 91 | write(d, mThreadCount); 92 | write(d, mKernelCount); 93 | write(d, mYoloV5NetWidth); 94 | write(d, mYoloV5NetHeight); 95 | write(d, mMaxOutObject); 96 | auto kernelSize = mKernelCount * sizeof(YoloKernel); 97 | memcpy(d, mYoloKernel.data(), kernelSize); 98 | d += kernelSize; 99 | 100 | assert(d == a + getSerializationSize()); 101 | } 102 | 103 | size_t YoloLayerPlugin::getSerializationSize() const TRT_NOEXCEPT 104 | { 105 | return sizeof(mClassCount) + sizeof(mThreadCount) + sizeof(mKernelCount) + sizeof(Yolo::YoloKernel) * mYoloKernel.size() + sizeof(mYoloV5NetWidth) + sizeof(mYoloV5NetHeight) + sizeof(mMaxOutObject); 106 | } 107 | 108 | int YoloLayerPlugin::initialize() TRT_NOEXCEPT 109 | { 110 | return 0; 111 | } 112 | 113 | Dims YoloLayerPlugin::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT 114 | { 115 | //output the result to channel 116 | int totalsize = mMaxOutObject * sizeof(Detection) / sizeof(float); 117 | 118 | return Dims3(totalsize + 1, 1, 1); 119 | } 120 | 121 | // Set plugin namespace 122 | void YoloLayerPlugin::setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT 123 | { 124 | mPluginNamespace = pluginNamespace; 125 | } 126 | 127 | const char* YoloLayerPlugin::getPluginNamespace() const TRT_NOEXCEPT 128 | { 129 | return mPluginNamespace; 130 | } 131 | 132 | // Return the DataType of the plugin output at the requested index 133 | DataType YoloLayerPlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT 134 | { 135 | return DataType::kFLOAT; 136 | } 137 | 138 | // Return true if output tensor is broadcast across a batch. 139 | bool YoloLayerPlugin::isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT 140 | { 141 | return false; 142 | } 143 | 144 | // Return true if plugin can use input that is broadcast across batch without replication. 145 | bool YoloLayerPlugin::canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT 146 | { 147 | return false; 148 | } 149 | 150 | void YoloLayerPlugin::configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT 151 | { 152 | } 153 | 154 | // Attach the plugin object to an execution context and grant the plugin the access to some context resource. 155 | void YoloLayerPlugin::attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT 156 | { 157 | } 158 | 159 | // Detach the plugin object from its execution context. 160 | void YoloLayerPlugin::detachFromContext() TRT_NOEXCEPT {} 161 | 162 | const char* YoloLayerPlugin::getPluginType() const TRT_NOEXCEPT 163 | { 164 | return "YoloLayer_TRT"; 165 | } 166 | 167 | const char* YoloLayerPlugin::getPluginVersion() const TRT_NOEXCEPT 168 | { 169 | return "1"; 170 | } 171 | 172 | void YoloLayerPlugin::destroy() TRT_NOEXCEPT 173 | { 174 | delete this; 175 | } 176 | 177 | // Clone the plugin 178 | IPluginV2IOExt* YoloLayerPlugin::clone() const TRT_NOEXCEPT 179 | { 180 | YoloLayerPlugin* p = new YoloLayerPlugin(mClassCount, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, mYoloKernel); 181 | p->setPluginNamespace(mPluginNamespace); 182 | return p; 183 | } 184 | 185 | __device__ float Logist(float data) { return 1.0f / (1.0f + expf(-data)); }; 186 | 187 | __global__ void CalDetection(const float *input, float *output, int noElements, 188 | const int netwidth, const int netheight, int maxoutobject, int yoloWidth, int yoloHeight, const float anchors[CHECK_COUNT * 2], int classes, int outputElem) 189 | { 190 | 191 | int idx = threadIdx.x + blockDim.x * blockIdx.x; 192 | if (idx >= noElements) return; 193 | 194 | int total_grid = yoloWidth * yoloHeight; 195 | int bnIdx = idx / total_grid; 196 | idx = idx - total_grid * bnIdx; 197 | int info_len_i = 5 + classes; 198 | int info_len_kpt = KEY_POINTS_NUM * 3; 199 | const float* curInput = input + bnIdx * ((info_len_i + info_len_kpt) * total_grid * CHECK_COUNT); 200 | 201 | for (int k = 0; k < CHECK_COUNT; ++k) { 202 | float box_prob = Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 4 * total_grid]); 203 | if (box_prob < IGNORE_THRESH) continue; 204 | int class_id = 0; //person class 205 | // float max_cls_prob = 0.0; 206 | // for (int i = 5; i < info_len_i; ++i) { 207 | // float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]); 208 | // if (p > max_cls_prob) { 209 | // max_cls_prob = p; 210 | // class_id = i - 5; 211 | // } 212 | // } 213 | float max_cls_prob = Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 5 * total_grid]); 214 | 215 | float *res_count = output + bnIdx * outputElem; 216 | int count = (int)atomicAdd(res_count, 1); 217 | if (count >= maxoutobject) return; 218 | char *data = (char*)res_count + sizeof(float) + count * sizeof(Detection); 219 | Detection *det = (Detection*)(data); 220 | 221 | int row = idx / yoloWidth; 222 | int col = idx % yoloWidth; 223 | 224 | //------------bboxs------------ 225 | //Location 226 | // pytorch: 227 | // y = x[i].sigmoid() 228 | // xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy 229 | // wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh 230 | 231 | det->bbox[0] = (col - 0.5f + 2.0f * Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 0 * total_grid])) * netwidth / yoloWidth; 232 | det->bbox[1] = (row - 0.5f + 2.0f * Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 1 * total_grid])) * netheight / yoloHeight; 233 | 234 | // W: (Pw * e^tw) / FeaturemapW * netwidth 235 | // v5: https://github.com/ultralytics/yolov5/issues/471 236 | det->bbox[2] = 2.0f * Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 2 * total_grid]); 237 | det->bbox[2] = det->bbox[2] * det->bbox[2] * anchors[2 * k]; 238 | det->bbox[3] = 2.0f * Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + 3 * total_grid]); 239 | det->bbox[3] = det->bbox[3] * det->bbox[3] * anchors[2 * k + 1]; 240 | det->conf = box_prob * max_cls_prob; 241 | det->class_id = class_id; 242 | 243 | //------------keypoints------------ 244 | // Location 245 | //pytorch: 246 | // x_kpt[..., 0::3] = (x_kpt[..., ::3] * 2. - 0.5 + kpt_grid_x.repeat(1,1,1,1,17)) * self.stride[i] # xy 247 | // x_kpt[..., 1::3] = (x_kpt[..., 1::3] * 2. - 0.5 + kpt_grid_y.repeat(1,1,1,1,17)) * self.stride[i] # xy 248 | // x_kpt[..., 2::3] = x_kpt[..., 2::3].sigmoid() 249 | for (int kpt_idx = 0; kpt_idx < KEY_POINTS_NUM; ++kpt_idx) 250 | { 251 | det->kpts[kpt_idx].x = (col - 0.5f + 2.0f * (curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + (6 + kpt_idx*3) * total_grid]) ) * netwidth / yoloWidth; 252 | det->kpts[kpt_idx].y = (row - 0.5f + 2.0f * (curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + (7 + kpt_idx*3) * total_grid]) ) * netheight / yoloHeight; 253 | det->kpts[kpt_idx].kpt_conf = Logist(curInput[idx + k * (info_len_i + info_len_kpt) * total_grid + (8 + kpt_idx*3) * total_grid]); 254 | } 255 | 256 | } 257 | } 258 | 259 | void YoloLayerPlugin::forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize) 260 | { 261 | int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float); 262 | for (int idx = 0; idx < batchSize; ++idx) { 263 | CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream)); 264 | } 265 | int numElem = 0; 266 | for (unsigned int i = 0; i < mYoloKernel.size(); ++i) { 267 | const auto& yolo = mYoloKernel[i]; 268 | numElem = yolo.width * yolo.height * batchSize; 269 | if (numElem < mThreadCount) mThreadCount = numElem; 270 | 271 | //printf("Net: %d %d \n", mYoloV5NetWidth, mYoloV5NetHeight); 272 | CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> > 273 | (inputs[i], output, numElem, mYoloV5NetWidth, mYoloV5NetHeight, mMaxOutObject, yolo.width, yolo.height, (float*)mAnchor[i], mClassCount, outputElem); 274 | } 275 | } 276 | 277 | 278 | int YoloLayerPlugin::enqueue(int batchSize, const void* const* inputs, void* TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT 279 | { 280 | forwardGpu((const float* const*)inputs, (float*)outputs[0], stream, batchSize); 281 | return 0; 282 | } 283 | 284 | PluginFieldCollection YoloPluginCreator::mFC{}; 285 | std::vector YoloPluginCreator::mPluginAttributes; 286 | 287 | YoloPluginCreator::YoloPluginCreator() 288 | { 289 | mPluginAttributes.clear(); 290 | 291 | mFC.nbFields = mPluginAttributes.size(); 292 | mFC.fields = mPluginAttributes.data(); 293 | } 294 | 295 | const char* YoloPluginCreator::getPluginName() const TRT_NOEXCEPT 296 | { 297 | return "YoloLayer_TRT"; 298 | } 299 | 300 | const char* YoloPluginCreator::getPluginVersion() const TRT_NOEXCEPT 301 | { 302 | return "1"; 303 | } 304 | 305 | const PluginFieldCollection* YoloPluginCreator::getFieldNames() TRT_NOEXCEPT 306 | { 307 | return &mFC; 308 | } 309 | 310 | IPluginV2IOExt* YoloPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT 311 | { 312 | // int class_count = 1;//p_netinfo[0]; 313 | // int input_w = 960;//p_netinfo[1]; 314 | // int input_h = 960;//p_netinfo[2]; 315 | // int max_output_object_count = 1000;//p_netinfo[3]; 316 | 317 | std::vector kernels{Yolo::yolo4, Yolo::yolo3, Yolo::yolo2, Yolo::yolo1}; 318 | 319 | YoloLayerPlugin* obj = new YoloLayerPlugin(CLASS_NUM, INPUT_W, INPUT_H, MAX_OUTPUT_BBOX_COUNT, kernels); 320 | obj->setPluginNamespace(mNamespace.c_str()); 321 | return obj; 322 | } 323 | 324 | IPluginV2IOExt* YoloPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT 325 | { 326 | // This object will be deleted when the network is destroyed, which will 327 | // call YoloLayerPlugin::destroy() 328 | YoloLayerPlugin* obj = new YoloLayerPlugin(serialData, serialLength); 329 | obj->setPluginNamespace(mNamespace.c_str()); 330 | return obj; 331 | } 332 | } 333 | -------------------------------------------------------------------------------- /yolov7-pose/yololayer.h: -------------------------------------------------------------------------------- 1 | #ifndef _YOLO_LAYER_H 2 | #define _YOLO_LAYER_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "macros.h" 10 | 11 | namespace Yolo 12 | { 13 | static constexpr int CHECK_COUNT = 3; 14 | static constexpr float IGNORE_THRESH = 0.1f; 15 | static constexpr int KEY_POINTS_NUM = 17; 16 | 17 | struct YoloKernel 18 | { 19 | int width; 20 | int height; 21 | float anchors[CHECK_COUNT * 2]; 22 | }; 23 | 24 | static constexpr int MAX_OUTPUT_BBOX_COUNT = 1000; 25 | static constexpr int CLASS_NUM = 1; 26 | static constexpr int INPUT_H = 960; // yolov5's input height and width must be divisible by 32. 27 | static constexpr int INPUT_W = 960; 28 | 29 | static constexpr int LOCATIONS = 4; 30 | struct Keypoint { 31 | float x; 32 | float y; 33 | float kpt_conf; 34 | }; 35 | 36 | struct alignas(float) Detection { 37 | //center_x center_y w h 38 | float bbox[LOCATIONS]; 39 | float conf; // bbox_conf * cls_conf 40 | float class_id; //person 0 41 | // 17 keypoints 42 | Keypoint kpts[KEY_POINTS_NUM]; 43 | }; 44 | 45 | static constexpr YoloKernel yolo1 = { 46 | INPUT_W / 64, 47 | INPUT_H / 64, 48 | {436.0f,615.0f, 739.0f,380.0f, 925.0f,792.0f} 49 | }; 50 | static constexpr YoloKernel yolo2 = { 51 | INPUT_W / 32, 52 | INPUT_H / 32, 53 | {140.0f,301.0f, 303.0f,264.0f, 238.0f,542.0f} 54 | }; 55 | static constexpr YoloKernel yolo3 = { 56 | INPUT_W / 16, 57 | INPUT_H / 16, 58 | {96.0f,68.0f, 86.0f,152.0f, 180.0f,137.0f} 59 | }; 60 | static constexpr YoloKernel yolo4 = { 61 | INPUT_W / 8, 62 | INPUT_H / 8, 63 | {19.0f,27.0f, 44.0f,40.0f, 38.0f,94.0f} 64 | }; 65 | 66 | 67 | } 68 | 69 | namespace nvinfer1 70 | { 71 | class API YoloLayerPlugin : public IPluginV2IOExt 72 | { 73 | public: 74 | YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, const std::vector& vYoloKernel); 75 | YoloLayerPlugin(const void* data, size_t length); 76 | ~YoloLayerPlugin(); 77 | 78 | int getNbOutputs() const TRT_NOEXCEPT override 79 | { 80 | return 1; 81 | } 82 | 83 | Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) TRT_NOEXCEPT override; 84 | 85 | int initialize() TRT_NOEXCEPT override; 86 | 87 | virtual void terminate() TRT_NOEXCEPT override {}; 88 | 89 | virtual size_t getWorkspaceSize(int maxBatchSize) const TRT_NOEXCEPT override { return 0; } 90 | 91 | virtual int enqueue(int batchSize, const void* const* inputs, void*TRT_CONST_ENQUEUE* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; 92 | 93 | virtual size_t getSerializationSize() const TRT_NOEXCEPT override; 94 | 95 | virtual void serialize(void* buffer) const TRT_NOEXCEPT override; 96 | 97 | bool supportsFormatCombination(int pos, const PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const TRT_NOEXCEPT override { 98 | return inOut[pos].format == TensorFormat::kLINEAR && inOut[pos].type == DataType::kFLOAT; 99 | } 100 | 101 | const char* getPluginType() const TRT_NOEXCEPT override; 102 | 103 | const char* getPluginVersion() const TRT_NOEXCEPT override; 104 | 105 | void destroy() TRT_NOEXCEPT override; 106 | 107 | IPluginV2IOExt* clone() const TRT_NOEXCEPT override; 108 | 109 | void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override; 110 | 111 | const char* getPluginNamespace() const TRT_NOEXCEPT override; 112 | 113 | DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; 114 | 115 | bool isOutputBroadcastAcrossBatch(int outputIndex, const bool* inputIsBroadcasted, int nbInputs) const TRT_NOEXCEPT override; 116 | 117 | bool canBroadcastInputAcrossBatch(int inputIndex) const TRT_NOEXCEPT override; 118 | 119 | void attachToContext( 120 | cudnnContext* cudnnContext, cublasContext* cublasContext, IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; 121 | 122 | void configurePlugin(const PluginTensorDesc* in, int nbInput, const PluginTensorDesc* out, int nbOutput) TRT_NOEXCEPT override; 123 | using IPluginV2Ext::configurePlugin; 124 | void detachFromContext() TRT_NOEXCEPT override; 125 | 126 | private: 127 | void forwardGpu(const float* const* inputs, float *output, cudaStream_t stream, int batchSize = 1); 128 | int mThreadCount = 256; 129 | const char* mPluginNamespace; 130 | int mKernelCount; 131 | int mClassCount; 132 | int mYoloV5NetWidth; 133 | int mYoloV5NetHeight; 134 | int mMaxOutObject; 135 | std::vector mYoloKernel; 136 | void** mAnchor; 137 | }; 138 | 139 | class API YoloPluginCreator : public IPluginCreator 140 | { 141 | public: 142 | YoloPluginCreator(); 143 | 144 | ~YoloPluginCreator() override = default; 145 | 146 | const char* getPluginName() const TRT_NOEXCEPT override; 147 | 148 | const char* getPluginVersion() const TRT_NOEXCEPT override; 149 | 150 | const PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override; 151 | 152 | IPluginV2IOExt* createPlugin(const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT override; 153 | 154 | IPluginV2IOExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; 155 | 156 | void setPluginNamespace(const char* libNamespace) TRT_NOEXCEPT override 157 | { 158 | mNamespace = libNamespace; 159 | } 160 | 161 | const char* getPluginNamespace() const TRT_NOEXCEPT override 162 | { 163 | return mNamespace.c_str(); 164 | } 165 | 166 | private: 167 | std::string mNamespace; 168 | static PluginFieldCollection mFC; 169 | static std::vector mPluginAttributes; 170 | }; 171 | REGISTER_TENSORRT_PLUGIN(YoloPluginCreator); 172 | }; 173 | 174 | #endif // _YOLO_LAYER_H -------------------------------------------------------------------------------- /yolov7-pose/yolov7_pose.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include "NvInfer.h" 12 | #include "cuda_runtime_api.h" 13 | #include "logging.h" 14 | #include "utils.h" 15 | #include "yololayer.h" 16 | 17 | using namespace Yolo; 18 | #define CHECK(status) \ 19 | do\ 20 | {\ 21 | auto ret = (status);\ 22 | if (ret != 0)\ 23 | {\ 24 | std::cerr << "Cuda failure: " << ret << std::endl;\ 25 | abort();\ 26 | }\ 27 | } while (0) 28 | 29 | #define DEVICE 0 // GPU id 30 | #define NMS_THRESH 0.65 31 | #define CONF_THRESH 0.75 32 | 33 | using namespace nvinfer1; 34 | std::string categories[] = {"person"}; 35 | static const unsigned char pose_kpt_color[17][3] = 36 | { 37 | { 0, 255, 0}, 38 | { 0, 255, 0}, 39 | { 0, 255, 0}, 40 | { 0, 255, 0}, 41 | { 0, 255, 0}, 42 | {255, 128, 0}, 43 | {255, 128, 0}, 44 | {255, 128, 0}, 45 | {255, 128, 0}, 46 | {255, 128, 0}, 47 | {255, 128, 0}, 48 | { 51, 153, 255}, 49 | { 51, 153, 255}, 50 | { 51, 153, 255}, 51 | { 51, 153, 255}, 52 | { 51, 153, 255}, 53 | { 51, 153, 255} 54 | }; 55 | static const unsigned char pose_limb_color[19][3] = 56 | { 57 | { 51, 153, 255}, 58 | { 51, 153, 255}, 59 | { 51, 153, 255}, 60 | { 51, 153, 255}, 61 | {255, 51, 255}, 62 | {255, 51, 255}, 63 | {255, 51, 255}, 64 | {255, 128, 0}, 65 | {255, 128, 0}, 66 | {255, 128, 0}, 67 | {255, 128, 0}, 68 | {255, 128, 0}, 69 | { 0, 255, 0}, 70 | { 0, 255, 0}, 71 | { 0, 255, 0}, 72 | { 0, 255, 0}, 73 | { 0, 255, 0}, 74 | { 0, 255, 0}, 75 | { 0, 255, 0} 76 | }; 77 | static const unsigned char skeleton[19][2] = 78 | { 79 | {16, 14}, 80 | {14, 12}, 81 | {17, 15}, 82 | {15, 13}, 83 | {12, 13}, 84 | {6, 12}, 85 | {7, 13}, 86 | {6, 7}, 87 | {6, 8}, 88 | {7, 9}, 89 | {8, 10}, 90 | {9, 11}, 91 | {2, 3}, 92 | {1, 2}, 93 | {1, 3}, 94 | {2, 4}, 95 | {3, 5}, 96 | {4, 6}, 97 | {5, 7} 98 | }; 99 | // stuff we know about the network and the input/output blobs 100 | const char* INPUT_BLOB_NAME = "images"; 101 | const char* OUTPUT_BLOB_NAME = "output0"; 102 | static Logger gLogger; 103 | 104 | 105 | float* blobFromImage(cv::Mat& img){ 106 | cv::cvtColor(img, img, cv::COLOR_BGR2RGB); 107 | 108 | float* blob = new float[img.total()*3]; 109 | int channels = 3; 110 | int img_h = img.rows; 111 | int img_w = img.cols; 112 | for (int c = 0; c < channels; c++) 113 | { 114 | for (int h = 0; h < img_h; h++) 115 | { 116 | for (int w = 0; w < img_w; w++) 117 | { 118 | blob[c * img_w * img_h + h * img_w + w] = 119 | (((float)img.at(h, w)[c]) / 255.0f); 120 | } 121 | } 122 | } 123 | return blob; 124 | } 125 | 126 | cv::Mat static_resize(cv::Mat& img) { 127 | float r = std::min(INPUT_W / (img.cols*1.0), INPUT_H / (img.rows*1.0)); 128 | int unpad_w = r * img.cols; 129 | int unpad_h = r * img.rows; 130 | cv::Mat re(unpad_h, unpad_w, CV_8UC3); 131 | cv::resize(img, re, re.size()); 132 | cv::Mat out(INPUT_W, INPUT_H, CV_8UC3, cv::Scalar(114, 114, 114)); 133 | re.copyTo(out(cv::Rect(0, 0, re.cols, re.rows))); 134 | return out; 135 | } 136 | 137 | 138 | cv::Rect get_bbox(cv::Mat& img, float bbox[4]) { 139 | float l, r, t, b; 140 | float r_w = INPUT_W / (img.cols * 1.0); 141 | float r_h = INPUT_H / (img.rows * 1.0); 142 | if (r_h > r_w) { 143 | l = bbox[0] - bbox[2] / 2.f; 144 | r = bbox[0] + bbox[2] / 2.f; 145 | t = bbox[1] - bbox[3] / 2.f;// - (INPUT_H - r_w * img.rows) / 2; 146 | b = bbox[1] + bbox[3] / 2.f;// - (INPUT_H - r_w * img.rows) / 2; 147 | l = l / r_w; 148 | r = r / r_w; 149 | t = t / r_w; 150 | b = b / r_w; 151 | } else { 152 | l = bbox[0] - bbox[2] / 2.f;// - (INPUT_W - r_h * img.cols) / 2; 153 | r = bbox[0] + bbox[2] / 2.f;// - (INPUT_W - r_h * img.cols) / 2; 154 | t = bbox[1] - bbox[3] / 2.f; 155 | b = bbox[1] + bbox[3] / 2.f; 156 | l = l / r_h; 157 | r = r / r_h; 158 | t = t / r_h; 159 | b = b / r_h; 160 | } 161 | return cv::Rect(round(l), round(t), round(r - l), round(b - t)); 162 | } 163 | 164 | int get_kpts(cv::Mat& img, Keypoint kpts[17]) { 165 | float r_w = INPUT_W / (img.cols * 1.0); 166 | float r_h = INPUT_H / (img.rows * 1.0); 167 | for (int i = 0; i < 17; i++) 168 | { 169 | if (r_h > r_w) { 170 | kpts[i].x /= r_w; 171 | kpts[i].y /= r_w; 172 | } else { 173 | kpts[i].x /= r_h; 174 | kpts[i].y /= r_h; 175 | } 176 | } 177 | 178 | return 0; 179 | } 180 | 181 | float iou(float lbox[4], float rbox[4]) { 182 | float interBox[] = { 183 | (std::max)(lbox[0] - lbox[2] / 2.f , rbox[0] - rbox[2] / 2.f), //left 184 | (std::min)(lbox[0] + lbox[2] / 2.f , rbox[0] + rbox[2] / 2.f), //right 185 | (std::max)(lbox[1] - lbox[3] / 2.f , rbox[1] - rbox[3] / 2.f), //top 186 | (std::min)(lbox[1] + lbox[3] / 2.f , rbox[1] + rbox[3] / 2.f), //bottom 187 | }; 188 | 189 | if (interBox[2] > interBox[3] || interBox[0] > interBox[1]) 190 | return 0.0f; 191 | 192 | float interBoxS = (interBox[1] - interBox[0])*(interBox[3] - interBox[2]); 193 | return interBoxS / (lbox[2] * lbox[3] + rbox[2] * rbox[3] - interBoxS); 194 | } 195 | 196 | bool cmp(const Detection& a, const Detection& b) { 197 | return a.conf > b.conf; 198 | } 199 | 200 | void nms(std::map>& objects_map, std::vector& res, float nms_thresh = 0.5) { 201 | for (auto it = objects_map.begin(); it != objects_map.end(); it++) 202 | { 203 | auto& dets = it->second; 204 | std::sort(dets.begin(), dets.end(), cmp); 205 | for (unsigned int det_map_i = 0; det_map_i < dets.size(); ++det_map_i) { 206 | auto& item = dets[det_map_i]; 207 | res.push_back(item); 208 | for (unsigned int n = det_map_i + 1; n < dets.size(); ++n) { 209 | if (iou(item.bbox, dets[n].bbox) > nms_thresh) { 210 | dets.erase(dets.begin() + n); 211 | --n; 212 | } 213 | } 214 | } 215 | } 216 | } 217 | 218 | static void postprocess_decode(float* feat_blob, float prob_threshold, std::map>& objects_map) 219 | { 220 | int det_size = sizeof(Detection) / sizeof(float); 221 | for (int i = 0; i < feat_blob[0] && i < MAX_OUTPUT_BBOX_COUNT; i++) { 222 | if (feat_blob[1 + det_size * i + 4] <= prob_threshold) continue; 223 | Detection det; 224 | memcpy(&det, &feat_blob[1 + det_size * i], det_size * sizeof(float)); 225 | if (objects_map.count(det.class_id) == 0) objects_map.emplace(det.class_id, std::vector()); 226 | objects_map[det.class_id].push_back(det); 227 | } 228 | } 229 | 230 | 231 | 232 | 233 | 234 | void doInference(IExecutionContext& context, float* input, float* output, const int output_size, const int input_shape) { 235 | const ICudaEngine& engine = context.getEngine(); 236 | 237 | // Pointers to input and output device buffers to pass to engine. 238 | // Engine requires exactly IEngine::getNbBindings() number of buffers. 239 | assert(engine.getNbBindings() == 2); 240 | void* buffers[2]; 241 | 242 | // In order to bind the buffers, we need to know the names of the input and output tensors. 243 | // Note that indices are guaranteed to be less than IEngine::getNbBindings() 244 | const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME); 245 | 246 | assert(engine.getBindingDataType(inputIndex) == nvinfer1::DataType::kFLOAT); 247 | const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME); 248 | assert(engine.getBindingDataType(outputIndex) == nvinfer1::DataType::kFLOAT); 249 | // int mBatchSize = engine.getMaxBatchSize(); 250 | 251 | // Create GPU buffers on device 252 | CHECK(cudaMalloc(&buffers[inputIndex], input_shape * sizeof(float))); 253 | CHECK(cudaMalloc(&buffers[outputIndex], output_size*sizeof(float))); 254 | 255 | // Create stream 256 | cudaStream_t stream; 257 | CHECK(cudaStreamCreate(&stream)); 258 | 259 | // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host 260 | CHECK(cudaMemcpyAsync(buffers[inputIndex], input, input_shape * sizeof(float), cudaMemcpyHostToDevice, stream)); 261 | context.enqueueV2(buffers, stream, nullptr); 262 | CHECK(cudaMemcpyAsync(output, buffers[outputIndex], output_size * sizeof(float), cudaMemcpyDeviceToHost, stream)); 263 | cudaStreamSynchronize(stream); 264 | 265 | // Release stream and buffers 266 | cudaStreamDestroy(stream); 267 | CHECK(cudaFree(buffers[inputIndex])); 268 | CHECK(cudaFree(buffers[outputIndex])); 269 | } 270 | 271 | int main(int argc, char** argv) { 272 | cudaSetDevice(DEVICE); 273 | // create a model using the API directly and serialize it to a stream 274 | char *trtModelStream{nullptr}; 275 | size_t size{0}; 276 | 277 | if (argc == 4 && std::string(argv[2]) == "-i") { 278 | const std::string engine_file_path {argv[1]}; 279 | std::ifstream file(engine_file_path, std::ios::binary); 280 | if (file.good()) { 281 | file.seekg(0, file.end); 282 | size = file.tellg(); 283 | file.seekg(0, file.beg); 284 | trtModelStream = new char[size]; 285 | assert(trtModelStream); 286 | file.read(trtModelStream, size); 287 | file.close(); 288 | } 289 | } else { 290 | std::cerr << "arguments not right!" << std::endl; 291 | std::cerr << "./yolov7 ../model_trt.engine -i ../*.jpg // deserialize file and run inference" << std::endl; 292 | return -1; 293 | } 294 | const std::string input_image_path {argv[3]}; 295 | 296 | //std::vector file_names; 297 | //if (read_files_in_dir(argv[2], file_names) < 0) { 298 | //std::cout << "read_files_in_dir failed." << std::endl; 299 | //return -1; 300 | //} 301 | 302 | IRuntime* runtime = createInferRuntime(gLogger); 303 | assert(runtime != nullptr); 304 | ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size); 305 | assert(engine != nullptr); 306 | IExecutionContext* context = engine->createExecutionContext(); 307 | assert(context != nullptr); 308 | delete[] trtModelStream; 309 | // auto out_dims = engine->getBindingDimensions(1); 310 | 311 | int input_size = 1*3*960*960; 312 | int output_size = 57001 * 1*1; 313 | static float* prob = new float[output_size]; 314 | 315 | cv::Mat img = cv::imread(input_image_path); 316 | // int img_w = img.cols; 317 | // int img_h = img.rows; 318 | cv::Mat pr_img = static_resize(img); 319 | 320 | float* blob; 321 | blob = blobFromImage(pr_img); 322 | 323 | // run inference 324 | auto start = std::chrono::system_clock::now(); 325 | doInference(*context, blob, prob, output_size, input_size); 326 | auto end = std::chrono::system_clock::now(); 327 | std::cout << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; 328 | 329 | // decode 330 | std::map> objects_map; 331 | postprocess_decode(prob, CONF_THRESH, objects_map); 332 | 333 | // NMS 334 | std::vector objects; 335 | nms(objects_map, objects, NMS_THRESH); 336 | 337 | std::cout << "NMS之后目标数: " << objects.size() << std::endl; 338 | 339 | for (size_t j = 0; j < objects.size(); j++) { 340 | cv::Rect obj_bbox = get_bbox(img, objects[j].bbox); 341 | 342 | std::cout << "x: " << objects[j].bbox[0] << " y:" << objects[j].bbox[1] << " w:" << objects[j].bbox[2] << " h:" << objects[j].bbox[3] << std::endl; 343 | cv::rectangle(img, obj_bbox, cv::Scalar(0x27, 0xC1, 0x36), 1); 344 | cv::putText(img, categories[(int)(objects[j].class_id)] + std::to_string(objects[j].conf), cv::Point(obj_bbox.x, obj_bbox.y - 1), cv::FONT_HERSHEY_PLAIN, 1.0, cv::Scalar(0, 0, 255), 1); 345 | 346 | get_kpts(img, objects[j].kpts); 347 | for (size_t i = 0; i < KEY_POINTS_NUM; i++) 348 | { 349 | const unsigned char* kpt_color = pose_kpt_color[i]; 350 | cv::circle(img, cv::Point((int)(objects[j].kpts[i].x), (int)(objects[j].kpts[i].y)), 2, cv::Scalar(kpt_color[0], kpt_color[1], kpt_color[2]), -1); 351 | } 352 | for (size_t i = 0; i < 19; i++) 353 | { 354 | int x1 = (int)(objects[j].kpts[skeleton[i][0]-1].x); 355 | int y1 = (int)(objects[j].kpts[skeleton[i][0]-1].y); 356 | int x2 = (int)(objects[j].kpts[skeleton[i][1]-1].x); 357 | int y2 = (int)(objects[j].kpts[skeleton[i][1]-1].y); 358 | 359 | const unsigned char* limb_color = pose_limb_color[i]; 360 | cv::line(img, cv::Point(x1,y1), cv::Point(x2,y2), cv::Scalar(limb_color[0], limb_color[1], limb_color[2]), 1); 361 | } 362 | 363 | } 364 | cv::imwrite("../_result.jpg", img); 365 | 366 | 367 | // delete the pointer to the float 368 | delete blob; 369 | // destroy the engine 370 | cudaFreeHost(context); 371 | cudaFreeHost(engine); 372 | cudaFreeHost(runtime); 373 | 374 | return 0; 375 | } 376 | --------------------------------------------------------------------------------