├── .gitignore ├── CMakeLists.txt ├── CMakeLists.txt.user ├── README.md ├── RefineNet.cpp ├── image.cpp ├── image.hpp └── image ├── Screenshot from 2020-08-16 13-39-14.png ├── d94be52120f2aa2cfbd7c12f10817b04.jpeg └── result.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | ./vid 2 | ./vid 3 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | CMAKE_MINIMUM_REQUIRED(VERSION 2.8) 2 | 3 | PROJECT(refinenet) 4 | 5 | set(CMAKE_BUILD_TYPE Release) 6 | set(CMAKE_CXX_FLAGS "-std=c++11 -O3") 7 | 8 | file(GLOB headers "${CMAKE_CURRENT_LIST_DIR}/*.h)") 9 | set(exported_headers ${headers}) 10 | 11 | include_directories(${PROJECT_SOURCE_DIR}) 12 | include_directories(/home/kong/TensorRT-5.1.5.0/samples/common) 13 | set(TENSORRT_INCLUDE_DIR /home/kong/TensorRT-5.1.5.0/include) 14 | include_directories(TENSORRT_INCLUDE_DIR) 15 | link_directories(/home/kong/TensorRT-5.1.5.0/lib) 16 | 17 | 18 | find_package(OpenCV 3.4.8 REQUIRED) 19 | include_directories(${OpenCV_INCLUDE_DIRS}) 20 | set(OpenCV_LIBS opencv_core opencv_highgui opencv_imgproc opencv_imgcodecs) 21 | find_package(CUDA REQUIRED) 22 | set(TENSORRT_LIBRARY_INFER /home/kong/TensorRT-5.1.5.0/include) 23 | #set(TENSORRT_LIBRARY_INFER_PLUGIN /home/kong/TensorRT-5.1.5.0/include) 24 | #set(TENSORRT_LIBRARY_PARSER /home/kong/TensorRT-5.1.5.0/include) 25 | #set(TENSORRT_LIBRARY_ONNXPARSER /home/kong/TensorRT-5.1.5.0/include) 26 | 27 | #find_path(TENSORRT_INCLUDE_DIR NvInfer.h 28 | # HINTS ${TENSORRT_ROOT} ${CUDA_TOOLKIT_ROOT_DIR} 29 | # PATH_SUFFIXES include) 30 | #MESSAGE(STATUS "Found TensorRT headers at ${TENSORRT_INCLUDE_DIR}") 31 | #find_library(TENSORRT_LIBRARY_INFER nvinfer 32 | # HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 33 | # PATH_SUFFIXES lib lib64 lib/x64) 34 | #find_library(TENSORRT_LIBRARY_INFER_PLUGIN nvinfer_plugin 35 | # HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 36 | # PATH_SUFFIXES lib lib64 lib/x64) 37 | #find_library(TENSORRT_LIBRARY_PARSER nvparsers 38 | # HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 39 | # PATH_SUFFIXES lib lib64 lib/x64) 40 | #find_library(TENSORRT_LIBRARY_ONNXPARSER nvonnxparser 41 | # HINTS ${TENSORRT_ROOT} ${TENSORRT_BUILD} ${CUDA_TOOLKIT_ROOT_DIR} 42 | # PATH_SUFFIXES lib lib64 lib/x64) 43 | 44 | 45 | set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_PARSER} ${TENSORRT_LIBRARY_ONNXPARSER}) 46 | set(TENSORRT_LIBRARY nvinfer nvinfer_plugin nvparsers nvonnxparser) 47 | MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}") 48 | find_package_handle_standard_args( 49 | TENSORRT DEFAULT_MSG TENSORRT_INCLUDE_DIR TENSORRT_LIBRARY) 50 | if(NOT TENSORRT_FOUND) 51 | message(ERROR 52 | "Cannot find TensorRT library.") 53 | endif() 54 | 55 | 56 | ADD_EXECUTABLE(RefineNet RefineNet.cpp image.cpp ${headers}) 57 | 58 | target_include_directories(RefineNet PUBLIC ${CUDA_INCLUDE_DIRS} ${TENSORRT_INCLUDE_DIR} ${CUDNN_INCLUDE_DIR}) 59 | # my cuda is 10.1 ,if your cuda isn‘t cuda10.1, please change this line 60 | TARGET_LINK_LIBRARIES(RefineNet ${OpenCV_LIBS} ${TENSORRT_LIBRARY} /usr/local/cuda-10.1/lib64/libcudart.so.10.1) 61 | -------------------------------------------------------------------------------- /CMakeLists.txt.user: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | EnvironmentId 7 | {732b7e0e-7c9f-4d54-b38b-a7d9ccadbb1a} 8 | 9 | 10 | ProjectExplorer.Project.ActiveTarget 11 | 0 12 | 13 | 14 | ProjectExplorer.Project.EditorSettings 15 | 16 | true 17 | false 18 | true 19 | 20 | Cpp 21 | 22 | CppGlobal 23 | 24 | 25 | 26 | QmlJS 27 | 28 | QmlJSGlobal 29 | 30 | 31 | 2 32 | UTF-8 33 | false 34 | 4 35 | false 36 | 80 37 | true 38 | true 39 | 1 40 | true 41 | false 42 | 0 43 | true 44 | 0 45 | 8 46 | true 47 | 1 48 | true 49 | true 50 | true 51 | false 52 | 53 | 54 | 55 | ProjectExplorer.Project.PluginSettings 56 | 57 | 58 | 59 | ProjectExplorer.Project.Target.0 60 | 61 | Desktop 62 | Desktop 63 | {367d1864-7c43-4609-907c-d3292055d365} 64 | 0 65 | 0 66 | 0 67 | 68 | false 69 | /home/kong/Documents/RefineNet_TensorRT-build 70 | 71 | 72 | 73 | 74 | false 75 | 76 | true 77 | Make 78 | 79 | CMakeProjectManager.MakeStep 80 | 81 | 1 82 | Build 83 | 84 | ProjectExplorer.BuildSteps.Build 85 | 86 | 87 | 88 | clean 89 | 90 | true 91 | 92 | true 93 | Make 94 | 95 | CMakeProjectManager.MakeStep 96 | 97 | 1 98 | Clean 99 | 100 | ProjectExplorer.BuildSteps.Clean 101 | 102 | 2 103 | false 104 | 105 | all 106 | 107 | CMakeProjectManager.CMakeBuildConfiguration 108 | 109 | 1 110 | 111 | 112 | 0 113 | Deploy 114 | 115 | ProjectExplorer.BuildSteps.Deploy 116 | 117 | 1 118 | Deploy locally 119 | 120 | ProjectExplorer.DefaultDeployConfiguration 121 | 122 | 1 123 | 124 | 125 | 126 | false 127 | false 128 | false 129 | false 130 | true 131 | 0.01 132 | 10 133 | true 134 | 1 135 | 25 136 | 137 | 1 138 | true 139 | false 140 | true 141 | valgrind 142 | 143 | 0 144 | 1 145 | 2 146 | 3 147 | 4 148 | 5 149 | 6 150 | 7 151 | 8 152 | 9 153 | 10 154 | 11 155 | 12 156 | 13 157 | 14 158 | 159 | RefineNet 160 | 161 | 162 | 2 163 | 164 | RefineNet 165 | 166 | CMakeProjectManager.CMakeRunConfiguration.RefineNet 167 | 3768 168 | false 169 | true 170 | false 171 | false 172 | true 173 | 174 | 1 175 | 176 | 177 | 178 | ProjectExplorer.Project.TargetCount 179 | 1 180 | 181 | 182 | ProjectExplorer.Project.Updater.FileVersion 183 | 18 184 | 185 | 186 | Version 187 | 18 188 | 189 | 190 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RefineNet_TensorRT 2 | 3 | ## update20210427 4 | recommend DDRNet whice achieve 334 FPS with TensorRT-FP16 5 | https://github.com/midasklr/DDRNet.TensorRT 6 | 7 | TensorRT for RefineNet Segmentation. 8 | 9 | This is TensorRT C++ project for [RefineNet](https://github.com/midasklr/RefineNet). See details for how to train refinenet for semantic segmentation and export onnx model. 10 | 11 | 这里是RefineNet语义分割模型的TensorRT工程,参考之前的[RefineNet](https://github.com/midasklr/RefineNet),这里我使用Helen人脸分割数据集训练,对原始的RefineNet模型进行了一些修改,修改了部分3*3卷积用于轻量化模型,修改了Upsample为转置卷积。 12 | 13 | ## environment 14 | 15 | Ubuntu1604 16 | 17 | Pytorch 0.41 18 | 19 | TensorRT5.1.5 20 | 21 | OpenCV3.4.8 22 | 23 | ## Build 24 | 25 | 1. configure your TensorRT path in CMakeLists.txt 26 | 27 | 2. make: 28 | 29 | ``` 30 | mkdir build && cd build 31 | cmake .. 32 | make -j8 33 | ``` 34 | 35 | 3. 36 | 37 | ``` 38 | ./RefineNet s float16(float32) refinenet.engine ../vid/demo.mp4 ../refinenet.onnx 39 | ``` 40 | 41 | 4. serialize the engine from onnx model: 42 | 43 | ``` 44 | ./RefineNet s float16(float32) refinenet.engine ../vid/face.mp4 ../refinenet.onnx 45 | ``` 46 | 47 | 5. deserialize the engine and infer: 48 | 49 | ``` 50 | ./RefineNet infer float16 refinenet.engine ../vid/face.mp4 51 | ``` 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | ## Performance 60 | 61 | | Model | FPS | 62 | | ------- | ---- | 63 | | Pytorch | 5 | 64 | | FP32 | 27 | 65 | | FP16 | 33 | 66 | 67 | -------------------------------------------------------------------------------- /RefineNet.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include "NvInfer.h" 12 | #include "NvOnnxParser.h" 13 | #include "common.h" 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include "image.hpp" 20 | 21 | #include 22 | 23 | 24 | using namespace nvinfer1; 25 | std::stringstream gieModelStream; 26 | static const int INPUT_H = 512; 27 | static const int INPUT_W = 512; 28 | static const int INPUT_C = 3; 29 | static const int OUTPUT_SIZE = 180224; 30 | static const int INPUT_SIZE = 786432; 31 | static Logger gLogger1; 32 | static int gUseDLACore{ -1 }; 33 | const char* INPUT_BLOB_NAME = "0"; 34 | 35 | 36 | void onnxToTRTModel(const std::string& modelFile, // name of the onnx model 37 | unsigned int maxBatchSize, // batch size - NB must be at least as large as the batch we want to run with 38 | IHostMemory*& trtModelStream, 39 | DataType dataType, 40 | IInt8Calibrator* calibrator, 41 | std::string save_name) // output buffer for the TensorRT model 42 | { 43 | int verbosity = (int)nvinfer1::ILogger::Severity::kWARNING; 44 | // create the builder 45 | IBuilder* builder = createInferBuilder(gLogger1); 46 | nvinfer1::INetworkDefinition* network = builder->createNetwork(); 47 | 48 | auto parser = nvonnxparser::createParser(*network, gLogger1); 49 | 50 | bool useFp16 = builder->platformHasFastFp16(); 51 | std::cout << "platformHasFastFp16: " << useFp16 << "\n"; 52 | 53 | if (!parser->parseFromFile(modelFile.c_str(), verbosity)) 54 | { 55 | string msg("failed to parse onnx file"); 56 | gLogger1.log(nvinfer1::ILogger::Severity::kERROR, msg.c_str()); 57 | exit(EXIT_FAILURE); 58 | } 59 | if ((dataType == DataType::kINT8 && !builder->platformHasFastInt8()) ) 60 | exit(EXIT_FAILURE); //如果不支持kint8或不支持khalf就返回false 61 | // Build the engine 62 | 63 | builder->setMaxBatchSize(maxBatchSize); 64 | builder->setMaxWorkspaceSize(4_GB); //不能超过你的实际能用的显存的大小,例如我的1060的可用为4.98GB,超过4.98GB会报错 65 | // builder->setInt8Mode(dataType == DataType::kINT8); // 66 | // builder->setInt8Calibrator(calibrator); // 67 | samplesCommon::enableDLA(builder, gUseDLACore); 68 | if(dataType == DataType::kHALF){ 69 | builder->setFp16Mode(true); 70 | std::cout<<"Now we use FP16 mode ..." <buildCudaEngine(*network); 73 | assert(engine); 74 | 75 | // we can destroy the parser 76 | parser->destroy(); 77 | 78 | // serialize the engine, then close everything down 序列化 79 | trtModelStream = engine->serialize(); 80 | 81 | gieModelStream.write((const char*)trtModelStream->data(), trtModelStream->size()); 82 | std::ofstream SaveFile(save_name, std::ios::out | std::ios::binary); 83 | SaveFile.seekp(0, std::ios::beg); 84 | SaveFile << gieModelStream.rdbuf(); 85 | gieModelStream.seekg(0, gieModelStream.beg); 86 | 87 | 88 | engine->destroy(); 89 | network->destroy(); 90 | builder->destroy(); 91 | } 92 | 93 | void doInference(IExecutionContext& context, float* input, float* output, int batchSize) 94 | { 95 | const ICudaEngine& engine = context.getEngine(); 96 | // input and output buffer pointers that we pass to the engine - the engine requires exactly IEngine::getNbBindings(), 97 | // of these, but in this case we know that there is exactly one input and one output. 98 | assert(engine.getNbBindings() == 2); 99 | void* buffers[2]; 100 | 101 | // In order to bind the buffers, we need to know the names of the input and output tensors. 102 | // note that indices are guaranteed to be less than IEngine::getNbBindings() 103 | int inputIndex, outputIndex; 104 | for (int b = 0; b < engine.getNbBindings(); ++b) 105 | { 106 | if (engine.bindingIsInput(b)) 107 | inputIndex = b; 108 | else 109 | outputIndex = b; 110 | } 111 | // create GPU buffers and a stream 创建GPU缓冲区和流 112 | CHECK(cudaMalloc(&buffers[inputIndex], batchSize *INPUT_C* INPUT_H * INPUT_W * sizeof(float))); 113 | CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float))); 114 | 115 | cudaStream_t stream; 116 | CHECK(cudaStreamCreate(&stream)); 117 | 118 | // DMA the input to the GPU, execute the batch asynchronously, and DMA it back: 119 | CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize *INPUT_C* INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream)); 120 | context.enqueue(batchSize, buffers, stream, nullptr);//TensorRT的执行通常是异步的,因此将核加入队列放在CUDA流上 121 | CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream)); 122 | cudaStreamSynchronize(stream); 123 | 124 | // release the stream and the buffers 125 | cudaStreamDestroy(stream); 126 | CHECK(cudaFree(buffers[inputIndex])); 127 | CHECK(cudaFree(buffers[outputIndex])); 128 | } 129 | 130 | 131 | 132 | int do_serialize(int argc, char** argv) 133 | { 134 | IHostMemory* trtModelStream{ nullptr }; 135 | //gUseDLACore = samplesCommon::parseDLA(argc, argv); 136 | // create a TensorRT model from the onnx model and serialize it to a stream 137 | std::string file_name=argv[3]; 138 | std::string modelFile = argv[5]; 139 | 140 | if (argc != 6) 141 | { 142 | std::cout << "s or infer" << std::endl; 143 | std::cout << "float32 or float16" << std::endl; 144 | std::cout << "cam or video file" << std::endl; 145 | std::cout << "save serialize name" << std::endl; 146 | std::cout << "onnx name" << std::endl; 147 | return 1; 148 | } 149 | if(0 == strcmp(argv[2],"float32")){ 150 | std::cout << "using float32 mode" << std::endl; 151 | onnxToTRTModel(modelFile, 1, trtModelStream, DataType::kFLOAT, nullptr, file_name); //读onnx模型,序列化引擎 152 | 153 | }else if(0 == strcmp(argv[2],"float16")){ 154 | // std::cout << "using float16 mode" << std::endl; 155 | onnxToTRTModel(modelFile, 1, trtModelStream, DataType::kHALF, nullptr, file_name); //读onnx模型,序列化引擎 156 | } 157 | // std::cout << "using float32 mode" << std::endl; 158 | // onnxToTRTModel(modelFile, 1, trtModelStream, DataType::kFLOAT, nullptr, file_name); //读onnx模型,序列化引擎 159 | 160 | std::cout << "rialize model ready" << std::endl; 161 | assert(trtModelStream != nullptr); 162 | 163 | // deserialize the engine DLA加速 164 | //反序列化引擎 165 | IRuntime* runtime = createInferRuntime(gLogger1); 166 | assert(runtime != nullptr); 167 | if (gUseDLACore >= 0) 168 | { 169 | runtime->setDLACore(gUseDLACore); 170 | } 171 | //反序列化 172 | ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream->data(), trtModelStream->size(), nullptr); 173 | 174 | assert(engine != nullptr); 175 | trtModelStream->destroy(); 176 | IExecutionContext* context = engine->createExecutionContext(); 177 | assert(context != nullptr); 178 | 179 | int cam_index = 0; 180 | cv::VideoCapture cap; 181 | if (0 == strcmp(argv[4], "cam")) { 182 | cap.open(cam_index); } 183 | else 184 | { cap.open(argv[4]); } 185 | 186 | if (!cap.isOpened()) { 187 | std::cout << "Error: video-stream can't be opened! \n"; 188 | return 1; } 189 | cv::namedWindow("RefineNet", CV_WINDOW_NORMAL); 190 | cv::resizeWindow("RefineNet", 512, 512); 191 | cv::Mat frame; 192 | float prob[OUTPUT_SIZE]; 193 | float* data; 194 | float fps = 0; 195 | 196 | cv::Mat out; 197 | out.create(128, 128, CV_32FC(11)); 198 | cv::Mat real_out; 199 | real_out.create(512, 512, CV_32FC(11)); 200 | cv::Mat real_out_; 201 | real_out_.create(512, 512, CV_8UC3); 202 | while (1) 203 | { 204 | struct timeval tval_before, tval_after, tval_result; 205 | gettimeofday(&tval_before, NULL); 206 | cap >> frame; 207 | cv::cvtColor(frame, frame, cv::COLOR_BGR2RGB); 208 | cv::Mat dst = cv::Mat::zeros(512, 512, CV_32FC3);//新建一张512x512尺寸的图片Mat 209 | cv::resize(frame, dst, dst.size()); 210 | data = normal(dst); 211 | doInference(*context, data, prob, 1);//chw 212 | out = read2mat(prob, out); 213 | //hwc 214 | cv::resize(out, real_out, real_out.size()); 215 | real_out_ = map2threeunchar(real_out, real_out_); 216 | cv::imshow("somename", real_out_); 217 | //show_image(real_out, "Segmentation"); //显示图片 218 | std::free(data); 219 | //free_image(real_out); 220 | if (cvWaitKey(10) == 27) break; 221 | gettimeofday(&tval_after, NULL); 222 | timersub(&tval_after, &tval_before, &tval_result); 223 | float curr = 1000000.f / ((long int)tval_result.tv_usec); 224 | //std::cout << (float)tval_result.tv_usec << std::endl; 225 | printf("\nFPS:%.0f\n", fps); 226 | fps = .9*fps + .1*curr; 227 | //fps = curr; 228 | } 229 | cv::destroyAllWindows(); 230 | cap.release(); 231 | 232 | // destroy the engine 233 | context->destroy(); 234 | engine->destroy(); 235 | runtime->destroy(); 236 | std::cout << "shut down" << std::endl; 237 | //nvcaffeparser1::shutdownProtobufLibrary(); 238 | 239 | return 0; 240 | } 241 | 242 | 243 | int do_deserialize(int argc, char ** argv) 244 | { 245 | gieModelStream.seekg(0, gieModelStream.beg); 246 | std::ifstream serialize_iutput_stream(argv[3], std::ios::in | std::ios::binary); 247 | if (!serialize_iutput_stream) { 248 | std::cout << "cannot find serialize file" << std::endl; 249 | return 1; 250 | } 251 | serialize_iutput_stream.seekg(0); 252 | 253 | gieModelStream << serialize_iutput_stream.rdbuf(); 254 | gieModelStream.seekg(0, std::ios::end); 255 | const int modelSize = gieModelStream.tellg(); 256 | gieModelStream.seekg(0, std::ios::beg); 257 | void* modelMem = malloc(modelSize); 258 | gieModelStream.read((char*)modelMem, modelSize); 259 | 260 | IHostMemory* trtModelStream{ nullptr }; 261 | IBuilder* builder = createInferBuilder(gLogger1); 262 | 263 | if (argc != 5) 264 | { 265 | std::cout << "have_serialize_txt" << std::endl; 266 | std::cout << "float" << std::endl; 267 | std::cout << "cam or video file" << std::endl; 268 | std::cout << "saved serialize name" << std::endl; 269 | return 1; 270 | } 271 | std::cout << "using float32 mode" << std::endl; 272 | 273 | builder->destroy(); 274 | IRuntime* runtime = createInferRuntime(gLogger1); 275 | 276 | ICudaEngine* engine = runtime->deserializeCudaEngine(modelMem, modelSize, NULL); 277 | std::free(modelMem); 278 | assert(engine != nullptr); 279 | IExecutionContext* context = engine->createExecutionContext(); 280 | assert(context != nullptr); 281 | int cam_index = 0; 282 | char *filename = (argc > 3) ? argv[3] : 0; 283 | std::cout << "Hello World!\n"; 284 | cv::VideoCapture cap; 285 | if (0 == strcmp(argv[4], "cam")) { 286 | cap.open(cam_index); } 287 | else { cap.open(argv[4]);} 288 | if (!cap.isOpened()){ 289 | std::cout << "Error: video-stream can't be opened! \n"; 290 | return 1;} 291 | cv::namedWindow("RefineNet", CV_WINDOW_NORMAL); 292 | cv::resizeWindow("RefineNet", 512, 512); 293 | cv::Mat frame; 294 | float prob[OUTPUT_SIZE]; 295 | float* data; 296 | float fps = 0; 297 | 298 | cv::Mat out; 299 | out.create(128, 128, CV_32FC(11)); 300 | cv::Mat real_out; 301 | real_out.create(512, 512, CV_32FC(11)); 302 | cv::Mat real_out_; 303 | real_out_.create(512, 512, CV_8UC3); 304 | while (1) 305 | { 306 | struct timeval tval_before, tval_after, tval_result; 307 | gettimeofday(&tval_before, NULL); 308 | cap >> frame; 309 | cv::cvtColor(frame, frame, cv::COLOR_BGR2RGB); 310 | cv::Mat dst = cv::Mat::zeros(512, 512, CV_32FC3);//新建一张512x512尺寸的图片Mat 311 | cv::resize(frame, dst, dst.size()); 312 | data = normal(dst); 313 | 314 | doInference(*context, data, prob, 1);//chw 315 | out = read2mat(prob, out); 316 | //hwc 317 | cv::resize(out, real_out, real_out.size()); 318 | real_out_ = map2threeunchar(real_out, real_out_); 319 | 320 | cv::imshow("RefineNet", real_out_); 321 | cv::imwrite("result.jpg",real_out_); 322 | std::free(data); 323 | if (cvWaitKey(10) == 27) break; 324 | gettimeofday(&tval_after, NULL); 325 | timersub(&tval_after, &tval_before, &tval_result); 326 | float curr = 1000000.f / ((long int)tval_result.tv_usec); 327 | printf("\nFPS:%.0f\n", fps); 328 | fps = .9*fps + .1*curr; 329 | } 330 | cv::destroyAllWindows(); 331 | cap.release(); 332 | // destroy the engine 333 | context->destroy(); 334 | engine->destroy(); 335 | runtime->destroy(); 336 | std::cout << "shut down" << std::endl; 337 | return 0; 338 | } 339 | 340 | int main(int argc, char** argv) 341 | { 342 | 343 | if (0 == strcmp(argv[1], "infer")) 344 | { 345 | do_deserialize(argc, argv); 346 | } 347 | else 348 | { 349 | do_serialize(argc, argv); 350 | } 351 | return 0; 352 | } 353 | -------------------------------------------------------------------------------- /image.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "image.hpp" 3 | 4 | static const float kMean[3] = { 0.485f, 0.456f, 0.406f }; 5 | static const float kStdDev[3] = { 0.229f, 0.224f, 0.225f }; 6 | static const int map_[11][3] = { {0,0,0} , 7 | {128,0,0}, 8 | {0,128,0}, 9 | {0,0,128}, 10 | {128,128,0}, 11 | {128,0,128}, 12 | {0,128,0}, 13 | {0,128,128}, 14 | {0,255,0}, 15 | {0,0,255}, 16 | {255,0,0} }; 17 | 18 | 19 | float* normal(cv::Mat img) { 20 | //cv::Mat image(img.rows, img.cols, CV_32FC3); 21 | float * data; 22 | data = (float*)calloc(img.rows*img.cols * 3, sizeof(float)); 23 | 24 | for (int c = 0; c < 3; ++c) 25 | { 26 | for (int i = 0; i < img.rows; ++i) 27 | { //获取第i行首像素指针 28 | cv::Vec3b *p1 = img.ptr(i); 29 | //cv::Vec3b *p2 = image.ptr(i); 30 | for (int j = 0; j < img.cols; ++j) 31 | { 32 | data[c * img.cols * img.rows + i * img.cols + j] = (p1[j][c] / 255. - kMean[c]) / kStdDev[c]; 33 | } 34 | } 35 | } 36 | return data; 37 | } 38 | 39 | 40 | 41 | cv::Mat read2mat(float * prob,cv::Mat out) 42 | { 43 | for (int i = 0; i < 128; ++i) 44 | { 45 | cv::Vec *p1 = out.ptr>(i); 46 | for (int j = 0; j < 128; ++j) 47 | { 48 | for (int c = 0; c < 11; ++c) 49 | { 50 | p1[j][c] = prob[c * 128 * 128 + i * 128 + j]; 51 | } 52 | } 53 | } 54 | return out; 55 | } 56 | 57 | 58 | cv::Mat map2threeunchar(cv::Mat real_out,cv::Mat real_out_) 59 | { 60 | for (int i = 0; i < 512; ++i) 61 | { 62 | cv::Vec *p1 = real_out.ptr>(i); 63 | cv::Vec3b *p2 = real_out_.ptr(i); 64 | for (int j = 0; j < 512; ++j) 65 | { 66 | int index = 0; 67 | float swap; 68 | for (int c = 0; c < 11; ++c) 69 | { 70 | if (p1[j][0] < p1[j][c]) 71 | { 72 | swap = p1[j][0]; 73 | p1[j][0] = p1[j][c]; 74 | p1[j][c] = swap; 75 | index = c; 76 | } 77 | } 78 | p2[j][0] = map_[index][2]; 79 | p2[j][1] = map_[index][1]; 80 | p2[j][2] = map_[index][0]; 81 | 82 | } 83 | } 84 | return real_out_; 85 | } 86 | 87 | -------------------------------------------------------------------------------- /image.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | 4 | typedef struct { 5 | int w; 6 | int h; 7 | int c; 8 | float *data; 9 | } image; 10 | 11 | float* normal(cv::Mat img); 12 | 13 | 14 | cv::Mat read2mat(float * prob, cv::Mat out); 15 | 16 | 17 | cv::Mat map2threeunchar(cv::Mat real_out, cv::Mat real_out_); 18 | -------------------------------------------------------------------------------- /image/Screenshot from 2020-08-16 13-39-14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/RefineNet_TensorRT/59724989fb82c96f0dddb1a337fc8bcf90650dda/image/Screenshot from 2020-08-16 13-39-14.png -------------------------------------------------------------------------------- /image/d94be52120f2aa2cfbd7c12f10817b04.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/RefineNet_TensorRT/59724989fb82c96f0dddb1a337fc8bcf90650dda/image/d94be52120f2aa2cfbd7c12f10817b04.jpeg -------------------------------------------------------------------------------- /image/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/RefineNet_TensorRT/59724989fb82c96f0dddb1a337fc8bcf90650dda/image/result.jpg --------------------------------------------------------------------------------