├── README.md ├── common.hpp ├── gpu_func.cu ├── gpu_func.cuh ├── main.cpp ├── mat_transform.hpp ├── trt_bisenet.cpp ├── trt_bisenet.h ├── trt_onnx_base.cpp └── trt_onnx_base.h /README.md: -------------------------------------------------------------------------------- 1 | # TensorRT c++ BiSeNetV1/2 2 | 3 | 使用TensorRT c++实现 [BiSeNetV1](https://arxiv.org/abs/1808.00897) 和 [BiSeNetV2](https://arxiv.org/abs/1808.00897)部署,640x640输入,帧率可以达到110帧左右 4 | 5 | ## 环境依赖 6 | 7 | 1. Opencv3.1 8 | 2. TensorRT7.2 9 | 3. Cuda10.2 10 | 11 | ## 代码介绍 12 | 13 | 1. trt_bisenet.h与trt_bisenet.cpp为主要实现代码,主要包括onnx->tensorrt生成.engine模型、预处理、前向传播、后处理等步骤 14 | 2. mat_transform.hpp中为图像预处理相关算法。预处理算法根据项目的不同会稍有区别,这里是我的预处理。 15 | 3. gpu_func.cuh与gpu_func.cu为后处理gpu代码。我实现了cpu版与gpu版后处理,这里是gpu版后处理的实现 16 | 17 | ## 主要接口 18 | 1. PreProcessCpu : 数据预处理cpu代码,在cpu上预处理数据 19 | 2. ProProcessGPU : 数据预处理cuda代码,在gpu设备上预处理数据 20 | 3. PostProcessCpu :算法后处理cpu代码 21 | 4. PostProcessGpu : 算法后处理cuda代码 22 | 5. Extract : 算法执行接口,对外接口 23 | 24 | ## 使用方法 25 | 26 | 具体使用方法可以参照trt_bisenet.cpp中的main()函数。如果初次调用,需要指定onnx模型的路径、生成的trt模型的保存路径以及保存的模型名。 27 | 初次调用之后会生成.engine的trt模型,并保存到指定位置,之后再调用,则直接调用.engine模型。 28 | ``` 29 | # onnx->tensorrt所需要的主要参数 30 | OnnxInitParam params; 31 | # onnx模型路径 32 | params.onnx_model_path = "./BiSeNet/checkpoints/onnx/bisenet.onnx"; 33 | # 保存生成.engine的模型路径 34 | params.rt_stream_path = "./" 35 | params.rt_model_name = "bisenet.engine" 36 | # 是否使用半精度,如果false,则使用fp32精度的 37 | params.use_fp16 = true; 38 | # 使用的显卡设备 39 | params.gpu_id = 0; 40 | # 模型分割的类别数 41 | params.num_classes = 4; 42 | # 设置最大网络输入大小,用于分配内存(显存), 这里需要根据自己的需求设置 43 | params.max_shape = Shape(1, 3, 640, 640); 44 | 45 | # 实例化BiSeNet类,其中会进行模型转化和一些初始化的操作 46 | BiSeNet model(params); 47 | 48 | # 模型前向推理,得到分割的输出,输出为uint8单通道Mat型图像数据,像素值从0~num_classes-1,代表像素的类别 49 | cv::Mat res = model.Extract(img); 50 | ``` -------------------------------------------------------------------------------- /common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef COMMON_H_ 2 | #define COMMON_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | 11 | class Shape 12 | { 13 | public: 14 | Shape() :num_(0), channels_(0), height_(0), width_(0) {} 15 | Shape(int num, int channels, int height, int width) : 16 | num_(num), channels_(channels), height_(height), width_(width) {} 17 | ~Shape() {} 18 | 19 | public: 20 | void Reshape(int num, int channels, int height, int width) 21 | { 22 | num_ = num; 23 | channels_ = channels; 24 | height_ = height; 25 | width_ = width; 26 | } 27 | 28 | public: 29 | inline int num() const 30 | { 31 | return num_; 32 | } 33 | inline int channels() const 34 | { 35 | return channels_; 36 | } 37 | inline int height() const 38 | { 39 | return height_; 40 | } 41 | inline int width() const 42 | { 43 | return width_; 44 | } 45 | inline int count() const 46 | { 47 | return num_ * channels_ * height_ * width_; 48 | } 49 | private: 50 | int num_; 51 | int channels_; 52 | int height_; 53 | int width_; 54 | }; 55 | 56 | class Tensor2VecMat 57 | { 58 | public: 59 | Tensor2VecMat() {} 60 | vector operator()(float* h_src, const Shape& input_shape) 61 | { 62 | vector input_channels; 63 | int channels = input_shape.channels(); 64 | int height = input_shape.height(); 65 | int width = input_shape.width(); 66 | 67 | for (int i = 0; i < channels; i++) 68 | { 69 | cv::Mat channel(height, width, CV_32FC1, h_src); 70 | input_channels.push_back(channel); 71 | h_src += height * width; 72 | } 73 | return std::move(input_channels); 74 | } 75 | }; 76 | 77 | #endif -------------------------------------------------------------------------------- /gpu_func.cu: -------------------------------------------------------------------------------- 1 | #include "gpu_func.cuh" 2 | #include 3 | 4 | __constant__ float const_mean[3]; 5 | __constant__ float const_std[3]; 6 | 7 | __device__ void softmax(float* src, int channels) 8 | { 9 | float tol = 0.0; 10 | for (int i = 0; i < channels; i++) 11 | { 12 | src[i] = exp(src[i]); 13 | tol += src[i]; 14 | } 15 | for (int i = 0; i < channels; i++) 16 | src[i] = src[i] / tol; 17 | } 18 | 19 | __global__ void kernel_segmentation_get_logits(float* ptr, int channels, int height, int width) 20 | { 21 | int x = blockIdx.x * blockDim.x + threadIdx.x; 22 | int y = blockIdx.y * blockDim.y + threadIdx.y; 23 | 24 | int offset = height * width; 25 | 26 | float tol = 0.0; 27 | for (int c = 0; c < channels; c++) 28 | { 29 | float tmp = ptr[y * width + x + c * offset]; 30 | tmp = std::exp(tmp); 31 | tol += tmp; 32 | ptr[y * width + x + c * offset] = tmp; 33 | } 34 | 35 | for (int c = 0; c < channels; c++) 36 | { 37 | ptr[y * width + x + c * offset] /= tol; 38 | } 39 | } 40 | 41 | __global__ void kernel_segmentation_get_cls(float* ptr, int channels, 42 | int height, int width, unsigned char* dev_dst) 43 | { 44 | int x = blockIdx.x * blockDim.x + threadIdx.x; 45 | int y = blockIdx.y * blockDim.y + threadIdx.y; 46 | int offset = height * width; 47 | 48 | int idx = 0; 49 | float max_val = FLT_MIN; 50 | for (int c = 0; c < channels; c++) 51 | { 52 | float tmp = ptr[y * width + x + c * offset]; 53 | if (tmp > max_val) 54 | { 55 | max_val = tmp; 56 | idx = c; 57 | } 58 | } 59 | dev_dst[y * width + x] = idx; 60 | } 61 | 62 | void segmentation(float* src_ptr, int channels, int height, int width, float* cpu_dst) 63 | { 64 | dim3 grid(32, 32); 65 | dim3 blocks(int(width / 32), int(height / 32)); 66 | 67 | kernel_segmentation_get_logits << > > (src_ptr, channels, height, width); 68 | 69 | cudaMemcpy(cpu_dst, src_ptr, channels * height * width * sizeof(float), cudaMemcpyDeviceToHost); 70 | } 71 | 72 | void segmentation(float* src_ptr, int channels, int height, int width, unsigned char* cpu_dst) 73 | { 74 | dim3 grid(int(width / 32), int(height / 32)); 75 | dim3 blocks(32, 32); 76 | unsigned char* dev_dst; 77 | cudaMalloc((void**)&dev_dst, height * width * sizeof(unsigned char)); 78 | kernel_segmentation_get_cls << > > (src_ptr, channels, height, width, dev_dst); 79 | 80 | cudaMemcpy(cpu_dst, dev_dst, height * width * sizeof(unsigned char), cudaMemcpyDeviceToHost); 81 | } 82 | 83 | __global__ void kernel_bilinear_resize(float* dst_ptr, int channels, int src_h, int src_w, 84 | int dst_h, int dst_w, int pad_h, int pad_w, float r, uchar* src_ptr) 85 | { 86 | int x = blockIdx.x * blockDim.x + threadIdx.x; 87 | int y = blockIdx.y * blockDim.y + threadIdx.y; 88 | int offset = y * dst_w + x; 89 | 90 | if (x >= dst_w - pad_w || y >= dst_h - pad_h) 91 | { 92 | for (int c = 0; c < channels; c++) 93 | { 94 | float value = 114. / 255.; 95 | dst_ptr[offset + (channels - c - 1) * dst_h * dst_w] = value; 96 | } 97 | return; 98 | } 99 | 100 | float src_x = (x + 0.5) / r - 0.5; 101 | float src_y = (y + 0.5) / r - 0.5; 102 | 103 | int src_x_0 = int((src_x)); 104 | int src_y_0 = int((src_y)); 105 | 106 | int src_x_1 = src_x_0 + 1 < src_w - 1 ? src_x_0 + 1 : src_w - 1; 107 | int src_y_1 = src_y_0 + 1 < src_h - 1 ? src_y_0 + 1 : src_h - 1; 108 | for (int c = 0; c < channels; c++) 109 | { 110 | 111 | unsigned char v00 = src_ptr[(src_y_0 * src_w + src_x_0) * channels + c]; 112 | unsigned char v01 = src_ptr[(src_y_0 * src_w + src_x_1) * channels + c]; 113 | float value0 = (src_x_1 - src_x) * float(v00) + (src_x - src_x_0) * float(v01); 114 | 115 | unsigned char v10 = src_ptr[(src_y_1 * src_w + src_x_0) * channels + c]; 116 | unsigned char v11 = src_ptr[(src_y_1 * src_w + src_x_1) * channels + c]; 117 | float value1 = (src_x_1 - src_x) * float(v10) + (src_x - src_x_0) * float(v11); 118 | 119 | float value = (src_y_1 - src_y) * value0 + (src_y - src_y_0) * value1; 120 | 121 | dst_ptr[offset + (channels - c - 1) * dst_h * dst_w] = value; 122 | } 123 | } 124 | 125 | __global__ void kernel_normalize(float* dst_ptr, int channels, int h, int w) // , float *mean, float* std 126 | { 127 | int x = blockIdx.x * blockDim.x + threadIdx.x; 128 | int y = blockIdx.y * blockDim.y + threadIdx.y; 129 | 130 | int offset = y * w + x; 131 | for (int c = 0; c < channels; c++) 132 | { 133 | float v = dst_ptr[c * h * w + offset] / 255.0; 134 | v = v - const_mean[c]; 135 | v = v / const_std[c]; 136 | dst_ptr[c * h * w + offset] = v; 137 | } 138 | } 139 | 140 | void biliresize_normalize(float* dst_ptr, int channels, int src_h, int src_w, 141 | int dst_h, int dst_w, int pad_h, int pad_w, float r, 142 | uchar* src_ptr, float* mean, float* std) 143 | { 144 | uchar* src_dev; 145 | cudaMalloc((void**)&src_dev, src_h * src_w * channels * sizeof(uchar)); 146 | cudaMemcpy(src_dev, src_ptr, src_h * src_w * channels * sizeof(uchar), cudaMemcpyHostToDevice); 147 | 148 | cudaMemcpyToSymbol(const_mean, mean, channels * sizeof(float)); 149 | 150 | cudaMemcpyToSymbol(const_std, std, channels * sizeof(float)); 151 | 152 | 153 | dim3 grids(int(dst_w / 32), int(dst_h / 32)); 154 | dim3 blocks(32, 32); 155 | 156 | kernel_bilinear_resize << > > (dst_ptr, channels, src_h, src_w, 157 | dst_h, dst_w, pad_h, pad_w, r, src_dev); 158 | kernel_normalize << > > (dst_ptr, channels, dst_h, dst_w); 159 | 160 | cudaFree(src_dev); 161 | } -------------------------------------------------------------------------------- /gpu_func.cuh: -------------------------------------------------------------------------------- 1 | #ifndef GPU_FUNC_H_ 2 | #define GPU_FUNC_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | 11 | typedef unsigned char uchar; 12 | 13 | __device__ void softmax(float *src, int channels); 14 | 15 | /* 分割核函数,获得每一个像素点的预测概率 16 | src_ptr: 输入数据,模型输出的预测结果 17 | channels: 模型输出通道数(类别数) 18 | height: 高 19 | width: 宽 20 | */ 21 | __global__ void kernel_segmentation_get_logits(float* src_ptr, int channels, int height, int width); 22 | 23 | /* 分割核函数,获得每一个像素的分类结果 24 | src_ptr: 输入数据,模型输出的预测结果 25 | src_ptr: 输入数据,模型输出的预测结果 26 | channels: 模型输出通道数(类别数) 27 | height: 高 28 | width: 宽 29 | dev_dst: 保存的分割结果 30 | */ 31 | __global__ void kernel_segmentation_get_cls(float* src_ptr, int channels, int height, int width, unsigned char *dev_dst); 32 | 33 | /* 双线性插值核函数 34 | */ 35 | __global__ void kernel_bilinear_resize(float* dst_ptr, int channels, int src_h, int src_w, int dst_h, int dst_w, 36 | int pad_h, int pad_w, float r, unsigned char *src_ptr); 37 | 38 | /* 归一化核函数 39 | */ 40 | __global__ void kernel_normalize(float* dst_ptr, int channels, int src_h, int src_w); // , float *mean, float* std 41 | 42 | /* kernel_segmentation_get_logits核函数调用接口 43 | */ 44 | void segmentation(float* src_ptr, int channels, int height, int width, float* cpu_dst); 45 | 46 | /* kernel_segmentation_get_cls核函数调用接口 47 | */ 48 | void segmentation(float* src_ptr, int channels, int height, int width, unsigned char *cpu_dst); 49 | 50 | /* 双线性与归一化对外接口 51 | */ 52 | void biliresize_normalize(float* dst_ptr, int channels, int src_h, int src_w, int dst_h, int dst_w, 53 | int pad_h, int pad_w, float r, unsigned char *src_ptr, float *mean, float* std); 54 | 55 | #endif -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "trt_bisenet.h" 2 | 3 | int main(int argc, char** argv) 4 | { 5 | OnnxInitParam params; 6 | params.onnx_model_path = "E:/BaiduNetdiskDownload/BiSeNetv3/checkpoints/onnx/bisenetv3.onnx"; 7 | params.use_fp16 = true; 8 | params.gpu_id = 0; 9 | params.num_classes = 4; 10 | params.max_shape = Shape(1, 3, 640, 640); // 设置最大网络输入大小,用于分配内存(显存),根据自己项目需要设置 11 | 12 | BiSeNet model(params); 13 | 14 | cv::Mat img = cv::imread("E:/BaiduNetdiskDownload/BiSeNetv3/datas/tupian.jpg"); 15 | 16 | cv::Mat res = model.Extract(img); 17 | cv::imshow("res", res); 18 | 19 | res = model.Extract(img); 20 | cv::imshow("res1", res); 21 | cv::waitKey(); 22 | } -------------------------------------------------------------------------------- /mat_transform.hpp: -------------------------------------------------------------------------------- 1 | #ifndef MAT_TRANSFORM_H_ 2 | #define MAT_TRANSFORM_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace std; 10 | 11 | class ComposeMatLambda 12 | { 13 | public: 14 | using FuncionType = std::function; 15 | 16 | ComposeMatLambda() = default; 17 | ComposeMatLambda(const vector& lambda) :lambda_(lambda) 18 | { 19 | ; 20 | } 21 | cv::Mat operator()(cv::Mat& img) 22 | { 23 | for (auto func : lambda_) 24 | img = func(img); 25 | return img; 26 | } 27 | private: 28 | vector lambda_; 29 | }; 30 | 31 | class MatDivConstant 32 | { 33 | public: 34 | MatDivConstant(int constant) :constant_(constant) {} 35 | cv::Mat operator()(const cv::Mat& img) 36 | { 37 | cv::Mat tmp; 38 | img.convertTo(tmp, CV_32FC3, 1, 0); 39 | tmp = tmp / constant_; 40 | return move(tmp); 41 | } 42 | 43 | private: 44 | float constant_; 45 | }; 46 | 47 | class MatNormalize 48 | { 49 | public: 50 | MatNormalize(vector& mean, vector std) : mean_(mean), std_(std) {} 51 | cv::Mat operator()(const cv::Mat& img) 52 | { 53 | cv::Mat img_float; 54 | if (img.type() == CV_32FC3) 55 | img_float = img; 56 | else if (img_float.type() == CV_8UC3) 57 | img.convertTo(img_float, CV_32FC3); 58 | else 59 | { 60 | assert(0), "img type is error"; 61 | } 62 | 63 | int width = img_float.cols; 64 | int height = img_float.rows; 65 | 66 | cv::Mat mean = cv::Mat(cv::Size(width, height), 67 | CV_32FC3, cv::Scalar(mean_[0], mean_[1], mean_[2])); 68 | cv::Mat std = cv::Mat(cv::Size(width, height), 69 | CV_32FC3, cv::Scalar(std_[0], std_[1], std_[2])); 70 | 71 | cv::Mat sample_sub; 72 | cv::subtract(img_float, mean, sample_sub); 73 | cv::Mat sample_normalized = sample_sub / std; 74 | return move(sample_normalized); 75 | } 76 | private: 77 | vector mean_; 78 | vector std_; 79 | }; 80 | 81 | class LetterResize 82 | { 83 | public: 84 | LetterResize(cv::Size& new_shape = cv::Size(640, 640), 85 | cv::Scalar& color = cv::Scalar(114, 114, 114), 86 | int stride = 32) :new_shape_(new_shape), color_(color), stride_(stride) {} 87 | 88 | cv::Mat operator()(const cv::Mat& img) 89 | { 90 | int img_h = img.rows; 91 | int img_w = img.cols; 92 | 93 | int shape_h = new_shape_.height; 94 | int shape_w = new_shape_.width; 95 | 96 | double r = std::min(double(shape_h) / double(img_h), double(shape_w) / double(img_w)); 97 | cout << "r: " << r << endl; 98 | 99 | 100 | cv::Size new_unpad = cv::Size(int(round(r * img_w)), int(round(r * img_h))); 101 | 102 | int dw = new_shape_.width - new_unpad.width; 103 | int dh = new_shape_.height - new_unpad.height; 104 | dw = dw % stride_; 105 | dh = dh % stride_; 106 | 107 | // dw /= 2; 108 | // dh /= 2; 109 | 110 | cv::Mat resize_mat; 111 | if (img.rows != new_unpad.height || img.cols != new_unpad.width) 112 | cv::resize(img, resize_mat, new_unpad, 0, 0, cv::INTER_LINEAR); 113 | int top = 0; // int(round(dh - 0.1)); 114 | int bottom = int(dh); // int(round(dh + 0.1)); 115 | int left = 0; // int(round(dw - 0.1)); 116 | int right = int(dw); // int(round(dw + 0.1)); 117 | cv::Mat pad_mat; 118 | cv::copyMakeBorder(resize_mat, pad_mat, top, bottom, left, right, cv::BORDER_CONSTANT, color_); 119 | 120 | return std::move(pad_mat); 121 | } 122 | 123 | private: 124 | cv::Size new_shape_; 125 | cv::Scalar color_; 126 | int stride_; 127 | }; 128 | 129 | #endif // ! MAT_TRANSFORM_H_ 130 | -------------------------------------------------------------------------------- /trt_bisenet.cpp: -------------------------------------------------------------------------------- 1 | #include "trt_bisenet.h" 2 | #include "mat_transform.hpp" 3 | #include "gpu_func.cuh" 4 | 5 | BiSeNet::BiSeNet(const OnnxInitParam& params) : TRTOnnxBase(params) 6 | { 7 | } 8 | 9 | cv::Mat BiSeNet::Extract(const cv::Mat& img) 10 | { 11 | if (img.empty()) 12 | return img; 13 | 14 | std::lock_guard lock(mtx_); 15 | /*PreProcessCpu(img);*/ 16 | ProProcessGPU(img); 17 | Forward(); 18 | 19 | /*cv::Mat res = PostProcessCpu();*/ 20 | cv::Mat res = PostProcessGpu(); 21 | return std::move(res); 22 | } 23 | 24 | BiSeNet::~BiSeNet() 25 | { 26 | } 27 | 28 | void BiSeNet::PreProcessCpu(const cv::Mat& img) 29 | { 30 | cv::Mat img_tmp = img; 31 | 32 | ComposeMatLambda compose({ 33 | LetterResize(cv::Size(crop_size_, crop_size_), cv::Scalar(114, 114, 114), 32), 34 | MatDivConstant(255), 35 | MatNormalize(mean_, std_), 36 | }); 37 | 38 | cv::Mat sample_float = compose(img_tmp); 39 | input_shape_.Reshape(1, sample_float.channels(), sample_float.rows, sample_float.cols); 40 | output_shape_.Reshape(1, _params.num_classes, sample_float.rows, sample_float.cols); 41 | 42 | Tensor2VecMat tensor_2_mat; 43 | std::vector channels = tensor_2_mat(h_input_tensor_, input_shape_); 44 | cv::split(sample_float, channels); 45 | 46 | cudaMemcpy(d_input_tensor_, h_input_tensor_, input_shape_.count() * sizeof(float), 47 | cudaMemcpyHostToDevice); 48 | } 49 | 50 | void BiSeNet::ProProcessGPU(const cv::Mat& img) 51 | { 52 | int src_h = img.rows; 53 | int src_w = img.cols; 54 | int channels = img.channels(); 55 | 56 | float r = MIN(float(crop_size_) / src_h, float(crop_size_) / src_w); 57 | 58 | int dst_h = int(r * src_h); 59 | int dst_w = int(r * src_w); 60 | 61 | int pad_h = (crop_size_ - dst_h) % stride_; 62 | int pad_w = (crop_size_ - dst_w) % stride_; 63 | 64 | dst_h += pad_h; 65 | dst_w += pad_w; 66 | 67 | input_shape_.Reshape(1, channels, dst_h, dst_w); 68 | output_shape_.Reshape(1, _params.num_classes, dst_h, dst_w); 69 | 70 | biliresize_normalize(d_input_tensor_, channels, src_h, src_w, dst_h, dst_w, 71 | pad_h, pad_w, r, img.data, mean_.data(), std_.data()); 72 | } 73 | 74 | cv::Mat BiSeNet::PostProcessCpu() 75 | { 76 | int num = output_shape_.num(); 77 | int channels = output_shape_.channels(); 78 | int height = output_shape_.height(); 79 | int width = output_shape_.width(); 80 | int count = output_shape_.count(); 81 | 82 | cudaMemcpy(h_output_tensor_, d_output_tensor_, 83 | count * sizeof(float), cudaMemcpyDeviceToHost); 84 | 85 | cv::Mat res = cv::Mat::zeros(height, width, CV_8UC1); 86 | for (int row = 0; row < height; row++) 87 | { 88 | for (int col = 0; col < width; col++) 89 | { 90 | vector vec; 91 | for (int c = 0; c < channels; c++) 92 | { 93 | int index = row * width + col + c * height * width; 94 | float val = h_output_tensor_[index]; 95 | vec.push_back(val); 96 | } 97 | 98 | int idx = findMaxIdx(vec); 99 | if (idx == -1) 100 | continue; 101 | res.at(row, col) = uchar(idx); 102 | } 103 | } 104 | 105 | return std::move(res); 106 | } 107 | 108 | cv::Mat BiSeNet::PostProcessGpu() 109 | { 110 | int num = output_shape_.num(); 111 | int channels = output_shape_.channels(); 112 | int height = output_shape_.height(); 113 | int width = output_shape_.width(); 114 | 115 | unsigned char* cpu_dst; 116 | cudaHostAlloc((void**)&cpu_dst, height * width * sizeof(float), cudaHostAllocDefault); 117 | //==> segmentation(output_tensor_, channels, height, width, cpu_dst); 118 | segmentation(d_output_tensor_, channels, height, width, cpu_dst); 119 | 120 | cv::Mat res = cv::Mat(height, width, CV_8UC1, cpu_dst); 121 | 122 | cudaFreeHost(cpu_dst); 123 | return std::move(res); 124 | } 125 | 126 | void BiSeNet::softmax(vector& vec) 127 | { 128 | float tol = 0.0; 129 | for (int i = 0; i < vec.size(); i++) 130 | { 131 | vec[i] = exp(vec[i]); 132 | tol += vec[i]; 133 | } 134 | 135 | for (int i = 0; i < vec.size(); i++) 136 | vec[i] = vec[i] / tol; 137 | } 138 | 139 | int BiSeNet::findMaxIdx(const vector& vec) 140 | { 141 | if (vec.empty()) 142 | return -1; 143 | auto pos = max_element(vec.begin(), vec.end()); 144 | return std::distance(vec.begin(), pos); 145 | } -------------------------------------------------------------------------------- /trt_bisenet.h: -------------------------------------------------------------------------------- 1 | #ifndef TRT_BISENET_H_ 2 | #define TRT_BISENET_H_ 3 | 4 | #include "trt_onnx_base.h" 5 | #include 6 | 7 | #define MIN(x, y) (x) < (y) ? (x) : (y) 8 | 9 | using namespace std; 10 | 11 | class BiSeNet : public TRTOnnxBase 12 | { 13 | public: 14 | BiSeNet() = delete; 15 | BiSeNet(const OnnxInitParam& params); 16 | 17 | virtual ~BiSeNet(); 18 | 19 | cv::Mat Extract(const cv::Mat& img); 20 | 21 | private: 22 | // cpu预处理 23 | void PreProcessCpu(const cv::Mat& img); 24 | // gpu预处理 25 | void ProProcessGPU(const cv::Mat& img); 26 | 27 | // cpu后处理 28 | cv::Mat PostProcessCpu(); 29 | // gpu后处理 30 | cv::Mat PostProcessGpu(); 31 | 32 | // softmax函数 33 | static void softmax(vector& vec); 34 | static int findMaxIdx(const vector& vec); 35 | 36 | private: 37 | int crop_size_ = 640; 38 | int stride_ = 32; 39 | 40 | std::vector mean_{ 0.485, 0.456, 0.406 }; 41 | std::vector std_{ 0.229, 0.224, 0.225 }; 42 | 43 | std::mutex mtx_; 44 | }; 45 | 46 | #endif -------------------------------------------------------------------------------- /trt_onnx_base.cpp: -------------------------------------------------------------------------------- 1 | #include "trt_onnx_base.h" 2 | 3 | TRTOnnxBase::TRTOnnxBase(const OnnxInitParam& params) : _params(params) 4 | { 5 | cudaSetDevice(params.gpu_id); 6 | 7 | cudaStreamCreate(&stream_); 8 | 9 | Initial(); 10 | } 11 | 12 | void TRTOnnxBase::Initial() 13 | { 14 | if (CheckFileExist(_params.rt_stream_path + _params.rt_model_name)) 15 | { 16 | std::cout << "read rt model..." << std::endl; 17 | LoadGieStreamBuildContext(_params.rt_stream_path + _params.rt_model_name); 18 | } 19 | else 20 | LoadOnnxModel(); 21 | } 22 | 23 | void TRTOnnxBase::LoadOnnxModel() 24 | { 25 | if (!CheckFileExist(_params.onnx_model_path)) 26 | { 27 | cout << "onnx_model_path: " << _params.onnx_model_path << endl; 28 | std::cerr << "onnx file is not found " << _params.onnx_model_path << std::endl; 29 | exit(0); 30 | } 31 | 32 | nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger); 33 | assert(builder != nullptr); 34 | 35 | const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 36 | nvinfer1::INetworkDefinition* network = builder->createNetworkV2(explicitBatch); 37 | 38 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, logger); 39 | assert(parser->parseFromFile(_params.onnx_model_path.c_str(), 2)); 40 | 41 | nvinfer1::IBuilderConfig* build_config = builder->createBuilderConfig(); 42 | nvinfer1::IOptimizationProfile* profile = builder->createOptimizationProfile(); 43 | nvinfer1::ITensor* input = network->getInput(0); 44 | nvinfer1::Dims input_dims = input->getDimensions(); 45 | std::cout << "batch_size: " << input_dims.d[0] 46 | << " channels: " << input_dims.d[1] 47 | << " height: " << input_dims.d[2] 48 | << " width: " << input_dims.d[3] << std::endl; 49 | 50 | { 51 | profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMIN, nvinfer1::Dims4{ 1, input_dims.d[1], 1, 1 }); 52 | profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kOPT, nvinfer1::Dims4{ 1, input_dims.d[1], 640, 480 }); // 640 53 | profile->setDimensions(input->getName(), nvinfer1::OptProfileSelector::kMAX, nvinfer1::Dims4{ 1, input_dims.d[1], 640, 640 }); 54 | build_config->addOptimizationProfile(profile); 55 | } 56 | 57 | build_config->setMaxWorkspaceSize(1 << 30); 58 | if (_params.use_fp16) 59 | { 60 | if (builder->platformHasFastFp16()) 61 | { 62 | builder->setHalf2Mode(true); 63 | std::cout << "useFP16 : " << true << std::endl; 64 | } 65 | } 66 | else 67 | std::cout << "Using GPU FP32 !" << std::endl; 68 | 69 | nvinfer1::ICudaEngine* engine = builder->buildEngineWithConfig(*network, *build_config); 70 | assert(engine != nullptr); 71 | 72 | nvinfer1::IHostMemory* gie_model_stream = engine->serialize(); 73 | SaveRTModel(gie_model_stream, _params.rt_stream_path + _params.rt_model_name); 74 | 75 | deserializeCudaEngine(gie_model_stream->data(), gie_model_stream->size()); 76 | 77 | builder->destroy(); 78 | network->destroy(); 79 | parser->destroy(); 80 | build_config->destroy(); 81 | engine->destroy(); 82 | } 83 | 84 | void TRTOnnxBase::LoadGieStreamBuildContext(const std::string& gie_file) 85 | { 86 | std::ifstream fgie(gie_file, std::ios_base::in | std::ios_base::binary); 87 | if (!fgie) 88 | { 89 | std::cerr << "Can't read rt model from " << gie_file << std::endl; 90 | return; 91 | } 92 | 93 | std::stringstream buffer; 94 | buffer << fgie.rdbuf(); 95 | 96 | std::string stream_model(buffer.str()); 97 | 98 | deserializeCudaEngine(stream_model.data(), stream_model.size()); 99 | } 100 | 101 | void TRTOnnxBase::deserializeCudaEngine(const void* blob_data, std::size_t size) 102 | { 103 | _runtime = nvinfer1::createInferRuntime(logger); 104 | assert(_runtime != nullptr); 105 | _engine = _runtime->deserializeCudaEngine(blob_data, size, nullptr); 106 | assert(_engine != nullptr); 107 | 108 | _context = _engine->createExecutionContext(); 109 | assert(_context != nullptr); 110 | 111 | mallocInputOutput(); 112 | } 113 | 114 | void TRTOnnxBase::mallocInputOutput() 115 | { 116 | int in_counts = _params.max_shape.count(); 117 | cudaHostAlloc((void**)&h_input_tensor_, in_counts * sizeof(float), cudaHostAllocDefault); 118 | cudaMalloc((void**)&d_input_tensor_, in_counts * sizeof(float)); 119 | 120 | int out_counts = _params.max_shape.num() * _params.num_classes * 121 | _params.max_shape.height() * _params.max_shape.width(); 122 | cudaHostAlloc((void**)&h_output_tensor_, out_counts * sizeof(float), cudaHostAllocDefault); 123 | cudaMalloc((void**)&d_output_tensor_, out_counts * sizeof(float)); 124 | 125 | buffer_queue_.push_back(d_input_tensor_); 126 | buffer_queue_.push_back(d_output_tensor_); 127 | } 128 | 129 | void TRTOnnxBase::SaveRTModel(nvinfer1::IHostMemory* gie_model_stream, const std::string& path) 130 | { 131 | std::ofstream outfile(path, std::ios_base::out | std::ios_base::binary); 132 | outfile.write((const char*)gie_model_stream->data(), gie_model_stream->size()); 133 | outfile.close(); 134 | } 135 | 136 | void TRTOnnxBase::Forward() 137 | { 138 | nvinfer1::Dims4 input_dims{ 1, input_shape_.channels(), 139 | input_shape_.height(), input_shape_.width() }; 140 | _context->setBindingDimensions(0, input_dims); 141 | _context->enqueueV2(buffer_queue_.data(), stream_, nullptr); 142 | 143 | cudaStreamSynchronize(stream_); 144 | } -------------------------------------------------------------------------------- /trt_onnx_base.h: -------------------------------------------------------------------------------- 1 | #ifndef TRT_ONNX_BASE_H_ 2 | #define TRT_ONNX_BASE_H_ 3 | 4 | #include 5 | #include 6 | #include "NvOnnxParser.h" 7 | #include "NvInfer.h" 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include "common.hpp" 13 | 14 | using namespace std; 15 | 16 | struct OnnxInitParam 17 | { 18 | std::string onnx_model_path; 19 | std::string rt_stream_path = "./"; 20 | std::string rt_model_name = "bisenetv3.engine"; 21 | bool use_fp16 = false; 22 | int gpu_id = 0; 23 | int num_classes; 24 | Shape max_shape{ 1, 3, 640, 640 }; 25 | }; 26 | 27 | class TRTOnnxBase 28 | { 29 | public: 30 | TRTOnnxBase() = delete; 31 | TRTOnnxBase(const OnnxInitParam& params); 32 | 33 | protected: 34 | class Logger : public nvinfer1::ILogger 35 | { 36 | public: 37 | void log(nvinfer1::ILogger::Severity severity, const char* msg) 38 | { 39 | switch (severity) 40 | { 41 | case nvinfer1::ILogger::Severity::kINTERNAL_ERROR: 42 | std::cerr << "kINTERNAL_ERROR: " << msg << std::endl; 43 | break; 44 | case nvinfer1::ILogger::Severity::kERROR: 45 | std::cerr << "kERROR: " << msg << std::endl; 46 | break; 47 | case nvinfer1::ILogger::Severity::kWARNING: 48 | std::cerr << "kWARNING: " << msg << std::endl; 49 | break; 50 | case nvinfer1::ILogger::Severity::kINFO: 51 | std::cerr << "kINFO: " << msg << std::endl; 52 | break; 53 | case nvinfer1::ILogger::Severity::kVERBOSE: 54 | std::cerr << "kVERBOSE: " << msg << std::endl; 55 | break; 56 | default: 57 | break; 58 | } 59 | } 60 | }; 61 | 62 | // tensorrt推理 63 | void Forward(); 64 | 65 | private: 66 | // void mallocInputOutput(const Shape &input_shape, const Shape &output_shape); 67 | // 模型初始化 68 | void Initial(); 69 | // 加载onnx模型 70 | void LoadOnnxModel(); 71 | // 加载tensorrt模型 72 | void LoadGieStreamBuildContext(const std::string& gie_file); 73 | // 分配执行预测所需的cpu内存与gpu显存 74 | void mallocInputOutput(); 75 | // 保存序列化的模型 76 | void SaveRTModel(nvinfer1::IHostMemory* gie_model_stream, const std::string& path); 77 | // 反序列化tensorrt模型 78 | void deserializeCudaEngine(const void* blob_data, std::size_t size); 79 | 80 | bool CheckFileExist(const std::string& path) 81 | { 82 | std::ifstream check_file(path); 83 | return check_file.is_open(); 84 | } 85 | 86 | private: 87 | Logger logger; 88 | nvinfer1::IRuntime* _runtime{ nullptr }; 89 | nvinfer1::ICudaEngine* _engine{ nullptr }; 90 | nvinfer1::IExecutionContext* _context{ nullptr }; 91 | 92 | cudaStream_t stream_; 93 | 94 | protected: 95 | std::vector buffer_queue_; 96 | 97 | float* h_input_tensor_; 98 | float* d_input_tensor_; 99 | Shape input_shape_; // 记录每次前向预测的输入样本的shape 100 | Shape output_shape_; // 记录每次前向预测的输出样本的shape 101 | float* h_output_tensor_; 102 | float* d_output_tensor_; 103 | 104 | OnnxInitParam _params; 105 | }; 106 | 107 | #endif --------------------------------------------------------------------------------