├── CMakeLists.txt
├── README.md
├── calibrator.cpp
├── calibrator.h
├── common.hpp
├── ddrnet.cpp
├── getwts.py
├── images
└── mainz_000001_009328_leftImg8bit.png
├── logging.h
├── results
├── Screenshot from 2021-04-21 19-25-48.png
├── Screenshot from 2021-04-21 19-26-08.png
└── result_mainz_000001_009328_leftImg8bit.png
└── utils.h
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 2.6)
2 |
3 | project(DDRNet)
4 |
5 | add_definitions(-std=c++11)
6 |
7 | option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
8 | set(CMAKE_CXX_STANDARD 11)
9 | set(CMAKE_BUILD_TYPE Debug)
10 |
11 | find_package(CUDA REQUIRED)
12 | include_directories(/usr/local/cuda/include)
13 | link_directories(/usr/local/cuda/lib64)
14 |
15 | include_directories(${PROJECT_SOURCE_DIR}/include)
16 |
17 | include_directories(/home/midas/TensorRT-7.0.0.11/include/)
18 | link_directories(/home/midas/TensorRT-7.0.0.11/lib/)
19 |
20 | find_package(OpenCV 3.4.8 REQUIRED)
21 | include_directories(${OpenCV_INCLUDE_DIRS})
22 |
23 | add_executable(ddrnet ${PROJECT_SOURCE_DIR}/ddrnet.cpp)
24 | target_link_libraries(ddrnet nvinfer)
25 | target_link_libraries(ddrnet cudart)
26 | target_link_libraries(ddrnet ${OpenCV_LIBS})
27 |
28 | add_definitions(-O2 -pthread)
29 |
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DDRNet
2 |
3 | TensorRT implementation of the official [DDRNet](https://github.com/ydhongHIT/DDRNet)
4 |
5 |
6 |
7 |
8 |
9 | [DDRNet-23-slim](https://paperswithcode.com/paper/deep-dual-resolution-networks-for-real-time) outperform other light weight segmentation method,[see](https://paperswithcode.com/sota/real-time-semantic-segmentation-on-cityscapes)
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Compile&Run
17 |
18 | * 1. get model.wts
19 |
20 | Convert pytorch model to wts model using getwts.py, or download the wts [model](url: https://pan.baidu.com/s/1Cm1A2mq6RxCFhUJrOJBSrw ;passworld: p6hy ) convert from official implementation.
21 |
22 | note that we do not use extral segmentation head while inference(set augment=False in https://github.com/ydhongHIT/DDRNet/blob/76a875084afdc7dedd20e2c2bdc0a93f8f481e81/segmentation/DDRNet_23_slim.py#L345).
23 |
24 | * 2. cmake and make
25 |
26 | config ur cmakelist and
27 |
28 | ```
29 | mkdir build
30 | cd build
31 | cmake ..
32 | make -j8
33 | ./ddrnet -s // serialize model to plan file i.e. 'DDRNet.engine'
34 | ./ddrnet -d ../images // deserialize plan file and run inference, the images in samples will be processed.
35 | ```
36 |
37 | for INT8 support:
38 |
39 | ```
40 | #define USE_INT8 // comment out this if want to use INT8
41 | //#define USE_FP16 // comment out this if want to use FP32
42 | ```
43 |
44 | mkdir "calib" and put around 1k images(cityscape val/test images) into folder "calib".
45 |
46 | ## FPS
47 |
48 | Test on RTX2070
49 |
50 | | model | input | FPS |
51 | | -------------- | --------------- | ---- |
52 | | Pytorch-aug | (3,1024,1024) | 107 |
53 | | Pytorch-no-aug | (3,1024,1024) | 108 |
54 | | TensorRT-FP32 | (3,1024,1024) | 117 |
55 | | TensorRT-FP16 | (3,1024,1024) | 215 |
56 | | TensorRT-INT8 | (3,1024,1024) | 232 |
57 |
58 | Pytorch-aug means augment=True.
59 |
60 | ## Difference with official
61 |
62 | we use Upsample with "nearest" other than "bilinear",which may lead to lower accuracy .
63 |
64 | Finetune with "nearest" upsample may recover the accuracy.
65 |
66 | Here we convert from the official model directly.
67 |
68 | ## Train
69 |
70 | 1. refer to:https://github.com/chenjun2hao/DDRNet.pytorch
71 | 2. generate wts model with getwts.py
72 |
73 | ## Train customer data
74 | https://github.com/midasklr/DDRNet.Pytorch
75 | wirte your own dataset and finetune the model with cityscape.
76 |
--------------------------------------------------------------------------------
/calibrator.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include "calibrator.h"
6 | #include "cuda_runtime_api.h"
7 | #include "utils.h"
8 |
9 | Int8EntropyCalibrator2::Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache)
10 | : batchsize_(batchsize)
11 | , input_w_(input_w)
12 | , input_h_(input_h)
13 | , img_idx_(0)
14 | , img_dir_(img_dir)
15 | , calib_table_name_(calib_table_name)
16 | , input_blob_name_(input_blob_name)
17 | , read_cache_(read_cache)
18 | {
19 | input_count_ = 3 * input_w * input_h * batchsize;
20 | CUDA_CHECK(cudaMalloc(&device_input_, input_count_ * sizeof(float)));
21 | read_files_in_dir(img_dir, img_files_);
22 | }
23 |
24 | Int8EntropyCalibrator2::~Int8EntropyCalibrator2()
25 | {
26 | CUDA_CHECK(cudaFree(device_input_));
27 | }
28 |
29 | int Int8EntropyCalibrator2::getBatchSize() const
30 | {
31 | return batchsize_;
32 | }
33 |
34 | bool Int8EntropyCalibrator2::getBatch(void* bindings[], const char* names[], int nbBindings)
35 | {
36 | if (img_idx_ + batchsize_ > (int)img_files_.size()) {
37 | return false;
38 | }
39 |
40 | std::vector input_imgs_;
41 | for (int i = img_idx_; i < img_idx_ + batchsize_; i++) {
42 | std::cout << img_files_[i] << " " << i << std::endl;
43 | cv::Mat temp = cv::imread(img_dir_ + img_files_[i]);
44 | if (temp.empty()){
45 | std::cerr << "Fatal error: image cannot open!" << std::endl;
46 | return false;
47 | }
48 | cv::Mat pr_img = preprocess_img(temp, input_w_, input_h_);
49 | input_imgs_.push_back(pr_img);
50 | }
51 | img_idx_ += batchsize_;
52 | cv::Mat blob = cv::dnn::blobFromImages(input_imgs_, 1.0 / 57.3750, cv::Size(input_w_, input_h_), cv::Scalar(1.80444, 2.0267, 2.1555), true, false);
53 |
54 | CUDA_CHECK(cudaMemcpy(device_input_, blob.ptr(0), input_count_ * sizeof(float), cudaMemcpyHostToDevice));
55 | assert(!strcmp(names[0], input_blob_name_));
56 | bindings[0] = device_input_;
57 | return true;
58 | }
59 |
60 | const void* Int8EntropyCalibrator2::readCalibrationCache(size_t& length)
61 | {
62 | std::cout << "reading calib cache: " << calib_table_name_ << std::endl;
63 | calib_cache_.clear();
64 | std::ifstream input(calib_table_name_, std::ios::binary);
65 | input >> std::noskipws;
66 | if (read_cache_ && input.good())
67 | {
68 | std::copy(std::istream_iterator(input), std::istream_iterator(), std::back_inserter(calib_cache_));
69 | }
70 | length = calib_cache_.size();
71 | return length ? calib_cache_.data() : nullptr;
72 | }
73 |
74 | void Int8EntropyCalibrator2::writeCalibrationCache(const void* cache, size_t length)
75 | {
76 | std::cout << "writing calib cache: " << calib_table_name_ << " size: " << length << std::endl;
77 | std::ofstream output(calib_table_name_, std::ios::binary);
78 | output.write(reinterpret_cast(cache), length);
79 | }
80 |
81 |
--------------------------------------------------------------------------------
/calibrator.h:
--------------------------------------------------------------------------------
1 | #ifndef ENTROPY_CALIBRATOR_H
2 | #define ENTROPY_CALIBRATOR_H
3 |
4 | #include "NvInfer.h"
5 | #include
6 | #include
7 |
8 | //! \class Int8EntropyCalibrator2
9 | //!
10 | //! \brief Implements Entropy calibrator 2.
11 | //! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
12 | //!
13 | class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2
14 | {
15 | public:
16 | Int8EntropyCalibrator2(int batchsize, int input_w, int input_h, const char* img_dir, const char* calib_table_name, const char* input_blob_name, bool read_cache = true);
17 |
18 | virtual ~Int8EntropyCalibrator2();
19 | int getBatchSize() const override;
20 | bool getBatch(void* bindings[], const char* names[], int nbBindings) override;
21 | const void* readCalibrationCache(size_t& length) override;
22 | void writeCalibrationCache(const void* cache, size_t length) override;
23 |
24 | private:
25 | int batchsize_;
26 | int input_w_;
27 | int input_h_;
28 | int img_idx_;
29 | std::string img_dir_;
30 | std::vector img_files_;
31 | size_t input_count_;
32 | std::string calib_table_name_;
33 | const char* input_blob_name_;
34 | bool read_cache_;
35 | void* device_input_;
36 | std::vector calib_cache_;
37 | };
38 |
39 | #endif // ENTROPY_CALIBRATOR_H
40 |
--------------------------------------------------------------------------------
/common.hpp:
--------------------------------------------------------------------------------
1 | #ifndef DDRNET_COMMON_H_
2 | #define DDRNET_COMMON_H_
3 |
4 | #include
5 | #include
6 | #include