├── CMakeLists.txt ├── README.md ├── calibrator.cpp ├── calibrator.h ├── common.hpp ├── ddrnet.cpp ├── getwts.py ├── images └── mainz_000001_009328_leftImg8bit.png ├── logging.h ├── results ├── Screenshot from 2021-04-21 19-25-48.png ├── Screenshot from 2021-04-21 19-26-08.png └── result_mainz_000001_009328_leftImg8bit.png └── utils.h /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.6) 2 | 3 | project(DDRNet) 4 | 5 | add_definitions(-std=c++11) 6 | 7 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF) 8 | set(CMAKE_CXX_STANDARD 11) 9 | set(CMAKE_BUILD_TYPE Debug) 10 | 11 | find_package(CUDA REQUIRED) 12 | include_directories(/usr/local/cuda/include) 13 | link_directories(/usr/local/cuda/lib64) 14 | 15 | include_directories(${PROJECT_SOURCE_DIR}/include) 16 | 17 | include_directories(/home/midas/TensorRT-7.0.0.11/include/) 18 | link_directories(/home/midas/TensorRT-7.0.0.11/lib/) 19 | 20 | find_package(OpenCV 3.4.8 REQUIRED) 21 | include_directories(${OpenCV_INCLUDE_DIRS}) 22 | 23 | add_executable(ddrnet ${PROJECT_SOURCE_DIR}/ddrnet.cpp) 24 | target_link_libraries(ddrnet nvinfer) 25 | target_link_libraries(ddrnet cudart) 26 | target_link_libraries(ddrnet ${OpenCV_LIBS}) 27 | 28 | add_definitions(-O2 -pthread) 29 | 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDRNet 2 | 3 | TensorRT implementation of the official [DDRNet](https://github.com/ydhongHIT/DDRNet) 4 | 5 |

6 | 7 |

8 | 9 | [DDRNet-23-slim](https://paperswithcode.com/paper/deep-dual-resolution-networks-for-real-time) outperform other light weight segmentation method,[see](https://paperswithcode.com/sota/real-time-semantic-segmentation-on-cityscapes) 10 |

11 | 12 |

13 | 14 | 15 | 16 | ## Compile&Run 17 | 18 | * 1. get model.wts 19 | 20 | Convert pytorch model to wts model using getwts.py, or download the wts [model](url: https://pan.baidu.com/s/1Cm1A2mq6RxCFhUJrOJBSrw ;passworld: p6hy ) convert from official implementation. 21 | 22 | note that we do not use extral segmentation head while inference(set augment=False in https://github.com/ydhongHIT/DDRNet/blob/76a875084afdc7dedd20e2c2bdc0a93f8f481e81/segmentation/DDRNet_23_slim.py#L345). 23 | 24 | * 2. cmake and make 25 | 26 | config ur cmakelist and 27 | 28 | ``` 29 | mkdir build 30 | cd build 31 | cmake .. 32 | make -j8 33 | ./ddrnet -s // serialize model to plan file i.e. 'DDRNet.engine' 34 | ./ddrnet -d ../images // deserialize plan file and run inference, the images in samples will be processed. 35 | ``` 36 | 37 | for INT8 support: 38 | 39 | ``` 40 | #define USE_INT8 // comment out this if want to use INT8 41 | //#define USE_FP16 // comment out this if want to use FP32 42 | ``` 43 | 44 | mkdir "calib" and put around 1k images(cityscape val/test images) into folder "calib". 45 | 46 | ## FPS 47 | 48 | Test on RTX2070 49 | 50 | | model | input | FPS | 51 | | -------------- | --------------- | ---- | 52 | | Pytorch-aug | (3,1024,1024) | 107 | 53 | | Pytorch-no-aug | (3,1024,1024) | 108 | 54 | | TensorRT-FP32 | (3,1024,1024) | 117 | 55 | | TensorRT-FP16 | (3,1024,1024) | 215 | 56 | | TensorRT-INT8 | (3,1024,1024) | 232 | 57 | 58 | Pytorch-aug means augment=True. 59 | 60 | ## Difference with official 61 | 62 | we use Upsample with "nearest" other than "bilinear",which may lead to lower accuracy . 63 | 64 | Finetune with "nearest" upsample may recover the accuracy. 65 | 66 | Here we convert from the official model directly. 67 | 68 | ## Train 69 | 70 | 1. refer to:https://github.com/chenjun2hao/DDRNet.pytorch 71 | 2. generate wts model with getwts.py 72 | 73 | ## Train customer data 74 | https://github.com/midasklr/DDRNet.Pytorch 75 | wirte your own dataset and finetune the model with cityscape. 76 | -------------------------------------------------------------------------------- /calibrator.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "calibrator.h" 6 | #include "cuda_runtime_api.h" 7 | #include "utils.h" 8 | 9 | Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache) 10 | : batchsize_(batchsize) 11 | , input_w_(input_w) 12 | , input_h_(input_h) 13 | , img_idx_(0) 14 | , img_dir_(img_dir) 15 | , calib_table_name_(calib_table_name) 16 | , input_blob_name_(input_blob_name) 17 | , read_cache_(read_cache) 18 | { 19 | input_count_ = 3 * input_w * input_h * batchsize; 20 | CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float))); 21 | read_files_in_dir(img_dir, img_files_); 22 | } 23 | 24 | Int8EntropyCalibrator2::~Int8EntropyCalibrator2() 25 | { 26 | CUDA_CHECK(cudaFree(device_input_)); 27 | } 28 | 29 | int Int8EntropyCalibrator2::getBatchSize() const 30 | { 31 | return batchsize_; 32 | } 33 | 34 | bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings) 35 | { 36 | if (img_idx_ + batchsize_ > (int)img_files_.size()) { 37 | return false; 38 | } 39 | 40 | std::vector input_imgs_; 41 | for (int i = img_idx_; i < img_idx_ + batchsize_; i++) { 42 | std::cout << img_files_[i] << " " << i << std::endl; 43 | cv::Mat temp = cv::imread(img_dir_ + img_files_[i]); 44 | if (temp.empty()){ 45 | std::cerr << "Fatal error: image cannot open!" << std::endl; 46 | return false; 47 | } 48 | cv::Mat pr_img = preprocess_img(temp, input_w_, input_h_); 49 | input_imgs_.push_back(pr_img); 50 | } 51 | img_idx_ += batchsize_; 52 | cv::Mat blob = cv::dnn::blobFromImages(input_imgs_, 1.0 / 57.3750, cv::Size(input_w_, input_h_), cv::Scalar(1.80444, 2.0267, 2.1555), true, false); 53 | 54 | CUDA_CHECK(cudaMemcpy(device_input_, blob.ptr(0), input_count_ * sizeof(float), cudaMemcpyHostToDevice)); 55 | assert(!strcmp(names[0], input_blob_name_)); 56 | bindings[0] = device_input_; 57 | return true; 58 | } 59 | 60 | const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length) 61 | { 62 | std::cout << "reading calib cache: " << calib_table_name_ << std::endl; 63 | calib_cache_.clear(); 64 | std::ifstream input(calib_table_name_, std::ios::binary); 65 | input >> std::noskipws; 66 | if (read_cache_ && input.good()) 67 | { 68 | std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(calib_cache_)); 69 | } 70 | length = calib_cache_.size(); 71 | return length ? calib_cache_.data() : nullptr; 72 | } 73 | 74 | void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length) 75 | { 76 | std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl; 77 | std::ofstream output(calib_table_name_, std::ios::binary); 78 | output.write(reinterpret_cast(cache), length); 79 | } 80 | 81 | -------------------------------------------------------------------------------- /calibrator.h: -------------------------------------------------------------------------------- 1 | #ifndef ENTROPY_CALIBRATOR_H 2 | #define ENTROPY_CALIBRATOR_H 3 | 4 | #include "NvInfer.h" 5 | #include 6 | #include 7 | 8 | //! \class Int8EntropyCalibrator2 9 | //! 10 | //! \brief Implements Entropy calibrator 2. 11 | //! CalibrationAlgoType is kENTROPY_CALIBRATION_2. 12 | //! 13 | class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 14 | { 15 | public: 16 | Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true); 17 | 18 | virtual ~Int8EntropyCalibrator2(); 19 | int getBatchSize() const override; 20 | bool getBatch(void* bindings[], const char* names[], int nbBindings) override; 21 | const void* readCalibrationCache(size_t& length) override; 22 | void writeCalibrationCache(const void* cache, size_t length) override; 23 | 24 | private: 25 | int batchsize_; 26 | int input_w_; 27 | int input_h_; 28 | int img_idx_; 29 | std::string img_dir_; 30 | std::vector img_files_; 31 | size_t input_count_; 32 | std::string calib_table_name_; 33 | const char* input_blob_name_; 34 | bool read_cache_; 35 | void* device_input_; 36 | std::vector calib_cache_; 37 | }; 38 | 39 | #endif // ENTROPY_CALIBRATOR_H 40 | -------------------------------------------------------------------------------- /common.hpp: -------------------------------------------------------------------------------- 1 | #ifndef DDRNET_COMMON_H_ 2 | #define DDRNET_COMMON_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include "dirent.h" 11 | #include "NvInfer.h" 12 | #include 13 | 14 | #define CHECK(status) \ 15 | do\ 16 | {\ 17 | auto ret = (status);\ 18 | if (ret != 0)\ 19 | {\ 20 | std::cerr << "Cuda failure: " << ret << std::endl;\ 21 | abort();\ 22 | }\ 23 | } while (0) 24 | 25 | using namespace nvinfer1; 26 | 27 | // TensorRT weight files have a simple space delimited format: 28 | // [type] [size] 29 | std::map loadWeights(const std::string file) { 30 | std::cout << "Loading weights: " << file << std::endl; 31 | std::map weightMap; 32 | 33 | // Open weights file 34 | std::ifstream input(file); 35 | assert(input.is_open() && "Unable to load weight file."); 36 | 37 | // Read number of weight blobs 38 | int32_t count; 39 | input >> count; 40 | assert(count > 0 && "Invalid weight map file."); 41 | 42 | while (count--) { 43 | Weights wt{ DataType::kFLOAT, nullptr, 0 }; 44 | uint32_t size; 45 | 46 | // Read name and type of blob 47 | std::string name; 48 | input >> name >> std::dec >> size; 49 | wt.type = DataType::kFLOAT; 50 | 51 | // Load blob 52 | uint32_t* val = reinterpret_cast(malloc(sizeof(val) * size)); 53 | for (uint32_t x = 0, y = size; x < y; ++x) { 54 | input >> std::hex >> val[x]; 55 | } 56 | wt.values = val; 57 | 58 | wt.count = size; 59 | weightMap[name] = wt; 60 | } 61 | 62 | return weightMap; 63 | } 64 | 65 | IScaleLayer* addBatchNorm2d(INetworkDefinition *network, std::map& weightMap, ITensor& input, std::string lname, float eps) { 66 | float *gamma = (float*)weightMap[lname + ".weight"].values; 67 | float *beta = (float*)weightMap[lname + ".bias"].values; 68 | float *mean = (float*)weightMap[lname + ".running_mean"].values; 69 | float *var = (float*)weightMap[lname + ".running_var"].values; 70 | int len = weightMap[lname + ".running_var"].count; 71 | 72 | float *scval = reinterpret_cast(malloc(sizeof(float) * len)); 73 | for (int i = 0; i < len; i++) { 74 | scval[i] = gamma[i] / sqrt(var[i] + eps); 75 | } 76 | Weights scale{ DataType::kFLOAT, scval, len }; 77 | 78 | float *shval = reinterpret_cast(malloc(sizeof(float) * len)); 79 | for (int i = 0; i < len; i++) { 80 | shval[i] = beta[i] - mean[i] * gamma[i] / sqrt(var[i] + eps); 81 | } 82 | Weights shift{ DataType::kFLOAT, shval, len }; 83 | 84 | float *pval = reinterpret_cast(malloc(sizeof(float) * len)); 85 | for (int i = 0; i < len; i++) { 86 | pval[i] = 1.0; 87 | } 88 | Weights power{ DataType::kFLOAT, pval, len }; 89 | 90 | weightMap[lname + ".scale"] = scale; 91 | weightMap[lname + ".shift"] = shift; 92 | weightMap[lname + ".power"] = power; 93 | IScaleLayer* scale_1 = network->addScale(input, ScaleMode::kCHANNEL, shift, scale, power); 94 | assert(scale_1); 95 | return scale_1; 96 | } 97 | 98 | 99 | ILayer* basicBlock(INetworkDefinition *network, std::map& weightMap, ITensor& input, int inch, int outch, int stride, bool downsample, bool no_relu, std::string lname) { 100 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 101 | 102 | IConvolutionLayer* conv1 = network->addConvolution(input, outch, DimsHW{ 3, 3 }, weightMap[lname + "conv1.weight"], emptywts); 103 | assert(conv1); 104 | conv1->setStride(DimsHW{ stride, stride }); 105 | conv1->setPadding(DimsHW{ 1, 1 }); 106 | 107 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "bn1", 1e-5); 108 | 109 | IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); 110 | assert(relu1); 111 | 112 | IConvolutionLayer* conv2 = network->addConvolution(*relu1->getOutput(0), outch, DimsHW{ 3, 3 }, weightMap[lname + "conv2.weight"], emptywts); 113 | assert(conv2); 114 | conv2->setPadding(DimsHW{ 1, 1 }); 115 | 116 | IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), lname + "bn2", 1e-5); 117 | 118 | if(downsample){ 119 | IConvolutionLayer* convdown = network->addConvolution(input, outch, DimsHW{ 1, 1 }, weightMap[lname + "downsample.0.weight"], emptywts); 120 | assert(convdown); 121 | convdown->setStride(DimsHW{ stride, stride}); 122 | convdown->setPadding(DimsHW{ 0, 0 }); 123 | 124 | IScaleLayer* bndown = addBatchNorm2d(network, weightMap, *convdown->getOutput(0), lname + "downsample.1", 1e-5); 125 | 126 | IElementWiseLayer* ew1 = network->addElementWise(*bn2->getOutput(0), *bndown->getOutput(0), ElementWiseOperation::kSUM); 127 | if(no_relu){ 128 | return ew1; 129 | }else{ 130 | IActivationLayer* relu3 = network->addActivation(*ew1->getOutput(0), ActivationType::kRELU); 131 | assert(relu3); 132 | return relu3; 133 | } 134 | } 135 | IElementWiseLayer* ew2 = network->addElementWise(input, *bn2->getOutput(0), ElementWiseOperation::kSUM); 136 | if(no_relu){ 137 | return ew2; 138 | }else{ 139 | IActivationLayer* relu3 = network->addActivation(*ew2->getOutput(0), ActivationType::kRELU); 140 | assert(relu1); 141 | return relu3; 142 | } 143 | } 144 | 145 | ILayer* Bottleneck(INetworkDefinition *network, std::map& weightMap, ITensor& input, int inch, int outch, int stride, bool downsample, bool no_relu, std::string lname) { 146 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 147 | 148 | IConvolutionLayer* conv1 = network->addConvolution(input, outch, DimsHW{ 1, 1 }, weightMap[lname + "conv1.weight"], emptywts); 149 | assert(conv1); 150 | conv1->setStride(DimsHW{ 1, 1 }); 151 | conv1->setPadding(DimsHW{ 0, 0 }); 152 | 153 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "bn1", 1e-5); 154 | 155 | IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); 156 | assert(relu1); 157 | 158 | IConvolutionLayer* conv2 = network->addConvolution(*relu1->getOutput(0), outch, DimsHW{ 3, 3 }, weightMap[lname + "conv2.weight"], emptywts); 159 | assert(conv2); 160 | conv2->setStride(DimsHW{ stride, stride }); 161 | conv2->setPadding(DimsHW{ 1, 1 }); 162 | 163 | 164 | IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), lname + "bn2", 1e-5); 165 | 166 | IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU); 167 | assert(relu2); 168 | 169 | IConvolutionLayer* conv3 = network->addConvolution(*relu2->getOutput(0), outch*2, DimsHW{ 1, 1 }, weightMap[lname + "conv3.weight"], emptywts); 170 | assert(conv3); 171 | conv3->setStride(DimsHW{ 1, 1 }); 172 | conv3->setPadding(DimsHW{ 0, 0 }); 173 | 174 | IScaleLayer* bn3 = addBatchNorm2d(network, weightMap, *conv3->getOutput(0), lname + "bn3", 1e-5); 175 | 176 | if(downsample){ 177 | IConvolutionLayer* convdown = network->addConvolution(input, outch*2, DimsHW{ 1, 1 }, weightMap[lname + "downsample.0.weight"], emptywts); 178 | assert(convdown); 179 | convdown->setStride(DimsHW{ stride, stride }); 180 | conv1->setPadding(DimsHW{ 0, 0 }); 181 | 182 | 183 | IScaleLayer* bndown = addBatchNorm2d(network, weightMap, *convdown->getOutput(0), lname + "downsample.1", 1e-5); 184 | 185 | IElementWiseLayer* ew1 = network->addElementWise(*bn3->getOutput(0), *bndown->getOutput(0), ElementWiseOperation::kSUM); 186 | if(no_relu){ 187 | return ew1; 188 | }else{ 189 | IActivationLayer* relu3 = network->addActivation(*ew1->getOutput(0), ActivationType::kRELU); 190 | assert(relu1); 191 | return relu3; 192 | } 193 | } 194 | IElementWiseLayer* ew2 = network->addElementWise(input, *bn3->getOutput(0), ElementWiseOperation::kSUM); 195 | if(no_relu){ 196 | return ew2; 197 | }else{ 198 | IActivationLayer* relu3 = network->addActivation(*ew2->getOutput(0), ActivationType::kRELU); 199 | assert(relu1); 200 | return relu3; 201 | } 202 | } 203 | 204 | 205 | ILayer* compression3(INetworkDefinition *network, std::map& weightMap, ITensor& input, int highres_planes, std::string lname) { 206 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 207 | 208 | IConvolutionLayer* conv1 = network->addConvolution(input, highres_planes , DimsHW{ 1, 1 }, weightMap[lname + "0.weight"], emptywts); 209 | assert(conv1); 210 | conv1->setPadding(DimsHW{ 0, 0 }); 211 | 212 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "1", 1e-5); 213 | 214 | return bn1; 215 | } 216 | 217 | ILayer* compression4(INetworkDefinition *network, std::map& weightMap, ITensor& input, int highres_planes, std::string lname) { 218 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 219 | 220 | IConvolutionLayer* conv1 = network->addConvolution(input, highres_planes , DimsHW{ 1, 1 }, weightMap[lname + "0.weight"], emptywts); 221 | assert(conv1); 222 | conv1->setPadding(DimsHW{ 0, 0 }); 223 | 224 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "1", 1e-5); 225 | 226 | return bn1; 227 | } 228 | 229 | ILayer* down3(INetworkDefinition *network, std::map& weightMap, ITensor& input, int planes, std::string lname) { 230 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 231 | 232 | IConvolutionLayer* conv1 = network->addConvolution(input, planes , DimsHW{ 3, 3 }, weightMap[lname + "0.weight"], emptywts); 233 | assert(conv1); 234 | conv1->setStride(DimsHW{ 2, 2 }); 235 | conv1->setPadding(DimsHW{ 1, 1 }); 236 | 237 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "1", 1e-5); 238 | 239 | return bn1; 240 | } 241 | 242 | ILayer* down4(INetworkDefinition *network, std::map& weightMap, ITensor& input, int planes, std::string lname) { 243 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 244 | 245 | IConvolutionLayer* conv1 = network->addConvolution(input, planes , DimsHW{ 3, 3 }, weightMap[lname + "0.weight"], emptywts); 246 | assert(conv1); 247 | conv1->setStride(DimsHW{ 2, 2 }); 248 | conv1->setPadding(DimsHW{ 1, 1 }); 249 | 250 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "1", 1e-5); 251 | 252 | IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); 253 | 254 | IConvolutionLayer* conv2 = network->addConvolution(*relu1->getOutput(0), planes*2 , DimsHW{ 3, 3 }, weightMap[lname + "3.weight"], emptywts); 255 | assert(conv2); 256 | conv2->setStride(DimsHW{ 2, 2 }); 257 | conv2->setPadding(DimsHW{ 1, 1 }); 258 | 259 | IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), lname + "4", 1e-5); 260 | 261 | return bn2; 262 | } 263 | 264 | ILayer* DAPPM(INetworkDefinition *network, std::map& weightMap, ITensor& input, int inplanes, int branch_planes, int outplanes, std::string lname) { 265 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 266 | 267 | IScaleLayer* scale0bn = addBatchNorm2d(network, weightMap, input, lname + "scale0.0", 1e-5); 268 | 269 | IActivationLayer* scale0relu = network->addActivation(*scale0bn->getOutput(0), ActivationType::kRELU); 270 | 271 | IConvolutionLayer* scale0conv = network->addConvolution(*scale0relu->getOutput(0), branch_planes , DimsHW{ 1, 1 }, weightMap[lname + "scale0.2.weight"], emptywts); 272 | assert(scale0conv); 273 | scale0conv->setPadding(DimsHW{ 0, 0 }); 274 | 275 | // x_list[1] 276 | IPoolingLayer* scale1pool = network->addPooling(input, PoolingType::kAVERAGE, DimsHW{ 5, 5 }); 277 | assert(scale1pool); 278 | scale1pool->setStride(DimsHW{ 2, 2 }); 279 | scale1pool->setPadding(DimsHW{ 2, 2 }); 280 | 281 | IScaleLayer* scale1bn = addBatchNorm2d(network, weightMap, *scale1pool->getOutput(0), lname + "scale1.1", 1e-5); 282 | 283 | IActivationLayer* scale1relu = network->addActivation(*scale1bn->getOutput(0), ActivationType::kRELU); 284 | IConvolutionLayer* scale1conv = network->addConvolution(*scale1relu->getOutput(0), branch_planes , DimsHW{ 1, 1 }, weightMap[lname + "scale1.3.weight"], emptywts); 285 | assert(scale1conv); 286 | scale1conv->setPadding(DimsHW{ 0, 0 }); 287 | 288 | float *deval = reinterpret_cast(malloc(sizeof(float) * branch_planes * 2 * 2)); 289 | for (int i = 0; i < branch_planes * 2 * 2; i++) { 290 | deval[i] = 1.0; 291 | } 292 | Weights deconvwts1{ DataType::kFLOAT, deval, branch_planes * 2 * 2 }; 293 | IDeconvolutionLayer* scale1_interpolate = network->addDeconvolutionNd(*scale1conv->getOutput(0), branch_planes, DimsHW{ 2, 2 }, deconvwts1, emptywts); 294 | scale1_interpolate->setStrideNd(DimsHW{ 2, 2 }); 295 | scale1_interpolate->setNbGroups(branch_planes); 296 | 297 | IElementWiseLayer* process1_input = network->addElementWise(*scale1_interpolate->getOutput(0), *scale0conv->getOutput(0), ElementWiseOperation::kSUM); 298 | 299 | IScaleLayer* process1bn = addBatchNorm2d(network, weightMap, *process1_input->getOutput(0), lname + "process1.0", 1e-5); 300 | 301 | IActivationLayer* process1relu = network->addActivation(*process1bn->getOutput(0), ActivationType::kRELU); 302 | 303 | IConvolutionLayer* process1conv = network->addConvolution(*process1relu->getOutput(0), branch_planes , DimsHW{ 3, 3 }, weightMap[lname + "process1.2.weight"], emptywts); 304 | assert(process1conv); 305 | process1conv->setPadding(DimsHW{ 1, 1 }); 306 | 307 | // x_list[2] 308 | IPoolingLayer* scale2pool = network->addPooling(input, PoolingType::kAVERAGE, DimsHW{ 9, 9 }); 309 | assert(scale2pool); 310 | scale2pool->setStride(DimsHW{ 4, 4 }); 311 | scale2pool->setPadding(DimsHW{ 4, 4 }); 312 | 313 | 314 | IScaleLayer* scale2bn = addBatchNorm2d(network, weightMap, *scale2pool->getOutput(0), lname + "scale2.1", 1e-5); 315 | 316 | IActivationLayer* scale2relu = network->addActivation(*scale2bn->getOutput(0), ActivationType::kRELU); 317 | IConvolutionLayer* scale2conv = network->addConvolution(*scale2relu->getOutput(0), branch_planes , DimsHW{ 1, 1 }, weightMap[lname + "scale2.3.weight"], emptywts); 318 | assert(scale2conv); 319 | scale2conv->setPadding(DimsHW{ 0, 0 }); 320 | 321 | float *deval2 = reinterpret_cast(malloc(sizeof(float) * branch_planes * 4 * 4)); 322 | for (int i = 0; i < branch_planes * 4 * 4; i++) { 323 | deval2[i] = 1.0; 324 | } 325 | Weights deconvwts2{ DataType::kFLOAT, deval2, branch_planes * 4 * 4 }; 326 | IDeconvolutionLayer* scale2_interpolate = network->addDeconvolutionNd(*scale2conv->getOutput(0), branch_planes, DimsHW{ 4, 4 }, deconvwts2, emptywts); 327 | scale2_interpolate->setStrideNd(DimsHW{ 4, 4 }); 328 | scale2_interpolate->setNbGroups(branch_planes); 329 | 330 | IElementWiseLayer* process2_input = network->addElementWise(*scale2_interpolate->getOutput(0), *process1conv->getOutput(0), ElementWiseOperation::kSUM); 331 | // process2 332 | IScaleLayer* process2bn = addBatchNorm2d(network, weightMap, *process2_input->getOutput(0), lname + "process2.0", 1e-5); 333 | 334 | IActivationLayer* process2relu = network->addActivation(*process2bn->getOutput(0), ActivationType::kRELU); 335 | 336 | IConvolutionLayer* process2conv = network->addConvolution(*process2relu->getOutput(0), branch_planes , DimsHW{ 3, 3 }, weightMap[lname + "process2.2.weight"], emptywts); 337 | assert(process2conv); 338 | process2conv->setPadding(DimsHW{ 1, 1 }); 339 | 340 | // scale3 341 | IPoolingLayer* scale3pool = network->addPooling(input, PoolingType::kAVERAGE, DimsHW{ 17, 17 }); 342 | assert(scale3pool); 343 | scale3pool->setStride(DimsHW{ 8, 8 }); 344 | scale3pool->setPadding(DimsHW{ 8, 8 }); 345 | 346 | IScaleLayer* scale3bn = addBatchNorm2d(network, weightMap, *scale3pool->getOutput(0), lname + "scale3.1", 1e-5); 347 | 348 | IActivationLayer* scale3relu = network->addActivation(*scale3bn->getOutput(0), ActivationType::kRELU); 349 | IConvolutionLayer* scale3conv = network->addConvolution(*scale3relu->getOutput(0), branch_planes , DimsHW{ 1, 1 }, weightMap[lname + "scale3.3.weight"], emptywts); 350 | assert(scale3conv); 351 | scale3conv->setPadding(DimsHW{ 0, 0 }); 352 | float *deval3 = reinterpret_cast(malloc(sizeof(float) * branch_planes * 8 * 8)); 353 | for (int i = 0; i < branch_planes * 8 * 8; i++) { 354 | deval3[i] = 1.0; 355 | } 356 | Weights deconvwts3{ DataType::kFLOAT, deval3, branch_planes * 8 * 8 }; 357 | IDeconvolutionLayer* scale3_interpolate = network->addDeconvolutionNd(*scale3conv->getOutput(0), branch_planes, DimsHW{ 8, 8 }, deconvwts3, emptywts); 358 | scale3_interpolate->setStrideNd(DimsHW{ 8, 8 }); 359 | scale3_interpolate->setNbGroups(branch_planes); 360 | 361 | IElementWiseLayer* process3_input = network->addElementWise(*scale3_interpolate->getOutput(0), *process2conv->getOutput(0), ElementWiseOperation::kSUM); 362 | // process3 363 | IScaleLayer* process3bn = addBatchNorm2d(network, weightMap, *process3_input->getOutput(0), lname + "process3.0", 1e-5); 364 | 365 | IActivationLayer* process3relu = network->addActivation(*process3bn->getOutput(0), ActivationType::kRELU); 366 | 367 | IConvolutionLayer* process3conv = network->addConvolution(*process3relu->getOutput(0), branch_planes , DimsHW{ 3, 3 }, weightMap[lname + "process3.2.weight"], emptywts); 368 | assert(process3conv); 369 | process3conv->setPadding(DimsHW{ 1, 1 }); 370 | 371 | // scale4 372 | int input_w = input.getDimensions().d[3]; 373 | int input_h = input.getDimensions().d[2]; 374 | IPoolingLayer* scale4pool = network->addPooling(input, PoolingType::kAVERAGE, DimsHW{ input_h, input_w }); 375 | assert(scale4pool); 376 | scale4pool->setStride(DimsHW{ input_h, input_w }); 377 | scale4pool->setPadding(DimsHW{ 0, 0 }); 378 | 379 | IScaleLayer* scale4bn = addBatchNorm2d(network, weightMap, *scale4pool->getOutput(0), lname + "scale4.1", 1e-5); 380 | 381 | IActivationLayer* scale4relu = network->addActivation(*scale4bn->getOutput(0), ActivationType::kRELU); 382 | IConvolutionLayer* scale4conv = network->addConvolution(*scale4relu->getOutput(0), branch_planes , DimsHW{ 1, 1 }, weightMap[lname + "scale4.3.weight"], emptywts); 383 | assert(scale4conv); 384 | scale4conv->setPadding(DimsHW{ 0, 0 }); 385 | 386 | float *deval4 = reinterpret_cast(malloc(sizeof(float) * branch_planes * input_h * input_w)); 387 | for (int i = 0; i < branch_planes * input_h * input_w; i++) { 388 | deval4[i] = 1.0; 389 | } 390 | Weights deconvwts4{ DataType::kFLOAT, deval4, branch_planes * input_h * input_w }; 391 | IDeconvolutionLayer* scale4_interpolate = network->addDeconvolutionNd(*scale4conv->getOutput(0), branch_planes, DimsHW{ input_h, input_w }, deconvwts4, emptywts); 392 | scale4_interpolate->setStrideNd(DimsHW{ input_h, input_w }); 393 | scale4_interpolate->setNbGroups(branch_planes); 394 | 395 | IElementWiseLayer* process4_input = network->addElementWise(*scale4_interpolate->getOutput(0), *process3conv->getOutput(0), ElementWiseOperation::kSUM); 396 | // process4 397 | IScaleLayer* process4bn = addBatchNorm2d(network, weightMap, *process4_input->getOutput(0), lname + "process4.0", 1e-5); 398 | 399 | IActivationLayer* process4relu = network->addActivation(*process4bn->getOutput(0), ActivationType::kRELU); 400 | 401 | IConvolutionLayer* process4conv = network->addConvolution(*process4relu->getOutput(0), branch_planes , DimsHW{ 3, 3 }, weightMap[lname + "process4.2.weight"], emptywts); 402 | assert(process4conv); 403 | process4conv->setPadding(DimsHW{ 1, 1 }); 404 | 405 | // compression 406 | // out = self.compression(torch.cat(x_list, 1)) + self.shortcut(x) 407 | ITensor* inputTensors[] = {scale0conv->getOutput(0), process1conv->getOutput(0) , process2conv->getOutput(0), process3conv->getOutput(0), process4conv->getOutput(0)}; 408 | IConcatenationLayer* neck_cat = network->addConcatenation(inputTensors, 5); 409 | 410 | IScaleLayer* compressionbn = addBatchNorm2d(network, weightMap, *neck_cat->getOutput(0), lname + "compression.0", 1e-5); 411 | 412 | IActivationLayer* compressionrelu = network->addActivation(*compressionbn->getOutput(0), ActivationType::kRELU); 413 | 414 | IConvolutionLayer* compressionconv = network->addConvolution(*compressionrelu->getOutput(0), outplanes , DimsHW{ 1, 1 }, weightMap[lname + "compression.2.weight"], emptywts); 415 | assert(compressionconv); 416 | compressionconv->setPadding(DimsHW{ 0, 0 }); 417 | 418 | // shortcut 419 | IScaleLayer* shortcutbn = addBatchNorm2d(network, weightMap, input, lname + "shortcut.0", 1e-5); 420 | 421 | IActivationLayer* shortcutrelu = network->addActivation(*shortcutbn->getOutput(0), ActivationType::kRELU); 422 | 423 | IConvolutionLayer* shortcutconv = network->addConvolution(*shortcutrelu->getOutput(0), outplanes , DimsHW{ 1, 1 }, weightMap[lname + "shortcut.2.weight"], emptywts); 424 | assert(shortcutconv); 425 | shortcutconv->setPadding(DimsHW{ 0, 0 }); 426 | 427 | IElementWiseLayer* out = network->addElementWise(*compressionconv->getOutput(0), *shortcutconv->getOutput(0), ElementWiseOperation::kSUM); 428 | return out; 429 | } 430 | 431 | ILayer* segmenthead(INetworkDefinition *network, std::map& weightMap, ITensor& input, int interplanes, int outplanes, std::string lname) { 432 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 433 | 434 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, input, lname + "bn1", 1e-5); 435 | 436 | IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); 437 | 438 | 439 | IConvolutionLayer* conv1 = network->addConvolution(*relu1->getOutput(0), interplanes , DimsHW{ 3, 3 }, weightMap[lname + "conv1.weight"], emptywts); 440 | assert(conv1); 441 | conv1->setPadding(DimsHW{ 1, 1 }); 442 | 443 | IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname + "bn2", 1e-5); 444 | 445 | IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU); 446 | IConvolutionLayer* conv2 = network->addConvolution(*relu2->getOutput(0), outplanes , DimsHW{ 1, 1 }, weightMap[lname + "conv2.weight"], weightMap[lname + "conv2.bias"]); 447 | assert(conv2); 448 | conv2->setPadding(DimsHW{ 0, 0 }); 449 | 450 | return conv2; 451 | } 452 | 453 | int read_files_in_dir(const char *p_dir_name, std::vector &file_names) { 454 | DIR *p_dir = opendir(p_dir_name); 455 | if (p_dir == nullptr) { 456 | return -1; 457 | } 458 | 459 | struct dirent* p_file = nullptr; 460 | while ((p_file = readdir(p_dir)) != nullptr) { 461 | if (strcmp(p_file->d_name, ".") != 0 && 462 | strcmp(p_file->d_name, "..") != 0) { 463 | //std::string cur_file_name(p_dir_name); 464 | //cur_file_name += "/"; 465 | //cur_file_name += p_file->d_name; 466 | std::string cur_file_name(p_file->d_name); 467 | file_names.push_back(cur_file_name); 468 | } 469 | } 470 | closedir(p_dir); 471 | return 0; 472 | } 473 | 474 | static const int map_[19][3] = { {255,0,0} , 475 | {128,0,0}, 476 | {0,128,0}, 477 | {0,0,128}, 478 | {128,128,0}, 479 | {128,0,128}, 480 | {0,128,128}, 481 | {0,255,0}, 482 | {0,0,255}, 483 | {255,0,0}, 484 | {255,255,0}, 485 | {0,255,255}, 486 | {255,0,255}, 487 | {255,0,128}, 488 | {128,255,0}, 489 | {128,0,255}, 490 | {0,255,128}, 491 | {0,255,255}, 492 | {255,0,255},}; 493 | 494 | cv::Mat map2cityscape(cv::Mat real_out,cv::Mat real_out_) 495 | { 496 | for (int i = 0; i < 128; ++i) 497 | { 498 | cv::Vec *p1 = real_out.ptr>(i); 499 | cv::Vec3b *p2 = real_out_.ptr(i); 500 | for (int j = 0; j < 128; ++j) 501 | { 502 | int index = 0; 503 | float swap; 504 | for (int c = 0; c < 19; ++c) 505 | { 506 | if (p1[j][0] < p1[j][c]) 507 | { 508 | swap = p1[j][0]; 509 | p1[j][0] = p1[j][c]; 510 | p1[j][c] = swap; 511 | index = c; 512 | } 513 | } 514 | p2[j][0] = map_[index][2]; 515 | p2[j][1] = map_[index][1]; 516 | p2[j][2] = map_[index][0]; 517 | 518 | } 519 | } 520 | return real_out_; 521 | } 522 | 523 | cv::Mat read2mat(float * prob, cv::Mat out) 524 | { 525 | for (int i = 0; i < 128; ++i) 526 | { 527 | cv::Vec *p1 = out.ptr>(i); 528 | for (int j = 0; j < 128; ++j) 529 | { 530 | for (int c = 0; c < 19; ++c) 531 | { 532 | p1[j][c] = prob[c * 128 * 128 + i * 128 + j]; 533 | } 534 | } 535 | } 536 | return out; 537 | } 538 | 539 | 540 | #endif 541 | 542 | -------------------------------------------------------------------------------- /ddrnet.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "cuda_runtime_api.h" 4 | #include "logging.h" 5 | #include "common.hpp" 6 | #include 7 | #include "calibrator.h" 8 | 9 | 10 | //#define USE_INT8 // comment out this if want to use INT8 11 | #define USE_FP16 // comment out this if want to use FP32 12 | #define DEVICE 0 // GPU id 13 | static const int INPUT_H = 1024; 14 | static const int INPUT_W = 1024; 15 | static const int OUT_MAP_H = 128; 16 | static const int OUT_MAP_W = 128; 17 | const char* INPUT_BLOB_NAME = "input_0"; 18 | const char* OUTPUT_BLOB_NAME = "output_0"; 19 | static Logger gLogger; 20 | 21 | // Creat the engine using only the API and not any parser. 22 | ICudaEngine* createEngine(unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt) { 23 | INetworkDefinition* network = builder->createNetworkV2(0U); 24 | // Create input tensor of shape {1, 3, INPUT_H, INPUT_W} with name INPUT_BLOB_NAME 25 | ITensor* data = network->addInput(INPUT_BLOB_NAME, dt, Dims4{ 1, 3, INPUT_H, INPUT_W }); 26 | assert(data); 27 | 28 | std::map weightMap = loadWeights("../DDRNet_CS.wts"); 29 | Weights emptywts{ DataType::kFLOAT, nullptr, 0 }; 30 | 31 | IConvolutionLayer* conv1 = network->addConvolution(*data, 32, DimsHW{ 3, 3 }, weightMap["conv1.0.weight"], weightMap["conv1.0.bias"]); 32 | assert(conv1); 33 | conv1->setStride(DimsHW{ 2, 2 }); 34 | conv1->setPadding(DimsHW{ 1, 1 }); 35 | 36 | IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), "conv1.1", 1e-5); 37 | 38 | IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU); 39 | assert(relu1); 40 | 41 | IConvolutionLayer* conv2 = network->addConvolution(*relu1->getOutput(0), 32, DimsHW{ 3, 3 }, weightMap["conv1.3.weight"], weightMap["conv1.3.bias"]); 42 | assert(conv2); 43 | conv2->setStride(DimsHW{ 2, 2 }); 44 | conv2->setPadding(DimsHW{ 1, 1 }); 45 | 46 | IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), "conv1.4", 1e-5); 47 | 48 | IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU); 49 | assert(relu2); 50 | 51 | // layer1 52 | ILayer* layer1_0 = basicBlock(network, weightMap, *relu2->getOutput(0), 32, 32, 1, false, false, "layer1.0."); 53 | ILayer* layer1_1 = basicBlock(network, weightMap, *layer1_0->getOutput(0), 32, 32, 1, false, true, "layer1.1."); 54 | IActivationLayer* layer1_relu = network->addActivation(*layer1_1->getOutput(0), ActivationType::kRELU); 55 | assert(layer1_relu); 56 | 57 | // layer2 58 | ILayer* layer2_0 = basicBlock(network, weightMap, *layer1_relu->getOutput(0), 32, 64, 2, true, false, "layer2.0."); 59 | ILayer* layer2_1 = basicBlock(network, weightMap, *layer2_0->getOutput(0), 64, 64, 1, false, true, "layer2.1."); // 1/8 60 | IActivationLayer* layer2_relu = network->addActivation(*layer2_1->getOutput(0), ActivationType::kRELU); 61 | assert(layer2_relu); 62 | 63 | // layer3 64 | ILayer* layer3_0 = basicBlock(network, weightMap, *layer2_relu->getOutput(0), 64, 128, 2, true, false, "layer3.0."); 65 | ILayer* layer3_1 = basicBlock(network, weightMap, *layer3_0->getOutput(0), 128, 128, 1, false, true, "layer3.1."); // 1/16 66 | IActivationLayer* layer3_relu = network->addActivation(*layer3_1->getOutput(0), ActivationType::kRELU); 67 | assert(layer3_relu); // layer[2] 68 | 69 | // x_ = self.layer3_(self.relu(layers[1])) 70 | // layer3_ 71 | ILayer* layer3_10 = basicBlock(network, weightMap, *layer2_relu->getOutput(0), 64, 64, 1, false, false, "layer3_.0."); 72 | ILayer* layer3_11 = basicBlock(network, weightMap, *layer3_10->getOutput(0), 64, 64, 1, false, true, "layer3_.1."); // x_ = self.layer3_(self.relu(layers[1])) 73 | // 74 | 75 | // down3 76 | IActivationLayer* down3_input_relu = network->addActivation(*layer3_11->getOutput(0), ActivationType::kRELU); 77 | assert(down3_input_relu); 78 | 79 | ILayer* down3_out = down3(network, weightMap, *down3_input_relu->getOutput(0), 128, "down3."); 80 | // x = x + self.down3(self.relu(x_)) 81 | 82 | IElementWiseLayer* down3_add = network->addElementWise(*layer3_1->getOutput(0), *down3_out->getOutput(0), ElementWiseOperation::kSUM); 83 | 84 | //x_ = x_ + F.interpolate(self.compression3(self.relu(layers[2])), size=[height_output, width_output], mode='bilinear',align_corners=True) 85 | ILayer* compression3_input = compression3(network, weightMap, *layer3_relu->getOutput(0), 64, "compression3."); 86 | 87 | float *deval = reinterpret_cast(malloc(sizeof(float) * 64 * 2 * 2)); 88 | for (int i = 0; i < 64 * 2 * 2; i++) { 89 | deval[i] = 1.0; 90 | } 91 | Weights deconvwts1{ DataType::kFLOAT, deval, 64 * 2 * 2 }; 92 | IDeconvolutionLayer* compression3_up = network->addDeconvolutionNd(*compression3_input->getOutput(0), 64, DimsHW{ 2, 2 }, deconvwts1, emptywts); 93 | compression3_up->setStrideNd(DimsHW{ 2, 2 }); 94 | compression3_up->setNbGroups(64); 95 | IElementWiseLayer* compression3_add = network->addElementWise(*layer3_11->getOutput(0), *compression3_up->getOutput(0), ElementWiseOperation::kSUM); 96 | // x_ = self.layer4_(self.relu(x_)) 97 | // layer4 98 | IActivationLayer* layer4_input = network->addActivation(*down3_add->getOutput(0), ActivationType::kRELU); 99 | ILayer* layer4_0 = basicBlock(network, weightMap, *layer4_input->getOutput(0), 128, 256, 2, true, false, "layer4.0."); 100 | // x = self.layer4(self.relu(x)) 101 | ILayer* layer4_1 = basicBlock(network, weightMap, *layer4_0->getOutput(0), 256, 256, 1, false, true, "layer4.1."); // 1/32 102 | IActivationLayer* layer4_relu = network->addActivation(*layer4_1->getOutput(0), ActivationType::kRELU); 103 | assert(layer4_relu); 104 | 105 | // layer4_ 106 | IActivationLayer* layer4_1_input = network->addActivation(*compression3_add->getOutput(0), ActivationType::kRELU); 107 | ILayer* layer4_10 = basicBlock(network, weightMap, *layer4_1_input->getOutput(0), 64, 64, 1, false, false, "layer4_.0."); 108 | // x_ = self.layer4_(self.relu(x_)) 109 | ILayer* layer4_11 = basicBlock(network, weightMap, *layer4_10->getOutput(0), 64, 64, 1, false, true, "layer4_.1."); // 1/8 110 | // down4 111 | IActivationLayer* down4_input_relu = network->addActivation(*layer4_11->getOutput(0), ActivationType::kRELU); 112 | assert(down4_input_relu); 113 | ILayer* down4_out = down4(network, weightMap, *down4_input_relu->getOutput(0), 128, "down4."); 114 | 115 | IElementWiseLayer* down4_add = network->addElementWise(*layer4_1->getOutput(0), *down4_out->getOutput(0), ElementWiseOperation::kSUM); 116 | // x_ = x_ + F.interpolate(self.compression4(self.relu(layers[3])),size=[height_output, width_output],mode='bilinear',align_corners=True) 117 | ILayer* compression4_input = compression4(network, weightMap, *layer4_relu->getOutput(0), 64, "compression4."); 118 | 119 | float *deval2 = reinterpret_cast(malloc(sizeof(float) * 64 * 4 * 4)); 120 | for (int i = 0; i < 64 * 4 * 4; i++) { 121 | deval2[i] = 1.0; 122 | } 123 | Weights deconvwts2{ DataType::kFLOAT, deval2, 64 * 4 * 4 }; 124 | IDeconvolutionLayer* compression4_up = network->addDeconvolutionNd(*compression4_input->getOutput(0), 64, DimsHW{ 4, 4 }, deconvwts2, emptywts); 125 | compression4_up->setStrideNd(DimsHW{ 4, 4 }); 126 | compression4_up->setNbGroups(64); 127 | 128 | IElementWiseLayer* compression4_add = network->addElementWise(*layer4_11->getOutput(0), *compression4_up->getOutput(0), ElementWiseOperation::kSUM); 129 | IActivationLayer* compression4_add_relu = network->addActivation(*compression4_add->getOutput(0), ActivationType::kRELU); 130 | assert(compression4_add_relu); 131 | // layer5_ 132 | // x_ = self.layer5_(self.relu(x_)) 133 | ILayer* layer5_ = Bottleneck(network, weightMap, *compression4_add_relu->getOutput(0), 64, 64, 1, true, true, "layer5_.0."); 134 | 135 | // layer5 136 | IActivationLayer* layer5_input = network->addActivation(*down4_add->getOutput(0), ActivationType::kRELU); 137 | assert(layer5_input); 138 | ILayer* layer5 = Bottleneck(network, weightMap, *layer5_input->getOutput(0), 256, 256, 2, true, true, "layer5.0."); 139 | ILayer* ssp = DAPPM(network, weightMap, *layer5->getOutput(0), 512, 128, 128, "spp."); 140 | 141 | float *deval3 = reinterpret_cast(malloc(sizeof(float) * 128 * 8 * 8)); 142 | for (int i = 0; i < 128 * 8 * 8; i++) { 143 | deval3[i] = 1.0; 144 | } 145 | Weights deconvwts3{ DataType::kFLOAT, deval3, 128 * 8 * 8 }; 146 | IDeconvolutionLayer* spp_up = network->addDeconvolutionNd(*ssp->getOutput(0), 128, DimsHW{ 8, 8 }, deconvwts3, emptywts); 147 | spp_up->setStrideNd(DimsHW{ 8, 8 }); 148 | spp_up->setNbGroups(128); 149 | // x_ = self.final_layer(x + x_) 150 | 151 | IElementWiseLayer* final_in = network->addElementWise(*spp_up->getOutput(0), *layer5_->getOutput(0), ElementWiseOperation::kSUM); 152 | 153 | ILayer* seg_out= segmenthead(network, weightMap, *final_in->getOutput(0), 64, 19, "final_layer."); 154 | 155 | // IActivationLayer* thresh = network->addActivation(*seg_out->getOutput(0), ActivationType::kSIGMOID); 156 | // assert(thresh); 157 | 158 | // y = F.interpolate(y, size=(H, W)) 159 | seg_out->getOutput(0)->setName(OUTPUT_BLOB_NAME); 160 | network->markOutput(*seg_out->getOutput(0)); 161 | 162 | // IOptimizationProfile* profile = builder->createOptimizationProfile(); 163 | // profile->setDimensions(INPUT_BLOB_NAME, OptProfileSelector::kMIN, Dims4(1, 3, MIN_INPUT_SIZE, MIN_INPUT_SIZE)); 164 | // profile->setDimensions(INPUT_BLOB_NAME, OptProfileSelector::kOPT, Dims4(1, 3, OPT_INPUT_H, OPT_INPUT_W)); 165 | // profile->setDimensions(INPUT_BLOB_NAME, OptProfileSelector::kMAX, Dims4(1, 3, MAX_INPUT_SIZE, MAX_INPUT_SIZE)); 166 | // config->addOptimizationProfile(profile); 167 | 168 | // Build engine 169 | builder->setMaxBatchSize(maxBatchSize); 170 | config->setMaxWorkspaceSize(16 * (1 << 20)); // 16MB 171 | 172 | #if defined(USE_INT8) 173 | std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl; 174 | assert(builder->platformHasFastInt8()); 175 | config->setFlag(BuilderFlag::kINT8); 176 | Int8EntropyCalibrator2 *calibrator = new Int8EntropyCalibrator2(1, INPUT_W, INPUT_H, "../calib/", "int8calib.table", INPUT_BLOB_NAME); 177 | config->setInt8Calibrator(calibrator); 178 | #elif defined(USE_FP16) 179 | config->setFlag(BuilderFlag::kFP16); 180 | #endif 181 | std::cout << "Building engine, please wait for a while..." << std::endl; 182 | ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config); 183 | std::cout << "Build engine successfully!" << std::endl; 184 | 185 | // Don't need the network any more 186 | network->destroy(); 187 | 188 | // Release host memory 189 | for (auto& mem : weightMap) { 190 | free((void*)(mem.second.values)); 191 | } 192 | 193 | return engine; 194 | } 195 | 196 | void APIToModel(unsigned int maxBatchSize, IHostMemory** modelStream) { 197 | // Create builder 198 | IBuilder* builder = createInferBuilder(gLogger); 199 | IBuilderConfig* config = builder->createBuilderConfig(); 200 | 201 | // Create model to populate the network, then set the outputs and create an engine 202 | ICudaEngine* engine = createEngine(maxBatchSize, builder, config, DataType::kFLOAT); 203 | //ICudaEngine* engine = createEngine(maxBatchSize, builder, config, DataType::kFLOAT); 204 | assert(engine != nullptr); 205 | 206 | // Serialize the engine 207 | (*modelStream) = engine->serialize(); 208 | 209 | // Close everything down 210 | engine->destroy(); 211 | builder->destroy(); 212 | } 213 | 214 | void doInference(IExecutionContext& context, float* input, float* output) { 215 | const ICudaEngine& engine = context.getEngine(); 216 | 217 | // Pointers to input and output device buffers to pass to engine. 218 | // Engine requires exactly IEngine::getNbBindings() number of buffers. 219 | // std::cout<<"engine.getNbBindings():"<(modelStream->data()), modelStream->size()); 266 | modelStream->destroy(); 267 | return 0; 268 | } 269 | else if (argc == 3 && std::string(argv[1]) == "-d") { 270 | std::ifstream file("DDRNet.engine", std::ios::binary); 271 | if (file.good()) { 272 | file.seekg(0, file.end); 273 | size = file.tellg(); 274 | file.seekg(0, file.beg); 275 | trtModelStream = new char[size]; 276 | assert(trtModelStream); 277 | file.read(trtModelStream, size); 278 | file.close(); 279 | } 280 | } 281 | else { 282 | std::cerr << "arguments not right!" << std::endl; 283 | std::cerr << "./debnet -s // serialize model to plan file" << std::endl; 284 | std::cerr << "./debnet -d ../samples // deserialize plan file and run inference" << std::endl; 285 | return -1; 286 | } 287 | 288 | // prepare input data --------------------------- 289 | IRuntime* runtime = createInferRuntime(gLogger); 290 | assert(runtime != nullptr); 291 | ICudaEngine* engine = runtime->deserializeCudaEngine(trtModelStream, size); 292 | assert(engine != nullptr); 293 | IExecutionContext* context = engine->createExecutionContext(); 294 | assert(context != nullptr); 295 | delete[] trtModelStream; 296 | 297 | std::vector file_names; 298 | if (read_files_in_dir(argv[2], file_names) < 0) { 299 | std::cout << "read_files_in_dir failed." << std::endl; 300 | return -1; 301 | } 302 | 303 | std::vector mean_value{ 0.406, 0.456, 0.485 }; // BGR 304 | std::vector std_value{ 0.225, 0.224, 0.229 }; 305 | int fcount = 0; 306 | for (auto f : file_names) { 307 | fcount++; 308 | std::cout << fcount << " " << f << std::endl; 309 | cv::Mat pr_img = cv::imread(std::string(argv[2]) + "/" + f); 310 | cv::resize(pr_img,pr_img,cv::Size(INPUT_W,INPUT_H)); 311 | if (pr_img.empty()) continue; 312 | float* data = new float[3 * pr_img.rows * pr_img.cols]; 313 | int i = 0; 314 | for (int row = 0; row < pr_img.rows; ++row) { 315 | uchar* uc_pixel = pr_img.data + row * pr_img.step; 316 | for (int col = 0; col < pr_img.cols; ++col) { 317 | data[i] = (uc_pixel[2] / 255.0 - mean_value[2]) / std_value[2]; 318 | data[i + pr_img.rows * pr_img.cols] = (uc_pixel[1] / 255.0 - mean_value[1]) / std_value[1]; 319 | data[i + 2 * pr_img.rows * pr_img.cols] = (uc_pixel[0] / 255.0 - mean_value[0]) / std_value[0]; 320 | uc_pixel += 3; 321 | ++i; 322 | } 323 | } 324 | float* prob = new float[ 19* OUT_MAP_H* OUT_MAP_W]; 325 | // Run inference 326 | auto start = std::chrono::system_clock::now(); 327 | doInference(*context, data, prob); 328 | auto end = std::chrono::system_clock::now(); 329 | std::cout << std::chrono::duration_cast(end - start).count() << "ms" << std::endl; 330 | 331 | // show mask 332 | cv::Mat out; 333 | out.create(OUT_MAP_H, OUT_MAP_W, CV_32FC(19)); 334 | out = read2mat(prob, out); 335 | // cv::resize(out, real_out, real_out.size()); 336 | cv::Mat mask; 337 | mask.create(OUT_MAP_H, OUT_MAP_W, CV_8UC3); 338 | mask = map2cityscape(out, mask); 339 | cv::resize(mask,mask,cv::Size(INPUT_W,INPUT_H)); 340 | cv::Mat result; 341 | cv::addWeighted(pr_img,0.7,mask,0.3,1,result); 342 | cv::resize(result,result,cv::Size(1024,512)); 343 | cv::imwrite("result_" + f, result); 344 | delete prob; 345 | delete data; 346 | } 347 | return 0; 348 | } 349 | -------------------------------------------------------------------------------- /getwts.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import os 9 | import pprint 10 | import shutil 11 | import sys 12 | 13 | import logging 14 | import time 15 | import timeit 16 | from pathlib import Path 17 | import time 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.backends.cudnn as cudnn 22 | import struct 23 | import _init_paths 24 | import models 25 | import cv2 26 | import torch.nn.functional as F 27 | import datasets 28 | from config import config 29 | from config import update_config 30 | from core.function import testval, test 31 | from utils.modelsummary import get_model_summary 32 | from utils.utils import create_logger, FullModel, speed_test 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description='Train segmentation network') 36 | 37 | parser.add_argument('--cfg', 38 | help='experiment configure file name', 39 | default="experiments/cityscapes/ddrnet23_slim.yaml", 40 | type=str) 41 | parser.add_argument('opts', 42 | help="Modify config options using the command-line", 43 | default=None, 44 | nargs=argparse.REMAINDER) 45 | 46 | args = parser.parse_args() 47 | update_config(config, args) 48 | 49 | return args 50 | 51 | def main(): 52 | mean=[0.485, 0.456, 0.406], 53 | std=[0.229, 0.224, 0.225] 54 | args = parse_args() 55 | 56 | logger, final_output_dir, _ = create_logger( 57 | config, args.cfg, 'test') 58 | 59 | logger.info(pprint.pformat(args)) 60 | logger.info(pprint.pformat(config)) 61 | 62 | # cudnn related setting 63 | cudnn.benchmark = config.CUDNN.BENCHMARK 64 | cudnn.deterministic = config.CUDNN.DETERMINISTIC 65 | cudnn.enabled = config.CUDNN.ENABLED 66 | 67 | # build model 68 | if torch.__version__.startswith('1'): 69 | module = eval('models.'+config.MODEL.NAME) 70 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d 71 | model = eval('models.'+config.MODEL.NAME + 72 | '.get_seg_model')(config) 73 | 74 | dump_input = torch.rand( 75 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]) 76 | ) 77 | logger.info(get_model_summary(model.cuda(), dump_input.cuda())) 78 | 79 | if config.TEST.MODEL_FILE: 80 | model_state_file = config.TEST.MODEL_FILE 81 | else: 82 | # model_state_file = os.path.join(final_output_dir, 'best_0.7589.pth') 83 | model_state_file = os.path.join(final_output_dir, 'best.pth') 84 | logger.info('=> loading model from {}'.format(model_state_file)) 85 | 86 | pretrained_dict = torch.load('model_best_bacc.pth.tar') 87 | if 'state_dict' in pretrained_dict: 88 | pretrained_dict = pretrained_dict['state_dict'] 89 | newstate_dict = {k:v for k,v in pretrained_dict.items() if k in model.state_dict()} 90 | # print(pretrained_dict.keys()) 91 | 92 | model.load_state_dict(newstate_dict) 93 | model = model.cuda() 94 | 95 | if True: 96 | save_wts = True 97 | print(model) 98 | if save_wts: 99 | f = open('DDRNet_CS.wts', 'w') 100 | f.write('{}\n'.format(len(model.state_dict().keys()))) 101 | for k, v in model.state_dict().items(): 102 | print("Layer {} ; Size {}".format(k,v.cpu().numpy().shape)) 103 | vr = v.reshape(-1).cpu().numpy() 104 | f.write('{} {} '.format(k, len(vr))) 105 | for vv in vr: 106 | f.write(' ') 107 | f.write(struct.pack('>f', float(vv)).hex()) 108 | f.write('\n') 109 | 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /images/mainz_000001_009328_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/DDRNet.TensorRT/f3dbdb77d9a5fb412dd42dcc0acd56a4c75dd0d4/images/mainz_000001_009328_leftImg8bit.png -------------------------------------------------------------------------------- /logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | 29 | using Severity = nvinfer1::ILogger::Severity; 30 | 31 | class LogStreamConsumerBuffer : public std::stringbuf 32 | { 33 | public: 34 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 35 | : mOutput(stream) 36 | , mPrefix(prefix) 37 | , mShouldLog(shouldLog) 38 | { 39 | } 40 | 41 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 42 | : mOutput(other.mOutput) 43 | { 44 | } 45 | 46 | ~LogStreamConsumerBuffer() 47 | { 48 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 49 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 50 | // if the pointer to the beginning is not equal to the pointer to the current position, 51 | // call putOutput() to log the output to the stream 52 | if (pbase() != pptr()) 53 | { 54 | putOutput(); 55 | } 56 | } 57 | 58 | // synchronizes the stream buffer and returns 0 on success 59 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 60 | // resetting the buffer and flushing the stream 61 | virtual int sync() 62 | { 63 | putOutput(); 64 | return 0; 65 | } 66 | 67 | void putOutput() 68 | { 69 | if (mShouldLog) 70 | { 71 | // prepend timestamp 72 | std::time_t timestamp = std::time(nullptr); 73 | tm* tm_local = std::localtime(×tamp); 74 | std::cout << "["; 75 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 76 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 77 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 78 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 81 | // std::stringbuf::str() gets the string contents of the buffer 82 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 83 | mOutput << mPrefix << str(); 84 | // set the buffer to empty 85 | str(""); 86 | // flush the stream 87 | mOutput.flush(); 88 | } 89 | } 90 | 91 | void setShouldLog(bool shouldLog) 92 | { 93 | mShouldLog = shouldLog; 94 | } 95 | 96 | private: 97 | std::ostream& mOutput; 98 | std::string mPrefix; 99 | bool mShouldLog; 100 | }; 101 | 102 | //! 103 | //! \class LogStreamConsumerBase 104 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 105 | //! 106 | class LogStreamConsumerBase 107 | { 108 | public: 109 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 110 | : mBuffer(stream, prefix, shouldLog) 111 | { 112 | } 113 | 114 | protected: 115 | LogStreamConsumerBuffer mBuffer; 116 | }; 117 | 118 | //! 119 | //! \class LogStreamConsumer 120 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 121 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 122 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 123 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 124 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 125 | //! Please do not change the order of the parent classes. 126 | //! 127 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 128 | { 129 | public: 130 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 131 | //! Reportable severity determines if the messages are severe enough to be logged. 132 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 133 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 134 | , std::ostream(&mBuffer) // links the stream buffer with the stream 135 | , mShouldLog(severity <= reportableSeverity) 136 | , mSeverity(severity) 137 | { 138 | } 139 | 140 | LogStreamConsumer(LogStreamConsumer&& other) 141 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 142 | , std::ostream(&mBuffer) // links the stream buffer with the stream 143 | , mShouldLog(other.mShouldLog) 144 | , mSeverity(other.mSeverity) 145 | { 146 | } 147 | 148 | void setReportableSeverity(Severity reportableSeverity) 149 | { 150 | mShouldLog = mSeverity <= reportableSeverity; 151 | mBuffer.setShouldLog(mShouldLog); 152 | } 153 | 154 | private: 155 | static std::ostream& severityOstream(Severity severity) 156 | { 157 | return severity >= Severity::kINFO ? std::cout : std::cerr; 158 | } 159 | 160 | static std::string severityPrefix(Severity severity) 161 | { 162 | switch (severity) 163 | { 164 | case Severity::kINTERNAL_ERROR: return "[F] "; 165 | case Severity::kERROR: return "[E] "; 166 | case Severity::kWARNING: return "[W] "; 167 | case Severity::kINFO: return "[I] "; 168 | case Severity::kVERBOSE: return "[V] "; 169 | default: assert(0); return ""; 170 | } 171 | } 172 | 173 | bool mShouldLog; 174 | Severity mSeverity; 175 | }; 176 | 177 | //! \class Logger 178 | //! 179 | //! \brief Class which manages logging of TensorRT tools and samples 180 | //! 181 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 182 | //! and supports logging two types of messages: 183 | //! 184 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 185 | //! - Test pass/fail messages 186 | //! 187 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 188 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 189 | //! 190 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 191 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 192 | //! 193 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 194 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 195 | //! library and messages coming from the sample. 196 | //! 197 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 198 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 199 | //! object. 200 | 201 | class Logger : public nvinfer1::ILogger 202 | { 203 | public: 204 | Logger(Severity severity = Severity::kWARNING) 205 | : mReportableSeverity(severity) 206 | { 207 | } 208 | 209 | //! 210 | //! \enum TestResult 211 | //! \brief Represents the state of a given test 212 | //! 213 | enum class TestResult 214 | { 215 | kRUNNING, //!< The test is running 216 | kPASSED, //!< The test passed 217 | kFAILED, //!< The test failed 218 | kWAIVED //!< The test was waived 219 | }; 220 | 221 | //! 222 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 223 | //! \return The nvinfer1::ILogger associated with this Logger 224 | //! 225 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 226 | //! we can eliminate the inheritance of Logger from ILogger 227 | //! 228 | nvinfer1::ILogger& getTRTLogger() 229 | { 230 | return *this; 231 | } 232 | 233 | //! 234 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 235 | //! 236 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 237 | //! inheritance from nvinfer1::ILogger 238 | //! 239 | void log(Severity severity, const char* msg) override 240 | { 241 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 242 | } 243 | 244 | //! 245 | //! \brief Method for controlling the verbosity of logging output 246 | //! 247 | //! \param severity The logger will only emit messages that have severity of this level or higher. 248 | //! 249 | void setReportableSeverity(Severity severity) 250 | { 251 | mReportableSeverity = severity; 252 | } 253 | 254 | //! 255 | //! \brief Opaque handle that holds logging information for a particular test 256 | //! 257 | //! This object is an opaque handle to information used by the Logger to print test results. 258 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 259 | //! with Logger::reportTest{Start,End}(). 260 | //! 261 | class TestAtom 262 | { 263 | public: 264 | TestAtom(TestAtom&&) = default; 265 | 266 | private: 267 | friend class Logger; 268 | 269 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 270 | : mStarted(started) 271 | , mName(name) 272 | , mCmdline(cmdline) 273 | { 274 | } 275 | 276 | bool mStarted; 277 | std::string mName; 278 | std::string mCmdline; 279 | }; 280 | 281 | //! 282 | //! \brief Define a test for logging 283 | //! 284 | //! \param[in] name The name of the test. This should be a string starting with 285 | //! "TensorRT" and containing dot-separated strings containing 286 | //! the characters [A-Za-z0-9_]. 287 | //! For example, "TensorRT.sample_googlenet" 288 | //! \param[in] cmdline The command line used to reproduce the test 289 | // 290 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 291 | //! 292 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 293 | { 294 | return TestAtom(false, name, cmdline); 295 | } 296 | 297 | //! 298 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 299 | //! as input 300 | //! 301 | //! \param[in] name The name of the test 302 | //! \param[in] argc The number of command-line arguments 303 | //! \param[in] argv The array of command-line arguments (given as C strings) 304 | //! 305 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 306 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 307 | { 308 | auto cmdline = genCmdlineString(argc, argv); 309 | return defineTest(name, cmdline); 310 | } 311 | 312 | //! 313 | //! \brief Report that a test has started. 314 | //! 315 | //! \pre reportTestStart() has not been called yet for the given testAtom 316 | //! 317 | //! \param[in] testAtom The handle to the test that has started 318 | //! 319 | static void reportTestStart(TestAtom& testAtom) 320 | { 321 | reportTestResult(testAtom, TestResult::kRUNNING); 322 | assert(!testAtom.mStarted); 323 | testAtom.mStarted = true; 324 | } 325 | 326 | //! 327 | //! \brief Report that a test has ended. 328 | //! 329 | //! \pre reportTestStart() has been called for the given testAtom 330 | //! 331 | //! \param[in] testAtom The handle to the test that has ended 332 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 333 | //! TestResult::kFAILED, TestResult::kWAIVED 334 | //! 335 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 336 | { 337 | assert(result != TestResult::kRUNNING); 338 | assert(testAtom.mStarted); 339 | reportTestResult(testAtom, result); 340 | } 341 | 342 | static int reportPass(const TestAtom& testAtom) 343 | { 344 | reportTestEnd(testAtom, TestResult::kPASSED); 345 | return EXIT_SUCCESS; 346 | } 347 | 348 | static int reportFail(const TestAtom& testAtom) 349 | { 350 | reportTestEnd(testAtom, TestResult::kFAILED); 351 | return EXIT_FAILURE; 352 | } 353 | 354 | static int reportWaive(const TestAtom& testAtom) 355 | { 356 | reportTestEnd(testAtom, TestResult::kWAIVED); 357 | return EXIT_SUCCESS; 358 | } 359 | 360 | static int reportTest(const TestAtom& testAtom, bool pass) 361 | { 362 | return pass ? reportPass(testAtom) : reportFail(testAtom); 363 | } 364 | 365 | Severity getReportableSeverity() const 366 | { 367 | return mReportableSeverity; 368 | } 369 | 370 | private: 371 | //! 372 | //! \brief returns an appropriate string for prefixing a log message with the given severity 373 | //! 374 | static const char* severityPrefix(Severity severity) 375 | { 376 | switch (severity) 377 | { 378 | case Severity::kINTERNAL_ERROR: return "[F] "; 379 | case Severity::kERROR: return "[E] "; 380 | case Severity::kWARNING: return "[W] "; 381 | case Severity::kINFO: return "[I] "; 382 | case Severity::kVERBOSE: return "[V] "; 383 | default: assert(0); return ""; 384 | } 385 | } 386 | 387 | //! 388 | //! \brief returns an appropriate string for prefixing a test result message with the given result 389 | //! 390 | static const char* testResultString(TestResult result) 391 | { 392 | switch (result) 393 | { 394 | case TestResult::kRUNNING: return "RUNNING"; 395 | case TestResult::kPASSED: return "PASSED"; 396 | case TestResult::kFAILED: return "FAILED"; 397 | case TestResult::kWAIVED: return "WAIVED"; 398 | default: assert(0); return ""; 399 | } 400 | } 401 | 402 | //! 403 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 404 | //! 405 | static std::ostream& severityOstream(Severity severity) 406 | { 407 | return severity >= Severity::kINFO ? std::cout : std::cerr; 408 | } 409 | 410 | //! 411 | //! \brief method that implements logging test results 412 | //! 413 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 414 | { 415 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 416 | << testAtom.mCmdline << std::endl; 417 | } 418 | 419 | //! 420 | //! \brief generate a command line string from the given (argc, argv) values 421 | //! 422 | static std::string genCmdlineString(int argc, char const* const* argv) 423 | { 424 | std::stringstream ss; 425 | for (int i = 0; i < argc; i++) 426 | { 427 | if (i > 0) 428 | ss << " "; 429 | ss << argv[i]; 430 | } 431 | return ss.str(); 432 | } 433 | 434 | Severity mReportableSeverity; 435 | }; 436 | 437 | namespace 438 | { 439 | 440 | //! 441 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 442 | //! 443 | //! Example usage: 444 | //! 445 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 446 | //! 447 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 448 | { 449 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 450 | } 451 | 452 | //! 453 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 454 | //! 455 | //! Example usage: 456 | //! 457 | //! LOG_INFO(logger) << "hello world" << std::endl; 458 | //! 459 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 460 | { 461 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 462 | } 463 | 464 | //! 465 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 466 | //! 467 | //! Example usage: 468 | //! 469 | //! LOG_WARN(logger) << "hello world" << std::endl; 470 | //! 471 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 472 | { 473 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 474 | } 475 | 476 | //! 477 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 478 | //! 479 | //! Example usage: 480 | //! 481 | //! LOG_ERROR(logger) << "hello world" << std::endl; 482 | //! 483 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 484 | { 485 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 486 | } 487 | 488 | //! 489 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 490 | // ("fatal" severity) 491 | //! 492 | //! Example usage: 493 | //! 494 | //! LOG_FATAL(logger) << "hello world" << std::endl; 495 | //! 496 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 497 | { 498 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 499 | } 500 | 501 | } // anonymous namespace 502 | 503 | #endif // TENSORRT_LOGGING_H 504 | -------------------------------------------------------------------------------- /results/Screenshot from 2021-04-21 19-25-48.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/DDRNet.TensorRT/f3dbdb77d9a5fb412dd42dcc0acd56a4c75dd0d4/results/Screenshot from 2021-04-21 19-25-48.png -------------------------------------------------------------------------------- /results/Screenshot from 2021-04-21 19-26-08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/DDRNet.TensorRT/f3dbdb77d9a5fb412dd42dcc0acd56a4c75dd0d4/results/Screenshot from 2021-04-21 19-26-08.png -------------------------------------------------------------------------------- /results/result_mainz_000001_009328_leftImg8bit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/midasklr/DDRNet.TensorRT/f3dbdb77d9a5fb412dd42dcc0acd56a4c75dd0d4/results/result_mainz_000001_009328_leftImg8bit.png -------------------------------------------------------------------------------- /utils.h: -------------------------------------------------------------------------------- 1 | #ifndef __TRT_UTILS_H_ 2 | #define __TRT_UTILS_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #ifndef CUDA_CHECK 12 | 13 | #define CUDA_CHECK(callstr) \ 14 | { \ 15 | cudaError_t error_code = callstr; \ 16 | if (error_code != cudaSuccess) { \ 17 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ 18 | assert(0); \ 19 | } \ 20 | } 21 | 22 | #endif 23 | 24 | namespace Tn 25 | { 26 | template 27 | void write(char*& buffer, const T& val) 28 | { 29 | *reinterpret_cast(buffer) = val; 30 | buffer += sizeof(T); 31 | } 32 | 33 | template 34 | void read(const char*& buffer, T& val) 35 | { 36 | val = *reinterpret_cast(buffer); 37 | buffer += sizeof(T); 38 | } 39 | } 40 | 41 | static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h) { 42 | int w, h, x, y; 43 | float r_w = input_w / (img.cols*1.0); 44 | float r_h = input_h / (img.rows*1.0); 45 | if (r_h > r_w) { 46 | w = input_w; 47 | h = r_w * img.rows; 48 | x = 0; 49 | y = (input_h - h) / 2; 50 | } else { 51 | w = r_h * img.cols; 52 | h = input_h; 53 | x = (input_w - w) / 2; 54 | y = 0; 55 | } 56 | cv::Mat re(h, w, CV_8UC3); 57 | cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR); 58 | cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128)); 59 | re.copyTo(out(cv::Rect(x, y, re.cols, re.rows))); 60 | return out; 61 | } 62 | 63 | static inline int read_files_in_dir(const char *p_dir_name, std::vector &file_names) { 64 | DIR *p_dir = opendir(p_dir_name); 65 | if (p_dir == nullptr) { 66 | return -1; 67 | } 68 | 69 | struct dirent* p_file = nullptr; 70 | while ((p_file = readdir(p_dir)) != nullptr) { 71 | if (strcmp(p_file->d_name, ".") != 0 && 72 | strcmp(p_file->d_name, "..") != 0) { 73 | //std::string cur_file_name(p_dir_name); 74 | //cur_file_name += "/"; 75 | //cur_file_name += p_file->d_name; 76 | std::string cur_file_name(p_file->d_name); 77 | file_names.push_back(cur_file_name); 78 | } 79 | } 80 | 81 | closedir(p_dir); 82 | return 0; 83 | } 84 | 85 | #endif 86 | --------------------------------------------------------------------------------